1use std::collections::HashMap;
6use std::ops::RangeInclusive;
7
8use serde_json::Value;
9
10use super::{ConfigError, ConfigResult, ValidationErrors};
11
12pub type ValidationFn = Box<dyn Fn(&Value) -> Result<(), String> + Send + Sync>;
13
14pub struct ConfigValidator {
15 required_keys: Vec<String>,
16 type_rules: HashMap<String, ValueType>,
17 range_rules: HashMap<String, RangeInclusive<i64>>,
18 pattern_rules: HashMap<String, regex::Regex>,
19 custom_rules: HashMap<String, ValidationFn>,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum ValueType {
24 String,
25 Number,
26 Boolean,
27 Array,
28 Object,
29}
30
31impl ValueType {
32 fn matches(&self, value: &Value) -> bool {
33 match self {
34 ValueType::String => value.is_string(),
35 ValueType::Number => value.is_number(),
36 ValueType::Boolean => value.is_boolean(),
37 ValueType::Array => value.is_array(),
38 ValueType::Object => value.is_object(),
39 }
40 }
41
42 fn name(&self) -> &'static str {
43 match self {
44 ValueType::String => "string",
45 ValueType::Number => "number",
46 ValueType::Boolean => "boolean",
47 ValueType::Array => "array",
48 ValueType::Object => "object",
49 }
50 }
51}
52
53impl ConfigValidator {
54 pub fn new() -> Self {
55 Self {
56 required_keys: Vec::new(),
57 type_rules: HashMap::new(),
58 range_rules: HashMap::new(),
59 pattern_rules: HashMap::new(),
60 custom_rules: HashMap::new(),
61 }
62 }
63
64 pub fn require(mut self, key: impl Into<String>) -> Self {
65 self.required_keys.push(key.into());
66 self
67 }
68
69 pub fn require_many(mut self, keys: impl IntoIterator<Item = impl Into<String>>) -> Self {
70 self.required_keys.extend(keys.into_iter().map(Into::into));
71 self
72 }
73
74 pub fn expect_type(mut self, key: impl Into<String>, value_type: ValueType) -> Self {
75 self.type_rules.insert(key.into(), value_type);
76 self
77 }
78
79 pub fn expect_range(mut self, key: impl Into<String>, range: RangeInclusive<i64>) -> Self {
80 self.range_rules.insert(key.into(), range);
81 self
82 }
83
84 pub fn expect_pattern(mut self, key: impl Into<String>, pattern: &str) -> ConfigResult<Self> {
85 let key = key.into();
86 let regex = regex::Regex::new(pattern).map_err(|e| ConfigError::InvalidValue {
87 key: key.clone(),
88 message: format!("invalid regex pattern: {}", e),
89 })?;
90 self.pattern_rules.insert(key, regex);
91 Ok(self)
92 }
93
94 pub fn custom<F>(mut self, key: impl Into<String>, validator: F) -> Self
95 where
96 F: Fn(&Value) -> Result<(), String> + Send + Sync + 'static,
97 {
98 self.custom_rules.insert(key.into(), Box::new(validator));
99 self
100 }
101
102 pub fn validate(&self, config: &Value) -> ConfigResult<()> {
103 let mut errors = Vec::new();
104
105 for key in &self.required_keys {
106 if get_nested(config, key).is_none() {
107 errors.push(ConfigError::NotFound { key: key.clone() });
108 }
109 }
110
111 for (key, expected_type) in &self.type_rules {
112 if let Some(value) = get_nested(config, key)
113 && !expected_type.matches(value)
114 {
115 errors.push(ConfigError::InvalidValue {
116 key: key.clone(),
117 message: format!(
118 "expected {}, got {}",
119 expected_type.name(),
120 value_type_name(value)
121 ),
122 });
123 }
124 }
125
126 for (key, range) in &self.range_rules {
127 if let Some(value) = get_nested(config, key)
128 && let Some(num) = value.as_i64()
129 && !range.contains(&num)
130 {
131 errors.push(ConfigError::InvalidValue {
132 key: key.clone(),
133 message: format!(
134 "value {} not in range {}..={}",
135 num,
136 range.start(),
137 range.end()
138 ),
139 });
140 }
141 }
142
143 for (key, pattern) in &self.pattern_rules {
144 if let Some(value) = get_nested(config, key)
145 && let Some(s) = value.as_str()
146 && !pattern.is_match(s)
147 {
148 errors.push(ConfigError::InvalidValue {
149 key: key.clone(),
150 message: format!("value '{}' does not match pattern", s),
151 });
152 }
153 }
154
155 for (key, validator) in &self.custom_rules {
156 if let Some(value) = get_nested(config, key)
157 && let Err(msg) = validator(value)
158 {
159 errors.push(ConfigError::InvalidValue {
160 key: key.clone(),
161 message: msg,
162 });
163 }
164 }
165
166 if errors.is_empty() {
167 Ok(())
168 } else {
169 Err(ConfigError::ValidationErrors(ValidationErrors(errors)))
170 }
171 }
172
173 pub fn validate_partial(&self, config: &Value) -> Vec<ConfigError> {
174 let mut errors = Vec::new();
175
176 for key in &self.required_keys {
177 if get_nested(config, key).is_none() {
178 errors.push(ConfigError::NotFound { key: key.clone() });
179 }
180 }
181
182 for (key, expected_type) in &self.type_rules {
183 if let Some(value) = get_nested(config, key)
184 && !expected_type.matches(value)
185 {
186 errors.push(ConfigError::InvalidValue {
187 key: key.clone(),
188 message: format!(
189 "expected {}, got {}",
190 expected_type.name(),
191 value_type_name(value)
192 ),
193 });
194 }
195 }
196
197 for (key, range) in &self.range_rules {
198 if let Some(value) = get_nested(config, key)
199 && let Some(num) = value.as_i64()
200 && !range.contains(&num)
201 {
202 errors.push(ConfigError::InvalidValue {
203 key: key.clone(),
204 message: format!(
205 "value {} not in range {}..={}",
206 num,
207 range.start(),
208 range.end()
209 ),
210 });
211 }
212 }
213
214 for (key, pattern) in &self.pattern_rules {
215 if let Some(value) = get_nested(config, key)
216 && let Some(s) = value.as_str()
217 && !pattern.is_match(s)
218 {
219 errors.push(ConfigError::InvalidValue {
220 key: key.clone(),
221 message: format!("value '{}' does not match pattern", s),
222 });
223 }
224 }
225
226 for (key, validator) in &self.custom_rules {
227 if let Some(value) = get_nested(config, key)
228 && let Err(msg) = validator(value)
229 {
230 errors.push(ConfigError::InvalidValue {
231 key: key.clone(),
232 message: msg,
233 });
234 }
235 }
236
237 errors
238 }
239}
240
241impl Default for ConfigValidator {
242 fn default() -> Self {
243 Self::new()
244 }
245}
246
247fn get_nested<'a>(config: &'a Value, key: &str) -> Option<&'a Value> {
248 let parts: Vec<&str> = key.split('.').collect();
249 let mut current = config;
250
251 for part in parts {
252 current = current.get(part)?;
253 }
254
255 Some(current)
256}
257
258fn value_type_name(value: &Value) -> &'static str {
259 match value {
260 Value::Null => "null",
261 Value::Bool(_) => "boolean",
262 Value::Number(_) => "number",
263 Value::String(_) => "string",
264 Value::Array(_) => "array",
265 Value::Object(_) => "object",
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use serde_json::json;
273
274 #[test]
275 fn test_required_keys() {
276 let validator = ConfigValidator::new().require("api_key").require("model");
277
278 let config = json!({
279 "api_key": "sk-test",
280 "model": "claude-sonnet-4-5"
281 });
282 assert!(validator.validate(&config).is_ok());
283
284 let missing = json!({
285 "api_key": "sk-test"
286 });
287 assert!(validator.validate(&missing).is_err());
288 }
289
290 #[test]
291 fn test_type_validation() {
292 let validator = ConfigValidator::new()
293 .expect_type("port", ValueType::Number)
294 .expect_type("enabled", ValueType::Boolean);
295
296 let valid = json!({
297 "port": 8080,
298 "enabled": true
299 });
300 assert!(validator.validate(&valid).is_ok());
301
302 let invalid = json!({
303 "port": "8080",
304 "enabled": true
305 });
306 assert!(validator.validate(&invalid).is_err());
307 }
308
309 #[test]
310 fn test_range_validation() {
311 let validator = ConfigValidator::new()
312 .expect_range("port", 1..=65535)
313 .expect_range("timeout", 1..=300);
314
315 let valid = json!({
316 "port": 8080,
317 "timeout": 30
318 });
319 assert!(validator.validate(&valid).is_ok());
320
321 let invalid = json!({
322 "port": 70000,
323 "timeout": 30
324 });
325 assert!(validator.validate(&invalid).is_err());
326 }
327
328 #[test]
329 fn test_pattern_validation() {
330 let validator = ConfigValidator::new()
331 .expect_pattern("api_key", r"^sk-[a-zA-Z0-9]+$")
332 .unwrap();
333
334 let valid = json!({
335 "api_key": "sk-test123"
336 });
337 assert!(validator.validate(&valid).is_ok());
338
339 let invalid = json!({
340 "api_key": "invalid-key"
341 });
342 assert!(validator.validate(&invalid).is_err());
343 }
344
345 #[test]
346 fn test_nested_keys() {
347 let validator = ConfigValidator::new()
348 .require("database.host")
349 .expect_type("database.port", ValueType::Number);
350
351 let config = json!({
352 "database": {
353 "host": "localhost",
354 "port": 5432
355 }
356 });
357 assert!(validator.validate(&config).is_ok());
358 }
359
360 #[test]
361 fn test_custom_validator() {
362 let validator = ConfigValidator::new().custom("urls", |v| {
363 if let Some(arr) = v.as_array()
364 && arr.is_empty()
365 {
366 return Err("urls cannot be empty".to_string());
367 }
368 Ok(())
369 });
370
371 let valid = json!({
372 "urls": ["http://example.com"]
373 });
374 assert!(validator.validate(&valid).is_ok());
375
376 let invalid = json!({
377 "urls": []
378 });
379 assert!(validator.validate(&invalid).is_err());
380 }
381
382 #[test]
383 fn test_require_many() {
384 let validator = ConfigValidator::new().require_many(["host", "port", "database"]);
385
386 let config = json!({
387 "host": "localhost",
388 "port": 5432,
389 "database": "mydb"
390 });
391 assert!(validator.validate(&config).is_ok());
392 }
393
394 #[test]
395 fn test_validate_partial() {
396 let validator = ConfigValidator::new()
397 .require("a")
398 .require("b")
399 .require("c");
400
401 let config = json!({
402 "a": 1
403 });
404
405 let errors = validator.validate_partial(&config);
406 assert_eq!(errors.len(), 2);
407 }
408}