1use std::fmt;
4
5use serde_json::Value;
6
7use crate::metadata::ArgType;
8use crate::registry::HandlerRegistry;
9use crate::schema::{ArmDef, PipelineDef, RouteBranch, StepDef};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum DiagnosticSeverity {
13 Error,
14 Warning,
15}
16
17impl fmt::Display for DiagnosticSeverity {
18 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19 match self {
20 DiagnosticSeverity::Error => f.write_str("error"),
21 DiagnosticSeverity::Warning => f.write_str("warning"),
22 }
23 }
24}
25
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct ValidationDiagnostic {
28 pub severity: DiagnosticSeverity,
29 pub location: String,
30 pub message: String,
31}
32
33impl ValidationDiagnostic {
34 pub fn error(location: impl Into<String>, message: impl Into<String>) -> Self {
35 Self {
36 severity: DiagnosticSeverity::Error,
37 location: location.into(),
38 message: message.into(),
39 }
40 }
41
42 pub fn warning(location: impl Into<String>, message: impl Into<String>) -> Self {
43 Self {
44 severity: DiagnosticSeverity::Warning,
45 location: location.into(),
46 message: message.into(),
47 }
48 }
49}
50
51#[derive(Debug, Clone, Default, PartialEq, Eq)]
52pub struct ValidationReport {
53 pub diagnostics: Vec<ValidationDiagnostic>,
54}
55
56impl ValidationReport {
57 pub fn is_ok(&self) -> bool {
58 self.error_count() == 0
59 }
60
61 pub fn error_count(&self) -> usize {
62 self.diagnostics
63 .iter()
64 .filter(|d| d.severity == DiagnosticSeverity::Error)
65 .count()
66 }
67
68 pub fn warning_count(&self) -> usize {
69 self.diagnostics
70 .iter()
71 .filter(|d| d.severity == DiagnosticSeverity::Warning)
72 .count()
73 }
74
75 fn push(&mut self, diagnostic: ValidationDiagnostic) {
76 self.diagnostics.push(diagnostic);
77 }
78}
79
80pub fn validate_pipeline(pipeline: &PipelineDef, registry: &HandlerRegistry) -> ValidationReport {
82 let mut report = ValidationReport::default();
83
84 for (idx, step) in pipeline.steps.iter().enumerate() {
85 let location = format!("steps[{idx}]");
86 match step {
87 StepDef::Step(node) => {
88 let handler = node.handler.as_deref().unwrap_or(&node.step);
89 validate_handler_ref(
90 &mut report,
91 registry,
92 &location,
93 handler,
94 node.args.as_ref(),
95 );
96 }
97 StepDef::Delegate(node) => {
98 if registry.get_agent(&node.delegate).is_none() {
99 report.push(ValidationDiagnostic::warning(
100 &location,
101 format!("agent '{}' is not registered", node.delegate),
102 ));
103 }
104 }
105 StepDef::Pipe(node) => {
106 for (stage_idx, arm) in node.stages.iter().enumerate() {
107 validate_arm(
108 &mut report,
109 registry,
110 &format!("{location}.stages[{stage_idx}]"),
111 arm,
112 );
113 }
114 }
115 StepDef::JoinAll(node) => {
116 for (arm_idx, arm) in node.arms.iter().enumerate() {
117 validate_arm(
118 &mut report,
119 registry,
120 &format!("{location}.arms[{arm_idx}]"),
121 arm,
122 );
123 }
124 }
125 StepDef::RouteOnConfidence(node) => {
126 validate_routes(&mut report, &location, &node.routes);
127 for (route_idx, branch) in node.routes.iter().enumerate() {
128 validate_handler_ref(
129 &mut report,
130 registry,
131 &format!("{location}.routes[{route_idx}]"),
132 &branch.handler,
133 branch.args.as_ref(),
134 );
135 }
136 }
137 StepDef::Speculate(node) => {
138 for (arm_idx, arm) in node.arms.iter().enumerate() {
139 validate_arm(
140 &mut report,
141 registry,
142 &format!("{location}.arms[{arm_idx}]"),
143 arm,
144 );
145 }
146 }
147 }
148 }
149
150 report
151}
152
153fn validate_arm(
154 report: &mut ValidationReport,
155 registry: &HandlerRegistry,
156 location: &str,
157 arm: &ArmDef,
158) {
159 validate_handler_ref(report, registry, location, arm.handler_name(), arm.args());
160}
161
162fn validate_handler_ref(
163 report: &mut ValidationReport,
164 registry: &HandlerRegistry,
165 location: &str,
166 handler: &str,
167 args: Option<&Value>,
168) {
169 let Some(metadata) = registry.get_metadata(handler) else {
170 report.push(ValidationDiagnostic::error(
171 location,
172 format!("handler '{handler}' is not registered"),
173 ));
174 return;
175 };
176
177 let Some(schema_args) = args else {
178 if metadata.args.has_required_args() {
179 let missing = metadata
180 .args
181 .args
182 .iter()
183 .filter(|spec| spec.required)
184 .map(|spec| spec.name.as_str())
185 .collect::<Vec<_>>()
186 .join(", ");
187 report.push(ValidationDiagnostic::error(
188 location,
189 format!("handler '{handler}' is missing required args: {missing}"),
190 ));
191 }
192 return;
193 };
194
195 let Some(arg_map) = schema_args.as_object() else {
196 report.push(ValidationDiagnostic::error(
197 location,
198 format!("handler '{handler}' args must be an object"),
199 ));
200 return;
201 };
202
203 for spec in &metadata.args.args {
204 let Some(value) = arg_map.get(&spec.name) else {
205 if spec.required {
206 report.push(ValidationDiagnostic::error(
207 location,
208 format!(
209 "handler '{handler}' is missing required arg '{}'",
210 spec.name
211 ),
212 ));
213 }
214 continue;
215 };
216
217 if is_template_string(value) {
218 continue;
219 }
220
221 if !spec.arg_type.matches(value) {
222 report.push(ValidationDiagnostic::error(
223 location,
224 format!(
225 "handler '{handler}' arg '{}' expected {}, got {}",
226 spec.name,
227 display_arg_type(spec.arg_type),
228 display_value_type(value)
229 ),
230 ));
231 }
232 }
233
234 if !metadata.args.allow_extra {
235 for key in arg_map.keys() {
236 if metadata.args.get(key).is_none() {
237 report.push(ValidationDiagnostic::error(
238 location,
239 format!("handler '{handler}' received unexpected arg '{key}'"),
240 ));
241 }
242 }
243 }
244}
245
246fn is_template_string(value: &Value) -> bool {
247 value
248 .as_str()
249 .map(|s| s.trim_start().starts_with("{{"))
250 .unwrap_or(false)
251}
252
253fn display_arg_type(arg_type: ArgType) -> &'static str {
254 match arg_type {
255 ArgType::Any => "any",
256 ArgType::String => "string",
257 ArgType::Number => "number",
258 ArgType::Integer => "integer",
259 ArgType::Boolean => "boolean",
260 ArgType::Object => "object",
261 ArgType::Array => "array",
262 }
263}
264
265fn display_value_type(value: &Value) -> &'static str {
266 match value {
267 Value::Null => "null",
268 Value::Bool(_) => "boolean",
269 Value::Number(_) => "number",
270 Value::String(_) => "string",
271 Value::Array(_) => "array",
272 Value::Object(_) => "object",
273 }
274}
275
276#[derive(Debug, Clone, Copy)]
277struct ParsedRange {
278 lo: f32,
279 hi: f32,
280 include_hi: bool,
281}
282
283fn validate_routes(report: &mut ValidationReport, location: &str, routes: &[RouteBranch]) {
284 let mut parsed = Vec::new();
285
286 for (idx, branch) in routes.iter().enumerate() {
287 match parse_range(&branch.range) {
288 Ok(range) => {
289 if range.lo < 0.0 || range.hi > 1.0 {
290 report.push(ValidationDiagnostic::error(
291 format!("{location}.routes[{idx}]"),
292 format!(
293 "confidence range '{}' must stay within [0.0, 1.0]",
294 branch.range
295 ),
296 ));
297 }
298 if range.lo > range.hi || (range.lo == range.hi && !range.include_hi) {
299 report.push(ValidationDiagnostic::error(
300 format!("{location}.routes[{idx}]"),
301 format!("confidence range '{}' is empty", branch.range),
302 ));
303 }
304 parsed.push((idx, range));
305 }
306 Err(e) => report.push(ValidationDiagnostic::error(
307 format!("{location}.routes[{idx}]"),
308 format!("invalid confidence range '{}': {e}", branch.range),
309 )),
310 }
311 }
312
313 parsed.sort_by(|a, b| {
314 a.1.lo
315 .partial_cmp(&b.1.lo)
316 .unwrap_or(std::cmp::Ordering::Equal)
317 });
318 for pair in parsed.windows(2) {
319 let (left_idx, left) = pair[0];
320 let (right_idx, right) = pair[1];
321 if ranges_overlap(left, right) {
322 report.push(ValidationDiagnostic::error(
323 location,
324 format!("confidence ranges for routes {left_idx} and {right_idx} overlap"),
325 ));
326 }
327 }
328}
329
330fn parse_range(s: &str) -> Result<ParsedRange, &'static str> {
331 let s = s.trim();
332 if !(s.starts_with('[') || s.starts_with('(')) {
333 return Err("missing opening bracket");
334 }
335 let include_hi = if s.ends_with(']') {
336 true
337 } else if s.ends_with(')') {
338 false
339 } else {
340 return Err("missing closing bracket");
341 };
342
343 let inner = &s[1..s.len() - 1];
344 let Some((lo, hi)) = inner.split_once(',') else {
345 return Err("expected lower and upper bounds");
346 };
347 let lo = lo
348 .trim()
349 .parse::<f32>()
350 .map_err(|_| "invalid lower bound")?;
351 let hi = hi
352 .trim()
353 .parse::<f32>()
354 .map_err(|_| "invalid upper bound")?;
355 Ok(ParsedRange { lo, hi, include_hi })
356}
357
358fn ranges_overlap(left: ParsedRange, right: ParsedRange) -> bool {
359 if left.hi > right.lo {
360 return true;
361 }
362 left.hi == right.lo && left.include_hi
363}