Skip to main content

claude_agent/config/
validator.rs

1//! Configuration Validation Layer
2//!
3//! Validates configuration values before use.
4
5use std::collections::HashMap;
6use std::ops::RangeInclusive;
7
8use serde_json::Value;
9
10use super::{ConfigError, ConfigResult, ValidationErrors};
11
12pub type ValidationFn = Box<dyn Fn(&Value) -> Result<(), String> + Send + Sync>;
13
14pub struct ConfigValidator {
15    required_keys: Vec<String>,
16    type_rules: HashMap<String, ValueType>,
17    range_rules: HashMap<String, RangeInclusive<i64>>,
18    pattern_rules: HashMap<String, regex::Regex>,
19    custom_rules: HashMap<String, ValidationFn>,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum ValueType {
24    String,
25    Number,
26    Boolean,
27    Array,
28    Object,
29}
30
31impl ValueType {
32    fn matches(&self, value: &Value) -> bool {
33        match self {
34            ValueType::String => value.is_string(),
35            ValueType::Number => value.is_number(),
36            ValueType::Boolean => value.is_boolean(),
37            ValueType::Array => value.is_array(),
38            ValueType::Object => value.is_object(),
39        }
40    }
41
42    fn name(&self) -> &'static str {
43        match self {
44            ValueType::String => "string",
45            ValueType::Number => "number",
46            ValueType::Boolean => "boolean",
47            ValueType::Array => "array",
48            ValueType::Object => "object",
49        }
50    }
51}
52
53impl ConfigValidator {
54    pub fn new() -> Self {
55        Self {
56            required_keys: Vec::new(),
57            type_rules: HashMap::new(),
58            range_rules: HashMap::new(),
59            pattern_rules: HashMap::new(),
60            custom_rules: HashMap::new(),
61        }
62    }
63
64    pub fn require(mut self, key: impl Into<String>) -> Self {
65        self.required_keys.push(key.into());
66        self
67    }
68
69    pub fn require_many(mut self, keys: impl IntoIterator<Item = impl Into<String>>) -> Self {
70        self.required_keys.extend(keys.into_iter().map(Into::into));
71        self
72    }
73
74    pub fn expect_type(mut self, key: impl Into<String>, value_type: ValueType) -> Self {
75        self.type_rules.insert(key.into(), value_type);
76        self
77    }
78
79    pub fn expect_range(mut self, key: impl Into<String>, range: RangeInclusive<i64>) -> Self {
80        self.range_rules.insert(key.into(), range);
81        self
82    }
83
84    pub fn expect_pattern(mut self, key: impl Into<String>, pattern: &str) -> ConfigResult<Self> {
85        let key = key.into();
86        let regex = regex::Regex::new(pattern).map_err(|e| ConfigError::InvalidValue {
87            key: key.clone(),
88            message: format!("Invalid regex pattern: {}", e),
89        })?;
90        self.pattern_rules.insert(key, regex);
91        Ok(self)
92    }
93
94    pub fn custom<F>(mut self, key: impl Into<String>, validator: F) -> Self
95    where
96        F: Fn(&Value) -> Result<(), String> + Send + Sync + 'static,
97    {
98        self.custom_rules.insert(key.into(), Box::new(validator));
99        self
100    }
101
102    pub fn validate(&self, config: &Value) -> ConfigResult<()> {
103        let errors = self.collect_errors(config);
104        if errors.is_empty() {
105            Ok(())
106        } else {
107            Err(ConfigError::ValidationErrors(ValidationErrors(errors)))
108        }
109    }
110
111    pub fn validate_partial(&self, config: &Value) -> Vec<ConfigError> {
112        self.collect_errors(config)
113    }
114
115    fn collect_errors(&self, config: &Value) -> Vec<ConfigError> {
116        let mut errors = Vec::new();
117
118        for key in &self.required_keys {
119            if get_nested(config, key).is_none() {
120                errors.push(ConfigError::NotFound { key: key.clone() });
121            }
122        }
123
124        for (key, expected_type) in &self.type_rules {
125            if let Some(value) = get_nested(config, key)
126                && !expected_type.matches(value)
127            {
128                errors.push(ConfigError::InvalidValue {
129                    key: key.clone(),
130                    message: format!(
131                        "expected {}, got {}",
132                        expected_type.name(),
133                        value_type_name(value)
134                    ),
135                });
136            }
137        }
138
139        for (key, range) in &self.range_rules {
140            if let Some(value) = get_nested(config, key)
141                && let Some(num) = value.as_i64()
142                && !range.contains(&num)
143            {
144                errors.push(ConfigError::InvalidValue {
145                    key: key.clone(),
146                    message: format!(
147                        "value {} not in range {}..={}",
148                        num,
149                        range.start(),
150                        range.end()
151                    ),
152                });
153            }
154        }
155
156        for (key, pattern) in &self.pattern_rules {
157            if let Some(value) = get_nested(config, key)
158                && let Some(s) = value.as_str()
159                && !pattern.is_match(s)
160            {
161                errors.push(ConfigError::InvalidValue {
162                    key: key.clone(),
163                    message: format!("Value '{}' does not match pattern", s),
164                });
165            }
166        }
167
168        for (key, validator) in &self.custom_rules {
169            if let Some(value) = get_nested(config, key)
170                && let Err(msg) = validator(value)
171            {
172                errors.push(ConfigError::InvalidValue {
173                    key: key.clone(),
174                    message: msg,
175                });
176            }
177        }
178
179        errors
180    }
181}
182
183impl Default for ConfigValidator {
184    fn default() -> Self {
185        Self::new()
186    }
187}
188
189fn get_nested<'a>(config: &'a Value, key: &str) -> Option<&'a Value> {
190    let parts: Vec<&str> = key.split('.').collect();
191    let mut current = config;
192
193    for part in parts {
194        current = current.get(part)?;
195    }
196
197    Some(current)
198}
199
200fn value_type_name(value: &Value) -> &'static str {
201    match value {
202        Value::Null => "null",
203        Value::Bool(_) => "boolean",
204        Value::Number(_) => "number",
205        Value::String(_) => "string",
206        Value::Array(_) => "array",
207        Value::Object(_) => "object",
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use serde_json::json;
215
216    #[test]
217    fn test_required_keys() {
218        let validator = ConfigValidator::new().require("api_key").require("model");
219
220        let config = json!({
221            "api_key": "sk-test",
222            "model": "claude-sonnet-4-5"
223        });
224        assert!(validator.validate(&config).is_ok());
225
226        let missing = json!({
227            "api_key": "sk-test"
228        });
229        assert!(validator.validate(&missing).is_err());
230    }
231
232    #[test]
233    fn test_type_validation() {
234        let validator = ConfigValidator::new()
235            .expect_type("port", ValueType::Number)
236            .expect_type("enabled", ValueType::Boolean);
237
238        let valid = json!({
239            "port": 8080,
240            "enabled": true
241        });
242        assert!(validator.validate(&valid).is_ok());
243
244        let invalid = json!({
245            "port": "8080",
246            "enabled": true
247        });
248        assert!(validator.validate(&invalid).is_err());
249    }
250
251    #[test]
252    fn test_range_validation() {
253        let validator = ConfigValidator::new()
254            .expect_range("port", 1..=65535)
255            .expect_range("timeout", 1..=300);
256
257        let valid = json!({
258            "port": 8080,
259            "timeout": 30
260        });
261        assert!(validator.validate(&valid).is_ok());
262
263        let invalid = json!({
264            "port": 70000,
265            "timeout": 30
266        });
267        assert!(validator.validate(&invalid).is_err());
268    }
269
270    #[test]
271    fn test_pattern_validation() {
272        let validator = ConfigValidator::new()
273            .expect_pattern("api_key", r"^sk-[a-zA-Z0-9]+$")
274            .unwrap();
275
276        let valid = json!({
277            "api_key": "sk-test123"
278        });
279        assert!(validator.validate(&valid).is_ok());
280
281        let invalid = json!({
282            "api_key": "invalid-key"
283        });
284        assert!(validator.validate(&invalid).is_err());
285    }
286
287    #[test]
288    fn test_nested_keys() {
289        let validator = ConfigValidator::new()
290            .require("database.host")
291            .expect_type("database.port", ValueType::Number);
292
293        let config = json!({
294            "database": {
295                "host": "localhost",
296                "port": 5432
297            }
298        });
299        assert!(validator.validate(&config).is_ok());
300    }
301
302    #[test]
303    fn test_custom_validator() {
304        let validator = ConfigValidator::new().custom("urls", |v| {
305            if let Some(arr) = v.as_array()
306                && arr.is_empty()
307            {
308                return Err("urls cannot be empty".to_string());
309            }
310            Ok(())
311        });
312
313        let valid = json!({
314            "urls": ["http://example.com"]
315        });
316        assert!(validator.validate(&valid).is_ok());
317
318        let invalid = json!({
319            "urls": []
320        });
321        assert!(validator.validate(&invalid).is_err());
322    }
323
324    #[test]
325    fn test_require_many() {
326        let validator = ConfigValidator::new().require_many(["host", "port", "database"]);
327
328        let config = json!({
329            "host": "localhost",
330            "port": 5432,
331            "database": "mydb"
332        });
333        assert!(validator.validate(&config).is_ok());
334    }
335
336    #[test]
337    fn test_validate_partial() {
338        let validator = ConfigValidator::new()
339            .require("a")
340            .require("b")
341            .require("c");
342
343        let config = json!({
344            "a": 1
345        });
346
347        let errors = validator.validate_partial(&config);
348        assert_eq!(errors.len(), 2);
349    }
350}