Skip to main content

cruxx_script/
validator.rs

1//! Static pipeline validation against registered handler metadata.
2
3use std::fmt;
4
5use serde_json::Value;
6
7use crate::metadata::ArgType;
8use crate::registry::HandlerRegistry;
9use crate::schema::{ArmDef, PipelineDef, RouteBranch, StepDef};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum DiagnosticSeverity {
13    Error,
14    Warning,
15}
16
17impl fmt::Display for DiagnosticSeverity {
18    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19        match self {
20            DiagnosticSeverity::Error => f.write_str("error"),
21            DiagnosticSeverity::Warning => f.write_str("warning"),
22        }
23    }
24}
25
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct ValidationDiagnostic {
28    pub severity: DiagnosticSeverity,
29    pub location: String,
30    pub message: String,
31}
32
33impl ValidationDiagnostic {
34    pub fn error(location: impl Into<String>, message: impl Into<String>) -> Self {
35        Self {
36            severity: DiagnosticSeverity::Error,
37            location: location.into(),
38            message: message.into(),
39        }
40    }
41
42    pub fn warning(location: impl Into<String>, message: impl Into<String>) -> Self {
43        Self {
44            severity: DiagnosticSeverity::Warning,
45            location: location.into(),
46            message: message.into(),
47        }
48    }
49}
50
51#[derive(Debug, Clone, Default, PartialEq, Eq)]
52pub struct ValidationReport {
53    pub diagnostics: Vec<ValidationDiagnostic>,
54}
55
56impl ValidationReport {
57    pub fn is_ok(&self) -> bool {
58        self.error_count() == 0
59    }
60
61    pub fn error_count(&self) -> usize {
62        self.diagnostics
63            .iter()
64            .filter(|d| d.severity == DiagnosticSeverity::Error)
65            .count()
66    }
67
68    pub fn warning_count(&self) -> usize {
69        self.diagnostics
70            .iter()
71            .filter(|d| d.severity == DiagnosticSeverity::Warning)
72            .count()
73    }
74
75    fn push(&mut self, diagnostic: ValidationDiagnostic) {
76        self.diagnostics.push(diagnostic);
77    }
78}
79
80/// Validate a parsed pipeline against the registered handler metadata.
81pub fn validate_pipeline(pipeline: &PipelineDef, registry: &HandlerRegistry) -> ValidationReport {
82    let mut report = ValidationReport::default();
83
84    for (idx, step) in pipeline.steps.iter().enumerate() {
85        let location = format!("steps[{idx}]");
86        match step {
87            StepDef::Step(node) => {
88                let handler = node.handler.as_deref().unwrap_or(&node.step);
89                validate_handler_ref(
90                    &mut report,
91                    registry,
92                    &location,
93                    handler,
94                    node.args.as_ref(),
95                );
96            }
97            StepDef::Delegate(node) => {
98                if registry.get_agent(&node.delegate).is_none() {
99                    report.push(ValidationDiagnostic::warning(
100                        &location,
101                        format!("agent '{}' is not registered", node.delegate),
102                    ));
103                }
104            }
105            StepDef::Pipe(node) => {
106                for (stage_idx, arm) in node.stages.iter().enumerate() {
107                    validate_arm(
108                        &mut report,
109                        registry,
110                        &format!("{location}.stages[{stage_idx}]"),
111                        arm,
112                    );
113                }
114            }
115            StepDef::JoinAll(node) => {
116                for (arm_idx, arm) in node.arms.iter().enumerate() {
117                    validate_arm(
118                        &mut report,
119                        registry,
120                        &format!("{location}.arms[{arm_idx}]"),
121                        arm,
122                    );
123                }
124            }
125            StepDef::RouteOnConfidence(node) => {
126                validate_routes(&mut report, &location, &node.routes);
127                for (route_idx, branch) in node.routes.iter().enumerate() {
128                    validate_handler_ref(
129                        &mut report,
130                        registry,
131                        &format!("{location}.routes[{route_idx}]"),
132                        &branch.handler,
133                        branch.args.as_ref(),
134                    );
135                }
136            }
137            StepDef::Speculate(node) => {
138                for (arm_idx, arm) in node.arms.iter().enumerate() {
139                    validate_arm(
140                        &mut report,
141                        registry,
142                        &format!("{location}.arms[{arm_idx}]"),
143                        arm,
144                    );
145                }
146            }
147        }
148    }
149
150    report
151}
152
153fn validate_arm(
154    report: &mut ValidationReport,
155    registry: &HandlerRegistry,
156    location: &str,
157    arm: &ArmDef,
158) {
159    validate_handler_ref(report, registry, location, arm.handler_name(), arm.args());
160}
161
162fn validate_handler_ref(
163    report: &mut ValidationReport,
164    registry: &HandlerRegistry,
165    location: &str,
166    handler: &str,
167    args: Option<&Value>,
168) {
169    let Some(metadata) = registry.get_metadata(handler) else {
170        report.push(ValidationDiagnostic::error(
171            location,
172            format!("handler '{handler}' is not registered"),
173        ));
174        return;
175    };
176
177    let Some(schema_args) = args else {
178        if metadata.args.has_required_args() {
179            let missing = metadata
180                .args
181                .args
182                .iter()
183                .filter(|spec| spec.required)
184                .map(|spec| spec.name.as_str())
185                .collect::<Vec<_>>()
186                .join(", ");
187            report.push(ValidationDiagnostic::error(
188                location,
189                format!("handler '{handler}' is missing required args: {missing}"),
190            ));
191        }
192        return;
193    };
194
195    let Some(arg_map) = schema_args.as_object() else {
196        report.push(ValidationDiagnostic::error(
197            location,
198            format!("handler '{handler}' args must be an object"),
199        ));
200        return;
201    };
202
203    for spec in &metadata.args.args {
204        let Some(value) = arg_map.get(&spec.name) else {
205            if spec.required {
206                report.push(ValidationDiagnostic::error(
207                    location,
208                    format!(
209                        "handler '{handler}' is missing required arg '{}'",
210                        spec.name
211                    ),
212                ));
213            }
214            continue;
215        };
216
217        if is_template_string(value) {
218            continue;
219        }
220
221        if !spec.arg_type.matches(value) {
222            report.push(ValidationDiagnostic::error(
223                location,
224                format!(
225                    "handler '{handler}' arg '{}' expected {}, got {}",
226                    spec.name,
227                    display_arg_type(spec.arg_type),
228                    display_value_type(value)
229                ),
230            ));
231        }
232    }
233
234    if !metadata.args.allow_extra {
235        for key in arg_map.keys() {
236            if metadata.args.get(key).is_none() {
237                report.push(ValidationDiagnostic::error(
238                    location,
239                    format!("handler '{handler}' received unexpected arg '{key}'"),
240                ));
241            }
242        }
243    }
244}
245
246fn is_template_string(value: &Value) -> bool {
247    value
248        .as_str()
249        .map(|s| s.trim_start().starts_with("{{"))
250        .unwrap_or(false)
251}
252
253fn display_arg_type(arg_type: ArgType) -> &'static str {
254    match arg_type {
255        ArgType::Any => "any",
256        ArgType::String => "string",
257        ArgType::Number => "number",
258        ArgType::Integer => "integer",
259        ArgType::Boolean => "boolean",
260        ArgType::Object => "object",
261        ArgType::Array => "array",
262    }
263}
264
265fn display_value_type(value: &Value) -> &'static str {
266    match value {
267        Value::Null => "null",
268        Value::Bool(_) => "boolean",
269        Value::Number(_) => "number",
270        Value::String(_) => "string",
271        Value::Array(_) => "array",
272        Value::Object(_) => "object",
273    }
274}
275
276#[derive(Debug, Clone, Copy)]
277struct ParsedRange {
278    lo: f32,
279    hi: f32,
280    include_hi: bool,
281}
282
283fn validate_routes(report: &mut ValidationReport, location: &str, routes: &[RouteBranch]) {
284    let mut parsed = Vec::new();
285
286    for (idx, branch) in routes.iter().enumerate() {
287        match parse_range(&branch.range) {
288            Ok(range) => {
289                if range.lo < 0.0 || range.hi > 1.0 {
290                    report.push(ValidationDiagnostic::error(
291                        format!("{location}.routes[{idx}]"),
292                        format!(
293                            "confidence range '{}' must stay within [0.0, 1.0]",
294                            branch.range
295                        ),
296                    ));
297                }
298                if range.lo > range.hi || (range.lo == range.hi && !range.include_hi) {
299                    report.push(ValidationDiagnostic::error(
300                        format!("{location}.routes[{idx}]"),
301                        format!("confidence range '{}' is empty", branch.range),
302                    ));
303                }
304                parsed.push((idx, range));
305            }
306            Err(e) => report.push(ValidationDiagnostic::error(
307                format!("{location}.routes[{idx}]"),
308                format!("invalid confidence range '{}': {e}", branch.range),
309            )),
310        }
311    }
312
313    parsed.sort_by(|a, b| {
314        a.1.lo
315            .partial_cmp(&b.1.lo)
316            .unwrap_or(std::cmp::Ordering::Equal)
317    });
318    for pair in parsed.windows(2) {
319        let (left_idx, left) = pair[0];
320        let (right_idx, right) = pair[1];
321        if ranges_overlap(left, right) {
322            report.push(ValidationDiagnostic::error(
323                location,
324                format!("confidence ranges for routes {left_idx} and {right_idx} overlap"),
325            ));
326        }
327    }
328}
329
330fn parse_range(s: &str) -> Result<ParsedRange, &'static str> {
331    let s = s.trim();
332    if !(s.starts_with('[') || s.starts_with('(')) {
333        return Err("missing opening bracket");
334    }
335    let include_hi = if s.ends_with(']') {
336        true
337    } else if s.ends_with(')') {
338        false
339    } else {
340        return Err("missing closing bracket");
341    };
342
343    let inner = &s[1..s.len() - 1];
344    let Some((lo, hi)) = inner.split_once(',') else {
345        return Err("expected lower and upper bounds");
346    };
347    let lo = lo
348        .trim()
349        .parse::<f32>()
350        .map_err(|_| "invalid lower bound")?;
351    let hi = hi
352        .trim()
353        .parse::<f32>()
354        .map_err(|_| "invalid upper bound")?;
355    Ok(ParsedRange { lo, hi, include_hi })
356}
357
358fn ranges_overlap(left: ParsedRange, right: ParsedRange) -> bool {
359    if left.hi > right.lo {
360        return true;
361    }
362    left.hi == right.lo && left.include_hi
363}