Skip to main content

openclaw_core/validation/
mod.rs

1//! Input validation and sanitization.
2//!
3//! Defense-in-depth: validate all external inputs before processing.
4
5use thiserror::Error;
6use unicode_normalization::UnicodeNormalization;
7
8/// Validation error types.
9#[derive(Error, Debug)]
10pub enum ValidationError {
11    /// Input exceeds maximum allowed length.
12    #[error("Input exceeds maximum length ({max} bytes, got {actual})")]
13    TooLong {
14        /// Maximum allowed length.
15        max: usize,
16        /// Actual input length.
17        actual: usize,
18    },
19
20    /// Invalid UTF-8 encoding.
21    #[error("Invalid UTF-8 encoding")]
22    InvalidUtf8,
23
24    /// Disallowed characters in input.
25    #[error("Disallowed characters in input")]
26    DisallowedChars,
27
28    /// Input failed schema validation.
29    #[error("Input failed schema validation: {0}")]
30    SchemaViolation(String),
31
32    /// JSON parsing error.
33    #[error("JSON error: {0}")]
34    JsonError(#[from] serde_json::Error),
35}
36
37/// Size limits per input type.
38pub mod limits {
39    /// Maximum message content length (64KB).
40    pub const MAX_MESSAGE_LENGTH: usize = 64 * 1024;
41
42    /// Maximum tool parameters size (1MB).
43    pub const MAX_TOOL_PARAMS_SIZE: usize = 1024 * 1024;
44
45    /// Maximum skill file size (256KB).
46    pub const MAX_SKILL_FILE_SIZE: usize = 256 * 1024;
47
48    /// Maximum config file size (1MB).
49    pub const MAX_CONFIG_FILE_SIZE: usize = 1024 * 1024;
50
51    /// Maximum attachment size (50MB).
52    pub const MAX_ATTACHMENT_SIZE: usize = 50 * 1024 * 1024;
53
54    /// Maximum JSON nesting depth.
55    pub const MAX_JSON_DEPTH: usize = 32;
56}
57
58/// Validate and sanitize message content from channels.
59///
60/// Performs:
61/// 1. Length check (prevent memory exhaustion)
62/// 2. Strip null bytes and control chars (except newlines/tabs)
63/// 3. Unicode normalization (NFKC - prevent homograph attacks)
64///
65/// # Errors
66///
67/// Returns `ValidationError::TooLong` if input exceeds `max_len`.
68pub fn validate_message_content(input: &str, max_len: usize) -> Result<String, ValidationError> {
69    // 1. Length check (prevent memory exhaustion)
70    if input.len() > max_len {
71        return Err(ValidationError::TooLong {
72            max: max_len,
73            actual: input.len(),
74        });
75    }
76
77    // 2. Strip null bytes and control chars (except newlines/tabs)
78    let sanitized: String = input
79        .chars()
80        .filter(|c| !c.is_control() || *c == '\n' || *c == '\t' || *c == '\r')
81        .collect();
82
83    // 3. Normalize unicode (NFKC - prevent homograph attacks in allowlists)
84    let normalized: String = sanitized.nfkc().collect();
85
86    Ok(normalized)
87}
88
89/// Validate tool parameters against a JSON schema.
90///
91/// # Errors
92///
93/// Returns `ValidationError::SchemaViolation` if validation fails.
94pub fn validate_tool_params(
95    params: &serde_json::Value,
96    schema: &serde_json::Value,
97) -> Result<(), ValidationError> {
98    // Check size limit first
99    let size = serde_json::to_string(params)?.len();
100    if size > limits::MAX_TOOL_PARAMS_SIZE {
101        return Err(ValidationError::TooLong {
102            max: limits::MAX_TOOL_PARAMS_SIZE,
103            actual: size,
104        });
105    }
106
107    // Check JSON depth
108    check_json_depth(params, 0, limits::MAX_JSON_DEPTH)?;
109
110    // JSON Schema validation would go here
111    // For now, we do basic structural validation
112    validate_json_structure(params, schema)?;
113
114    Ok(())
115}
116
117/// Check JSON nesting depth to prevent stack overflow.
118fn check_json_depth(
119    value: &serde_json::Value,
120    depth: usize,
121    max: usize,
122) -> Result<(), ValidationError> {
123    if depth > max {
124        return Err(ValidationError::SchemaViolation(format!(
125            "JSON nesting depth exceeds maximum ({max})"
126        )));
127    }
128
129    match value {
130        serde_json::Value::Array(arr) => {
131            for item in arr {
132                check_json_depth(item, depth + 1, max)?;
133            }
134        }
135        serde_json::Value::Object(obj) => {
136            for (_, item) in obj {
137                check_json_depth(item, depth + 1, max)?;
138            }
139        }
140        _ => {}
141    }
142
143    Ok(())
144}
145
146/// Basic JSON structure validation against schema.
147fn validate_json_structure(
148    params: &serde_json::Value,
149    schema: &serde_json::Value,
150) -> Result<(), ValidationError> {
151    let schema_type = schema.get("type").and_then(|t| t.as_str());
152
153    match schema_type {
154        Some("object") => {
155            if !params.is_object() {
156                return Err(ValidationError::SchemaViolation(
157                    "Expected object".to_string(),
158                ));
159            }
160
161            // Check required fields
162            if let Some(required) = schema.get("required").and_then(|r| r.as_array()) {
163                let obj = params.as_object().unwrap();
164                for req in required {
165                    if let Some(field) = req.as_str() {
166                        if !obj.contains_key(field) {
167                            return Err(ValidationError::SchemaViolation(format!(
168                                "Missing required field: {field}"
169                            )));
170                        }
171                    }
172                }
173            }
174        }
175        Some("array") => {
176            if !params.is_array() {
177                return Err(ValidationError::SchemaViolation(
178                    "Expected array".to_string(),
179                ));
180            }
181        }
182        Some("string") => {
183            if !params.is_string() {
184                return Err(ValidationError::SchemaViolation(
185                    "Expected string".to_string(),
186                ));
187            }
188        }
189        Some("number" | "integer") => {
190            if !params.is_number() {
191                return Err(ValidationError::SchemaViolation(
192                    "Expected number".to_string(),
193                ));
194            }
195        }
196        Some("boolean") => {
197            if !params.is_boolean() {
198                return Err(ValidationError::SchemaViolation(
199                    "Expected boolean".to_string(),
200                ));
201            }
202        }
203        _ => {}
204    }
205
206    Ok(())
207}
208
209/// Validate a file path to prevent path traversal attacks.
210///
211/// # Errors
212///
213/// Returns error if path contains traversal sequences.
214pub fn validate_path(path: &str) -> Result<(), ValidationError> {
215    if path.contains("..") || path.contains('\0') {
216        return Err(ValidationError::DisallowedChars);
217    }
218    Ok(())
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_validate_message_content() {
227        // Normal content
228        let result = validate_message_content("Hello, world!", 100);
229        assert!(result.is_ok());
230        assert_eq!(result.unwrap(), "Hello, world!");
231
232        // With control chars
233        let result = validate_message_content("Hello\x00World", 100);
234        assert!(result.is_ok());
235        assert_eq!(result.unwrap(), "HelloWorld");
236
237        // Preserves newlines
238        let result = validate_message_content("Line1\nLine2", 100);
239        assert!(result.is_ok());
240        assert_eq!(result.unwrap(), "Line1\nLine2");
241
242        // Too long
243        let result = validate_message_content("x".repeat(200).as_str(), 100);
244        assert!(matches!(result, Err(ValidationError::TooLong { .. })));
245    }
246
247    #[test]
248    fn test_unicode_normalization() {
249        // NFKC normalization
250        let result = validate_message_content("fi", 100); // fi ligature
251        assert!(result.is_ok());
252        assert_eq!(result.unwrap(), "fi");
253    }
254
255    #[test]
256    fn test_validate_path() {
257        assert!(validate_path("/home/user/file.txt").is_ok());
258        assert!(validate_path("../etc/passwd").is_err());
259        assert!(validate_path("/home/user/\0file").is_err());
260    }
261
262    #[test]
263    fn test_json_depth() {
264        let shallow = serde_json::json!({"a": {"b": "c"}});
265        assert!(check_json_depth(&shallow, 0, 10).is_ok());
266
267        // Create deeply nested JSON
268        let mut deep = serde_json::json!("leaf");
269        for _ in 0..50 {
270            deep = serde_json::json!({"nested": deep});
271        }
272        assert!(check_json_depth(&deep, 0, 32).is_err());
273    }
274}