1use std::collections::HashMap;
6use std::ops::RangeInclusive;
7
8use serde_json::Value;
9
10use super::{ConfigError, ConfigResult, ValidationErrors};
11
12pub type ValidationFn = Box<dyn Fn(&Value) -> Result<(), String> + Send + Sync>;
13
14pub struct ConfigValidator {
15 required_keys: Vec<String>,
16 type_rules: HashMap<String, ValueType>,
17 range_rules: HashMap<String, RangeInclusive<i64>>,
18 pattern_rules: HashMap<String, regex::Regex>,
19 custom_rules: HashMap<String, ValidationFn>,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum ValueType {
24 String,
25 Number,
26 Boolean,
27 Array,
28 Object,
29}
30
31impl ValueType {
32 fn matches(&self, value: &Value) -> bool {
33 match self {
34 ValueType::String => value.is_string(),
35 ValueType::Number => value.is_number(),
36 ValueType::Boolean => value.is_boolean(),
37 ValueType::Array => value.is_array(),
38 ValueType::Object => value.is_object(),
39 }
40 }
41
42 fn name(&self) -> &'static str {
43 match self {
44 ValueType::String => "string",
45 ValueType::Number => "number",
46 ValueType::Boolean => "boolean",
47 ValueType::Array => "array",
48 ValueType::Object => "object",
49 }
50 }
51}
52
53impl ConfigValidator {
54 pub fn new() -> Self {
55 Self {
56 required_keys: Vec::new(),
57 type_rules: HashMap::new(),
58 range_rules: HashMap::new(),
59 pattern_rules: HashMap::new(),
60 custom_rules: HashMap::new(),
61 }
62 }
63
64 pub fn require(mut self, key: impl Into<String>) -> Self {
65 self.required_keys.push(key.into());
66 self
67 }
68
69 pub fn require_many(mut self, keys: impl IntoIterator<Item = impl Into<String>>) -> Self {
70 self.required_keys.extend(keys.into_iter().map(Into::into));
71 self
72 }
73
74 pub fn expect_type(mut self, key: impl Into<String>, value_type: ValueType) -> Self {
75 self.type_rules.insert(key.into(), value_type);
76 self
77 }
78
79 pub fn expect_range(mut self, key: impl Into<String>, range: RangeInclusive<i64>) -> Self {
80 self.range_rules.insert(key.into(), range);
81 self
82 }
83
84 pub fn expect_pattern(mut self, key: impl Into<String>, pattern: &str) -> ConfigResult<Self> {
85 let key = key.into();
86 let regex = regex::Regex::new(pattern).map_err(|e| ConfigError::InvalidValue {
87 key: key.clone(),
88 message: format!("Invalid regex pattern: {}", e),
89 })?;
90 self.pattern_rules.insert(key, regex);
91 Ok(self)
92 }
93
94 pub fn custom<F>(mut self, key: impl Into<String>, validator: F) -> Self
95 where
96 F: Fn(&Value) -> Result<(), String> + Send + Sync + 'static,
97 {
98 self.custom_rules.insert(key.into(), Box::new(validator));
99 self
100 }
101
102 pub fn validate(&self, config: &Value) -> ConfigResult<()> {
103 let errors = self.collect_errors(config);
104 if errors.is_empty() {
105 Ok(())
106 } else {
107 Err(ConfigError::ValidationErrors(ValidationErrors(errors)))
108 }
109 }
110
111 pub fn validate_partial(&self, config: &Value) -> Vec<ConfigError> {
112 self.collect_errors(config)
113 }
114
115 fn collect_errors(&self, config: &Value) -> Vec<ConfigError> {
116 let mut errors = Vec::new();
117
118 for key in &self.required_keys {
119 if get_nested(config, key).is_none() {
120 errors.push(ConfigError::NotFound { key: key.clone() });
121 }
122 }
123
124 for (key, expected_type) in &self.type_rules {
125 if let Some(value) = get_nested(config, key)
126 && !expected_type.matches(value)
127 {
128 errors.push(ConfigError::InvalidValue {
129 key: key.clone(),
130 message: format!(
131 "expected {}, got {}",
132 expected_type.name(),
133 value_type_name(value)
134 ),
135 });
136 }
137 }
138
139 for (key, range) in &self.range_rules {
140 if let Some(value) = get_nested(config, key)
141 && let Some(num) = value.as_i64()
142 && !range.contains(&num)
143 {
144 errors.push(ConfigError::InvalidValue {
145 key: key.clone(),
146 message: format!(
147 "value {} not in range {}..={}",
148 num,
149 range.start(),
150 range.end()
151 ),
152 });
153 }
154 }
155
156 for (key, pattern) in &self.pattern_rules {
157 if let Some(value) = get_nested(config, key)
158 && let Some(s) = value.as_str()
159 && !pattern.is_match(s)
160 {
161 errors.push(ConfigError::InvalidValue {
162 key: key.clone(),
163 message: format!("Value '{}' does not match pattern", s),
164 });
165 }
166 }
167
168 for (key, validator) in &self.custom_rules {
169 if let Some(value) = get_nested(config, key)
170 && let Err(msg) = validator(value)
171 {
172 errors.push(ConfigError::InvalidValue {
173 key: key.clone(),
174 message: msg,
175 });
176 }
177 }
178
179 errors
180 }
181}
182
183impl Default for ConfigValidator {
184 fn default() -> Self {
185 Self::new()
186 }
187}
188
189fn get_nested<'a>(config: &'a Value, key: &str) -> Option<&'a Value> {
190 let parts: Vec<&str> = key.split('.').collect();
191 let mut current = config;
192
193 for part in parts {
194 current = current.get(part)?;
195 }
196
197 Some(current)
198}
199
200fn value_type_name(value: &Value) -> &'static str {
201 match value {
202 Value::Null => "null",
203 Value::Bool(_) => "boolean",
204 Value::Number(_) => "number",
205 Value::String(_) => "string",
206 Value::Array(_) => "array",
207 Value::Object(_) => "object",
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use serde_json::json;
215
216 #[test]
217 fn test_required_keys() {
218 let validator = ConfigValidator::new().require("api_key").require("model");
219
220 let config = json!({
221 "api_key": "sk-test",
222 "model": "claude-sonnet-4-5"
223 });
224 assert!(validator.validate(&config).is_ok());
225
226 let missing = json!({
227 "api_key": "sk-test"
228 });
229 assert!(validator.validate(&missing).is_err());
230 }
231
232 #[test]
233 fn test_type_validation() {
234 let validator = ConfigValidator::new()
235 .expect_type("port", ValueType::Number)
236 .expect_type("enabled", ValueType::Boolean);
237
238 let valid = json!({
239 "port": 8080,
240 "enabled": true
241 });
242 assert!(validator.validate(&valid).is_ok());
243
244 let invalid = json!({
245 "port": "8080",
246 "enabled": true
247 });
248 assert!(validator.validate(&invalid).is_err());
249 }
250
251 #[test]
252 fn test_range_validation() {
253 let validator = ConfigValidator::new()
254 .expect_range("port", 1..=65535)
255 .expect_range("timeout", 1..=300);
256
257 let valid = json!({
258 "port": 8080,
259 "timeout": 30
260 });
261 assert!(validator.validate(&valid).is_ok());
262
263 let invalid = json!({
264 "port": 70000,
265 "timeout": 30
266 });
267 assert!(validator.validate(&invalid).is_err());
268 }
269
270 #[test]
271 fn test_pattern_validation() {
272 let validator = ConfigValidator::new()
273 .expect_pattern("api_key", r"^sk-[a-zA-Z0-9]+$")
274 .unwrap();
275
276 let valid = json!({
277 "api_key": "sk-test123"
278 });
279 assert!(validator.validate(&valid).is_ok());
280
281 let invalid = json!({
282 "api_key": "invalid-key"
283 });
284 assert!(validator.validate(&invalid).is_err());
285 }
286
287 #[test]
288 fn test_nested_keys() {
289 let validator = ConfigValidator::new()
290 .require("database.host")
291 .expect_type("database.port", ValueType::Number);
292
293 let config = json!({
294 "database": {
295 "host": "localhost",
296 "port": 5432
297 }
298 });
299 assert!(validator.validate(&config).is_ok());
300 }
301
302 #[test]
303 fn test_custom_validator() {
304 let validator = ConfigValidator::new().custom("urls", |v| {
305 if let Some(arr) = v.as_array()
306 && arr.is_empty()
307 {
308 return Err("urls cannot be empty".to_string());
309 }
310 Ok(())
311 });
312
313 let valid = json!({
314 "urls": ["http://example.com"]
315 });
316 assert!(validator.validate(&valid).is_ok());
317
318 let invalid = json!({
319 "urls": []
320 });
321 assert!(validator.validate(&invalid).is_err());
322 }
323
324 #[test]
325 fn test_require_many() {
326 let validator = ConfigValidator::new().require_many(["host", "port", "database"]);
327
328 let config = json!({
329 "host": "localhost",
330 "port": 5432,
331 "database": "mydb"
332 });
333 assert!(validator.validate(&config).is_ok());
334 }
335
336 #[test]
337 fn test_validate_partial() {
338 let validator = ConfigValidator::new()
339 .require("a")
340 .require("b")
341 .require("c");
342
343 let config = json!({
344 "a": 1
345 });
346
347 let errors = validator.validate_partial(&config);
348 assert_eq!(errors.len(), 2);
349 }
350}