claude_agent/config/
validator.rs

1//! Configuration Validation Layer
2//!
3//! Validates configuration values before use.
4
5use std::collections::HashMap;
6use std::ops::RangeInclusive;
7
8use serde_json::Value;
9
10use super::{ConfigError, ConfigResult, ValidationErrors};
11
12pub type ValidationFn = Box<dyn Fn(&Value) -> Result<(), String> + Send + Sync>;
13
14pub struct ConfigValidator {
15    required_keys: Vec<String>,
16    type_rules: HashMap<String, ValueType>,
17    range_rules: HashMap<String, RangeInclusive<i64>>,
18    pattern_rules: HashMap<String, regex::Regex>,
19    custom_rules: HashMap<String, ValidationFn>,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum ValueType {
24    String,
25    Number,
26    Boolean,
27    Array,
28    Object,
29}
30
31impl ValueType {
32    fn matches(&self, value: &Value) -> bool {
33        match self {
34            ValueType::String => value.is_string(),
35            ValueType::Number => value.is_number(),
36            ValueType::Boolean => value.is_boolean(),
37            ValueType::Array => value.is_array(),
38            ValueType::Object => value.is_object(),
39        }
40    }
41
42    fn name(&self) -> &'static str {
43        match self {
44            ValueType::String => "string",
45            ValueType::Number => "number",
46            ValueType::Boolean => "boolean",
47            ValueType::Array => "array",
48            ValueType::Object => "object",
49        }
50    }
51}
52
53impl ConfigValidator {
54    pub fn new() -> Self {
55        Self {
56            required_keys: Vec::new(),
57            type_rules: HashMap::new(),
58            range_rules: HashMap::new(),
59            pattern_rules: HashMap::new(),
60            custom_rules: HashMap::new(),
61        }
62    }
63
64    pub fn require(mut self, key: impl Into<String>) -> Self {
65        self.required_keys.push(key.into());
66        self
67    }
68
69    pub fn require_many(mut self, keys: impl IntoIterator<Item = impl Into<String>>) -> Self {
70        self.required_keys.extend(keys.into_iter().map(Into::into));
71        self
72    }
73
74    pub fn expect_type(mut self, key: impl Into<String>, value_type: ValueType) -> Self {
75        self.type_rules.insert(key.into(), value_type);
76        self
77    }
78
79    pub fn expect_range(mut self, key: impl Into<String>, range: RangeInclusive<i64>) -> Self {
80        self.range_rules.insert(key.into(), range);
81        self
82    }
83
84    pub fn expect_pattern(mut self, key: impl Into<String>, pattern: &str) -> ConfigResult<Self> {
85        let key = key.into();
86        let regex = regex::Regex::new(pattern).map_err(|e| ConfigError::InvalidValue {
87            key: key.clone(),
88            message: format!("invalid regex pattern: {}", e),
89        })?;
90        self.pattern_rules.insert(key, regex);
91        Ok(self)
92    }
93
94    pub fn custom<F>(mut self, key: impl Into<String>, validator: F) -> Self
95    where
96        F: Fn(&Value) -> Result<(), String> + Send + Sync + 'static,
97    {
98        self.custom_rules.insert(key.into(), Box::new(validator));
99        self
100    }
101
102    pub fn validate(&self, config: &Value) -> ConfigResult<()> {
103        let mut errors = Vec::new();
104
105        for key in &self.required_keys {
106            if get_nested(config, key).is_none() {
107                errors.push(ConfigError::NotFound { key: key.clone() });
108            }
109        }
110
111        for (key, expected_type) in &self.type_rules {
112            if let Some(value) = get_nested(config, key)
113                && !expected_type.matches(value)
114            {
115                errors.push(ConfigError::InvalidValue {
116                    key: key.clone(),
117                    message: format!(
118                        "expected {}, got {}",
119                        expected_type.name(),
120                        value_type_name(value)
121                    ),
122                });
123            }
124        }
125
126        for (key, range) in &self.range_rules {
127            if let Some(value) = get_nested(config, key)
128                && let Some(num) = value.as_i64()
129                && !range.contains(&num)
130            {
131                errors.push(ConfigError::InvalidValue {
132                    key: key.clone(),
133                    message: format!(
134                        "value {} not in range {}..={}",
135                        num,
136                        range.start(),
137                        range.end()
138                    ),
139                });
140            }
141        }
142
143        for (key, pattern) in &self.pattern_rules {
144            if let Some(value) = get_nested(config, key)
145                && let Some(s) = value.as_str()
146                && !pattern.is_match(s)
147            {
148                errors.push(ConfigError::InvalidValue {
149                    key: key.clone(),
150                    message: format!("value '{}' does not match pattern", s),
151                });
152            }
153        }
154
155        for (key, validator) in &self.custom_rules {
156            if let Some(value) = get_nested(config, key)
157                && let Err(msg) = validator(value)
158            {
159                errors.push(ConfigError::InvalidValue {
160                    key: key.clone(),
161                    message: msg,
162                });
163            }
164        }
165
166        if errors.is_empty() {
167            Ok(())
168        } else {
169            Err(ConfigError::ValidationErrors(ValidationErrors(errors)))
170        }
171    }
172
173    pub fn validate_partial(&self, config: &Value) -> Vec<ConfigError> {
174        let mut errors = Vec::new();
175
176        for key in &self.required_keys {
177            if get_nested(config, key).is_none() {
178                errors.push(ConfigError::NotFound { key: key.clone() });
179            }
180        }
181
182        for (key, expected_type) in &self.type_rules {
183            if let Some(value) = get_nested(config, key)
184                && !expected_type.matches(value)
185            {
186                errors.push(ConfigError::InvalidValue {
187                    key: key.clone(),
188                    message: format!(
189                        "expected {}, got {}",
190                        expected_type.name(),
191                        value_type_name(value)
192                    ),
193                });
194            }
195        }
196
197        for (key, range) in &self.range_rules {
198            if let Some(value) = get_nested(config, key)
199                && let Some(num) = value.as_i64()
200                && !range.contains(&num)
201            {
202                errors.push(ConfigError::InvalidValue {
203                    key: key.clone(),
204                    message: format!(
205                        "value {} not in range {}..={}",
206                        num,
207                        range.start(),
208                        range.end()
209                    ),
210                });
211            }
212        }
213
214        for (key, pattern) in &self.pattern_rules {
215            if let Some(value) = get_nested(config, key)
216                && let Some(s) = value.as_str()
217                && !pattern.is_match(s)
218            {
219                errors.push(ConfigError::InvalidValue {
220                    key: key.clone(),
221                    message: format!("value '{}' does not match pattern", s),
222                });
223            }
224        }
225
226        for (key, validator) in &self.custom_rules {
227            if let Some(value) = get_nested(config, key)
228                && let Err(msg) = validator(value)
229            {
230                errors.push(ConfigError::InvalidValue {
231                    key: key.clone(),
232                    message: msg,
233                });
234            }
235        }
236
237        errors
238    }
239}
240
241impl Default for ConfigValidator {
242    fn default() -> Self {
243        Self::new()
244    }
245}
246
247fn get_nested<'a>(config: &'a Value, key: &str) -> Option<&'a Value> {
248    let parts: Vec<&str> = key.split('.').collect();
249    let mut current = config;
250
251    for part in parts {
252        current = current.get(part)?;
253    }
254
255    Some(current)
256}
257
258fn value_type_name(value: &Value) -> &'static str {
259    match value {
260        Value::Null => "null",
261        Value::Bool(_) => "boolean",
262        Value::Number(_) => "number",
263        Value::String(_) => "string",
264        Value::Array(_) => "array",
265        Value::Object(_) => "object",
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use serde_json::json;
273
274    #[test]
275    fn test_required_keys() {
276        let validator = ConfigValidator::new().require("api_key").require("model");
277
278        let config = json!({
279            "api_key": "sk-test",
280            "model": "claude-sonnet-4-5"
281        });
282        assert!(validator.validate(&config).is_ok());
283
284        let missing = json!({
285            "api_key": "sk-test"
286        });
287        assert!(validator.validate(&missing).is_err());
288    }
289
290    #[test]
291    fn test_type_validation() {
292        let validator = ConfigValidator::new()
293            .expect_type("port", ValueType::Number)
294            .expect_type("enabled", ValueType::Boolean);
295
296        let valid = json!({
297            "port": 8080,
298            "enabled": true
299        });
300        assert!(validator.validate(&valid).is_ok());
301
302        let invalid = json!({
303            "port": "8080",
304            "enabled": true
305        });
306        assert!(validator.validate(&invalid).is_err());
307    }
308
309    #[test]
310    fn test_range_validation() {
311        let validator = ConfigValidator::new()
312            .expect_range("port", 1..=65535)
313            .expect_range("timeout", 1..=300);
314
315        let valid = json!({
316            "port": 8080,
317            "timeout": 30
318        });
319        assert!(validator.validate(&valid).is_ok());
320
321        let invalid = json!({
322            "port": 70000,
323            "timeout": 30
324        });
325        assert!(validator.validate(&invalid).is_err());
326    }
327
328    #[test]
329    fn test_pattern_validation() {
330        let validator = ConfigValidator::new()
331            .expect_pattern("api_key", r"^sk-[a-zA-Z0-9]+$")
332            .unwrap();
333
334        let valid = json!({
335            "api_key": "sk-test123"
336        });
337        assert!(validator.validate(&valid).is_ok());
338
339        let invalid = json!({
340            "api_key": "invalid-key"
341        });
342        assert!(validator.validate(&invalid).is_err());
343    }
344
345    #[test]
346    fn test_nested_keys() {
347        let validator = ConfigValidator::new()
348            .require("database.host")
349            .expect_type("database.port", ValueType::Number);
350
351        let config = json!({
352            "database": {
353                "host": "localhost",
354                "port": 5432
355            }
356        });
357        assert!(validator.validate(&config).is_ok());
358    }
359
360    #[test]
361    fn test_custom_validator() {
362        let validator = ConfigValidator::new().custom("urls", |v| {
363            if let Some(arr) = v.as_array()
364                && arr.is_empty()
365            {
366                return Err("urls cannot be empty".to_string());
367            }
368            Ok(())
369        });
370
371        let valid = json!({
372            "urls": ["http://example.com"]
373        });
374        assert!(validator.validate(&valid).is_ok());
375
376        let invalid = json!({
377            "urls": []
378        });
379        assert!(validator.validate(&invalid).is_err());
380    }
381
382    #[test]
383    fn test_require_many() {
384        let validator = ConfigValidator::new().require_many(["host", "port", "database"]);
385
386        let config = json!({
387            "host": "localhost",
388            "port": 5432,
389            "database": "mydb"
390        });
391        assert!(validator.validate(&config).is_ok());
392    }
393
394    #[test]
395    fn test_validate_partial() {
396        let validator = ConfigValidator::new()
397            .require("a")
398            .require("b")
399            .require("c");
400
401        let config = json!({
402            "a": 1
403        });
404
405        let errors = validator.validate_partial(&config);
406        assert_eq!(errors.len(), 2);
407    }
408}