runbeam_sdk/validation/
schema.rs

1//! Schema DSL parser for TOML configuration validation.
2//!
3//! This module parses schema definition files written in the Harmony DSL format
4//! and converts them into structured data that can be used for validation.
5
6use crate::validation::error::ValidationError;
7use regex::Regex;
8use serde::Deserialize;
9use std::collections::HashMap;
10
11/// A parsed schema containing table and field definitions.
12#[derive(Debug, Clone)]
13pub struct Schema {
14    /// Schema version
15    pub version: String,
16    /// Schema description
17    pub description: String,
18    /// Table definitions indexed by name
19    pub tables: HashMap<String, TableDefinition>,
20}
21
22/// A table definition in the schema.
23#[derive(Debug, Clone)]
24pub struct TableDefinition {
25    /// Table name (may include wildcards like "network.*")
26    pub name: String,
27    /// Whether this table name is a pattern (contains wildcards)
28    pub is_pattern: bool,
29    /// Pattern constraint regex if this is a pattern table
30    pub pattern_constraint: Option<Regex>,
31    /// Whether this table is required
32    pub required: bool,
33    /// Table description
34    pub description: Option<String>,
35    /// Field definitions for this table
36    pub fields: Vec<FieldDefinition>,
37}
38
39/// A field definition within a table.
40#[derive(Debug, Clone)]
41pub struct FieldDefinition {
42    /// Field name (may be a path like "tcp_config.bind_address")
43    pub name: String,
44    /// Field type (string, integer, boolean, float, array, table)
45    pub field_type: String,
46    /// Whether this field is required
47    pub required: bool,
48    /// Conditional requirement expression
49    pub required_if: Option<String>,
50    /// Default value
51    pub default: Option<toml::Value>,
52    /// Allowed enum values
53    pub enum_values: Option<Vec<String>>,
54    /// Minimum value (for numeric types)
55    pub min: Option<i64>,
56    /// Maximum value (for numeric types)
57    pub max: Option<i64>,
58    /// Minimum number of array items
59    pub min_items: Option<usize>,
60    /// Maximum number of array items
61    pub max_items: Option<usize>,
62    /// Expected type of array items
63    pub array_item_type: Option<String>,
64    /// Pattern constraint for string values
65    pub pattern: Option<Regex>,
66    /// Field description
67    pub description: Option<String>,
68}
69
70/// Raw schema format as parsed from TOML
71#[derive(Debug, Deserialize)]
72struct RawSchema {
73    schema: SchemaMetadata,
74    table: Vec<RawTable>,
75}
76
77#[derive(Debug, Deserialize)]
78struct SchemaMetadata {
79    version: String,
80    description: String,
81}
82
83#[derive(Debug, Deserialize)]
84struct RawTable {
85    name: String,
86    #[serde(default)]
87    required: bool,
88    #[serde(default)]
89    pattern: bool,
90    pattern_constraint: Option<String>,
91    description: Option<String>,
92    #[serde(default)]
93    field: Vec<RawField>,
94}
95
96#[derive(Debug, Deserialize)]
97struct RawField {
98    name: String,
99    #[serde(rename = "type")]
100    field_type: String,
101    #[serde(default)]
102    required: bool,
103    required_if: Option<String>,
104    default: Option<toml::Value>,
105    #[serde(rename = "enum")]
106    enum_values: Option<Vec<String>>,
107    min: Option<i64>,
108    max: Option<i64>,
109    min_items: Option<usize>,
110    max_items: Option<usize>,
111    array_item_type: Option<String>,
112    pattern_constraint: Option<String>,
113    description: Option<String>,
114}
115
116impl Schema {
117    /// Parse a schema from TOML string
118    ///
119    /// Note: This is similar to `FromStr::from_str` but returns our custom error type
120    #[allow(clippy::should_implement_trait)]
121    pub fn from_str(schema_toml: &str) -> Result<Self, ValidationError> {
122        let raw: RawSchema = toml::from_str(schema_toml).map_err(|e| {
123            ValidationError::SchemaParseError(format!("Failed to parse schema TOML: {}", e))
124        })?;
125
126        let mut tables = HashMap::new();
127
128        for raw_table in raw.table {
129            let pattern_constraint = if let Some(pattern_str) = &raw_table.pattern_constraint {
130                Some(Regex::new(pattern_str).map_err(|e| {
131                    ValidationError::SchemaParseError(format!(
132                        "Invalid pattern constraint '{}': {}",
133                        pattern_str, e
134                    ))
135                })?)
136            } else {
137                None
138            };
139
140            let mut fields = Vec::new();
141            for raw_field in raw_table.field {
142                let pattern = if let Some(pattern_str) = &raw_field.pattern_constraint {
143                    Some(Regex::new(pattern_str).map_err(|e| {
144                        ValidationError::SchemaParseError(format!(
145                            "Invalid pattern for field '{}': {}",
146                            raw_field.name, e
147                        ))
148                    })?)
149                } else {
150                    None
151                };
152
153                fields.push(FieldDefinition {
154                    name: raw_field.name,
155                    field_type: raw_field.field_type,
156                    required: raw_field.required,
157                    required_if: raw_field.required_if,
158                    default: raw_field.default,
159                    enum_values: raw_field.enum_values,
160                    min: raw_field.min,
161                    max: raw_field.max,
162                    min_items: raw_field.min_items,
163                    max_items: raw_field.max_items,
164                    array_item_type: raw_field.array_item_type,
165                    pattern,
166                    description: raw_field.description,
167                });
168            }
169
170            let table_def = TableDefinition {
171                name: raw_table.name.clone(),
172                is_pattern: raw_table.pattern,
173                pattern_constraint,
174                required: raw_table.required,
175                description: raw_table.description,
176                fields,
177            };
178
179            tables.insert(raw_table.name, table_def);
180        }
181
182        Ok(Schema {
183            version: raw.schema.version,
184            description: raw.schema.description,
185            tables,
186        })
187    }
188
189    /// Find a table definition that matches the given table path.
190    ///
191    /// This handles both exact matches and pattern matches (e.g., "network.*" matches "network.default").
192    pub fn find_table(&self, table_path: &str) -> Option<&TableDefinition> {
193        // Try exact match first
194        if let Some(table_def) = self.tables.get(table_path) {
195            return Some(table_def);
196        }
197
198        // Try pattern matches
199        self.tables.values().find(|&table_def| {
200            table_def.is_pattern && self.matches_pattern(table_path, &table_def.name)
201        })
202    }
203
204    /// Check if a table path matches a pattern table name.
205    ///
206    /// For example, "network.default" matches pattern "network.*"
207    pub fn matches_pattern(&self, table_path: &str, pattern: &str) -> bool {
208        if !pattern.contains('*') {
209            return table_path == pattern;
210        }
211
212        // Convert pattern to regex
213        // "network.*" becomes "^network\.[^.]+$"
214        let pattern_regex = pattern.replace(".", r"\.").replace("*", "[^.]+");
215        let pattern_regex = format!("^{}$", pattern_regex);
216
217        if let Ok(re) = Regex::new(&pattern_regex) {
218            re.is_match(table_path)
219        } else {
220            false
221        }
222    }
223
224    /// Get all tables that should be validated
225    pub fn get_concrete_tables(&self) -> impl Iterator<Item = &TableDefinition> {
226        self.tables.values().filter(|t| !t.is_pattern)
227    }
228}
229
230impl TableDefinition {
231    /// Find a field definition by name (supports nested paths like "tcp_config.bind_address")
232    pub fn find_field(&self, field_name: &str) -> Option<&FieldDefinition> {
233        self.fields.iter().find(|f| f.name == field_name)
234    }
235
236    /// Get all fields for this table
237    pub fn get_fields(&self) -> &[FieldDefinition] {
238        &self.fields
239    }
240}
241
242impl FieldDefinition {
243    /// Check if this field is conditionally required based on the given table data
244    pub fn is_conditionally_required(&self, table_data: &toml::Value) -> bool {
245        if let Some(condition) = &self.required_if {
246            evaluate_condition(condition, table_data)
247        } else {
248            false
249        }
250    }
251}
252
253/// Evaluate a required_if condition expression
254///
255/// Supports:
256/// - `field == "value"` - Field equals value
257/// - `field != "value"` - Field doesn't equal value  
258/// - `field == true` - Boolean field is true
259/// - `field == false` - Boolean field is false
260/// - `field exists` or just `field` - Field exists
261fn evaluate_condition(condition: &str, table_data: &toml::Value) -> bool {
262    let condition = condition.trim();
263
264    // Handle "field exists" or just "field"
265    if !condition.contains("==") && !condition.contains("!=") {
266        let field_name = condition.replace(" exists", "").trim().to_string();
267        return table_data.get(&field_name).is_some();
268    }
269
270    // Handle == and != operators
271    if let Some((left, right)) = condition.split_once("==") {
272        let field_name = left.trim();
273        let expected_value = right.trim().trim_matches('"').trim_matches('\'');
274
275        if let Some(field_value) = table_data.get(field_name) {
276            match field_value {
277                toml::Value::String(s) => return s == expected_value,
278                toml::Value::Boolean(b) => {
279                    return expected_value == "true" && *b || expected_value == "false" && !*b
280                }
281                toml::Value::Integer(i) => {
282                    if let Ok(expected_int) = expected_value.parse::<i64>() {
283                        return *i == expected_int;
284                    }
285                }
286                _ => {}
287            }
288        }
289        return false;
290    }
291
292    if let Some((left, right)) = condition.split_once("!=") {
293        let field_name = left.trim();
294        let expected_value = right.trim().trim_matches('"').trim_matches('\'');
295
296        if let Some(field_value) = table_data.get(field_name) {
297            match field_value {
298                toml::Value::String(s) => return s != expected_value,
299                toml::Value::Boolean(b) => {
300                    return !(expected_value == "true" && *b || expected_value == "false" && !*b)
301                }
302                toml::Value::Integer(i) => {
303                    if let Ok(expected_int) = expected_value.parse::<i64>() {
304                        return *i != expected_int;
305                    }
306                }
307                _ => {}
308            }
309        }
310        return true; // Field doesn't exist or doesn't match, so != is satisfied
311    }
312
313    false
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319
320    fn sample_schema() -> &'static str {
321        r#"
322[schema]
323version = "1.0"
324description = "Test schema"
325
326[[table]]
327name = "proxy"
328required = true
329description = "Core proxy configuration"
330
331[[table.field]]
332name = "id"
333type = "string"
334required = true
335description = "Proxy ID"
336
337[[table.field]]
338name = "log_level"
339type = "string"
340required = false
341default = "error"
342enum = ["trace", "debug", "info", "warn", "error"]
343
344[[table.field]]
345name = "port"
346type = "integer"
347required = false
348min = 1
349max = 65535
350
351[[table]]
352name = "network.*"
353pattern = true
354pattern_constraint = "^[a-z0-9_-]+$"
355required = false
356
357[[table.field]]
358name = "bind_address"
359type = "string"
360required = true
361"#
362    }
363
364    #[test]
365    fn test_parse_schema() {
366        let schema = Schema::from_str(sample_schema()).unwrap();
367        assert_eq!(schema.version, "1.0");
368        assert_eq!(schema.description, "Test schema");
369        assert_eq!(schema.tables.len(), 2);
370    }
371
372    #[test]
373    fn test_find_table_exact() {
374        let schema = Schema::from_str(sample_schema()).unwrap();
375        let table = schema.find_table("proxy");
376        assert!(table.is_some());
377        assert_eq!(table.unwrap().name, "proxy");
378    }
379
380    #[test]
381    fn test_find_table_pattern() {
382        let schema = Schema::from_str(sample_schema()).unwrap();
383        let table = schema.find_table("network.default");
384        assert!(table.is_some());
385        assert_eq!(table.unwrap().name, "network.*");
386    }
387
388    #[test]
389    fn test_find_field() {
390        let schema = Schema::from_str(sample_schema()).unwrap();
391        let table = schema.find_table("proxy").unwrap();
392        let field = table.find_field("id");
393        assert!(field.is_some());
394        assert_eq!(field.unwrap().field_type, "string");
395        assert!(field.unwrap().required);
396    }
397
398    #[test]
399    fn test_enum_values() {
400        let schema = Schema::from_str(sample_schema()).unwrap();
401        let table = schema.find_table("proxy").unwrap();
402        let field = table.find_field("log_level").unwrap();
403        assert!(field.enum_values.is_some());
404        let enums = field.enum_values.as_ref().unwrap();
405        assert_eq!(enums.len(), 5);
406        assert!(enums.contains(&"error".to_string()));
407    }
408
409    #[test]
410    fn test_numeric_range() {
411        let schema = Schema::from_str(sample_schema()).unwrap();
412        let table = schema.find_table("proxy").unwrap();
413        let field = table.find_field("port").unwrap();
414        assert_eq!(field.min, Some(1));
415        assert_eq!(field.max, Some(65535));
416    }
417
418    #[test]
419    fn test_evaluate_condition_equals() {
420        let mut table = toml::map::Map::new();
421        table.insert("enabled".to_string(), toml::Value::Boolean(true));
422        let table_value = toml::Value::Table(table);
423
424        assert!(evaluate_condition("enabled == true", &table_value));
425        assert!(!evaluate_condition("enabled == false", &table_value));
426    }
427
428    #[test]
429    fn test_evaluate_condition_string() {
430        let mut table = toml::map::Map::new();
431        table.insert("type".to_string(), toml::Value::String("http".to_string()));
432        let table_value = toml::Value::Table(table);
433
434        assert!(evaluate_condition("type == \"http\"", &table_value));
435        assert!(!evaluate_condition("type == \"tcp\"", &table_value));
436    }
437
438    #[test]
439    fn test_evaluate_condition_exists() {
440        let mut table = toml::map::Map::new();
441        table.insert(
442            "field".to_string(),
443            toml::Value::String("value".to_string()),
444        );
445        let table_value = toml::Value::Table(table);
446
447        assert!(evaluate_condition("field exists", &table_value));
448        assert!(evaluate_condition("field", &table_value));
449        assert!(!evaluate_condition("missing", &table_value));
450    }
451
452    #[test]
453    fn test_pattern_matching() {
454        let schema = Schema::from_str(sample_schema()).unwrap();
455
456        // Should match
457        assert!(schema.matches_pattern("network.default", "network.*"));
458        assert!(schema.matches_pattern("network.management", "network.*"));
459
460        // Should not match
461        assert!(!schema.matches_pattern("network.sub.deep", "network.*"));
462        assert!(!schema.matches_pattern("other.default", "network.*"));
463        assert!(!schema.matches_pattern("network", "network.*"));
464    }
465}