Skip to main content

axon/
tool_validator.rs

1//! Tool result validation and effect tracking.
2//!
3//! After a tool executes, validates the output against its declared
4//! `output_schema` and records which effects were activated from the
5//! tool's `effect_row`.
6//!
7//! Validation rules (by output_schema value):
8//!   "JSON" / "json"       — output must be valid JSON
9//!   "number" / "numeric"  — output must parse as f64
10//!   "boolean" / "bool"    — output must be "true" or "false"
11//!   "nonempty"            — output must not be empty or whitespace-only
12//!   ""                    — no validation (always passes)
13//!   other                 — treated as a type name, validated as non-empty
14//!
15//! Effect categories:
16//!   read     — reads data from external source
17//!   write    — writes/persists data externally
18//!   network  — makes network calls
19//!   compute  — performs significant computation
20//!   side     — general side effect
21//!
22//! The `EffectTracker` accumulates effect records during execution
23//! for inclusion in the execution report.
24
25use std::collections::HashMap;
26
27// ── Validation ─────────────────────────────────────────────────────────────
28
29/// Result of validating a tool output.
30#[derive(Debug, Clone)]
31pub struct ValidationResult {
32    pub tool_name: String,
33    pub schema: String,
34    pub passed: bool,
35    pub message: String,
36}
37
38/// Validate a tool's output against its declared output_schema.
39pub fn validate_output(tool_name: &str, output: &str, schema: &str) -> ValidationResult {
40    let schema_lower = schema.trim().to_lowercase();
41
42    let (passed, message) = match schema_lower.as_str() {
43        // No schema declared — always passes
44        "" => (true, "no schema declared".to_string()),
45
46        // JSON validation
47        "json" => {
48            match serde_json::from_str::<serde_json::Value>(output) {
49                Ok(_) => (true, "valid JSON".to_string()),
50                Err(e) => (false, format!("invalid JSON: {e}")),
51            }
52        }
53
54        // Numeric validation
55        "number" | "numeric" | "integer" | "float" => {
56            match output.trim().parse::<f64>() {
57                Ok(_) => (true, "valid number".to_string()),
58                Err(_) => (false, format!("expected number, got: '{}'", truncate(output, 50))),
59            }
60        }
61
62        // Boolean validation
63        "boolean" | "bool" => {
64            let lower = output.trim().to_lowercase();
65            if lower == "true" || lower == "false" {
66                (true, "valid boolean".to_string())
67            } else {
68                (false, format!("expected boolean, got: '{}'", truncate(output, 50)))
69            }
70        }
71
72        // Non-empty validation
73        "nonempty" | "non_empty" | "required" => {
74            if output.trim().is_empty() {
75                (false, "output is empty".to_string())
76            } else {
77                (true, "non-empty output".to_string())
78            }
79        }
80
81        // Named type — treated as non-empty check
82        _ => {
83            if output.trim().is_empty() {
84                (false, format!("expected {schema} output, got empty"))
85            } else {
86                (true, format!("output present (schema: {schema})"))
87            }
88        }
89    };
90
91    ValidationResult {
92        tool_name: tool_name.to_string(),
93        schema: schema.to_string(),
94        passed,
95        message,
96    }
97}
98
99fn truncate(s: &str, max: usize) -> String {
100    if s.len() > max {
101        format!("{}…", &s[..max])
102    } else {
103        s.to_string()
104    }
105}
106
107// ── Effect tracking ────────────────────────────────────────────────────────
108
109/// A recorded tool effect event.
110#[derive(Debug, Clone)]
111pub struct EffectRecord {
112    pub tool_name: String,
113    pub step_name: String,
114    pub unit_name: String,
115    pub effects: Vec<String>,
116}
117
118/// Tracks tool effects during execution.
119#[derive(Debug)]
120pub struct EffectTracker {
121    records: Vec<EffectRecord>,
122    effect_counts: HashMap<String, usize>,
123}
124
125impl EffectTracker {
126    pub fn new() -> Self {
127        EffectTracker {
128            records: Vec::new(),
129            effect_counts: HashMap::new(),
130        }
131    }
132
133    /// Record a tool execution with its declared effects.
134    pub fn record(
135        &mut self,
136        tool_name: &str,
137        step_name: &str,
138        unit_name: &str,
139        effects: &[String],
140    ) {
141        for effect in effects {
142            *self.effect_counts.entry(effect.clone()).or_insert(0) += 1;
143        }
144        self.records.push(EffectRecord {
145            tool_name: tool_name.to_string(),
146            step_name: step_name.to_string(),
147            unit_name: unit_name.to_string(),
148            effects: effects.to_vec(),
149        });
150    }
151
152    /// All recorded effect events.
153    pub fn records(&self) -> &[EffectRecord] {
154        &self.records
155    }
156
157    /// Total number of tool executions tracked.
158    pub fn total_executions(&self) -> usize {
159        self.records.len()
160    }
161
162    /// Count of a specific effect type across all executions.
163    pub fn effect_count(&self, effect: &str) -> usize {
164        self.effect_counts.get(effect).copied().unwrap_or(0)
165    }
166
167    /// All distinct effect types observed.
168    pub fn distinct_effects(&self) -> Vec<&str> {
169        let mut effects: Vec<&str> = self.effect_counts.keys().map(|k| k.as_str()).collect();
170        effects.sort();
171        effects
172    }
173
174    /// Whether any network effects have been recorded.
175    pub fn has_network_effects(&self) -> bool {
176        self.effect_count("network") > 0
177    }
178
179    /// Whether any write effects have been recorded.
180    pub fn has_write_effects(&self) -> bool {
181        self.effect_count("write") > 0
182    }
183
184    /// Summary string for display.
185    pub fn summary(&self) -> String {
186        if self.records.is_empty() {
187            return "no tool effects".to_string();
188        }
189        let parts: Vec<String> = self
190            .effect_counts
191            .iter()
192            .map(|(k, v)| format!("{k}:{v}"))
193            .collect();
194        format!(
195            "{} tool executions, effects: {}",
196            self.records.len(),
197            parts.join(", ")
198        )
199    }
200}
201
202// ── Tests ──────────────────────────────────────────────────────────────────
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    // ── Validation tests ───────────────────────────────────────────
209
210    #[test]
211    fn validate_no_schema() {
212        let r = validate_output("Tool", "anything", "");
213        assert!(r.passed);
214    }
215
216    #[test]
217    fn validate_json_valid() {
218        let r = validate_output("Tool", r#"{"key": "value"}"#, "JSON");
219        assert!(r.passed);
220    }
221
222    #[test]
223    fn validate_json_array() {
224        let r = validate_output("Tool", "[1, 2, 3]", "json");
225        assert!(r.passed);
226    }
227
228    #[test]
229    fn validate_json_invalid() {
230        let r = validate_output("Tool", "not json at all", "JSON");
231        assert!(!r.passed);
232        assert!(r.message.contains("invalid JSON"));
233    }
234
235    #[test]
236    fn validate_number_valid() {
237        let r = validate_output("Calc", "42", "number");
238        assert!(r.passed);
239        let r2 = validate_output("Calc", "3.14", "numeric");
240        assert!(r2.passed);
241        let r3 = validate_output("Calc", "-100", "integer");
242        assert!(r3.passed);
243    }
244
245    #[test]
246    fn validate_number_invalid() {
247        let r = validate_output("Calc", "not a number", "number");
248        assert!(!r.passed);
249    }
250
251    #[test]
252    fn validate_boolean_valid() {
253        assert!(validate_output("T", "true", "boolean").passed);
254        assert!(validate_output("T", "false", "bool").passed);
255        assert!(validate_output("T", "TRUE", "boolean").passed);
256    }
257
258    #[test]
259    fn validate_boolean_invalid() {
260        let r = validate_output("T", "maybe", "boolean");
261        assert!(!r.passed);
262    }
263
264    #[test]
265    fn validate_nonempty_valid() {
266        let r = validate_output("T", "has content", "nonempty");
267        assert!(r.passed);
268    }
269
270    #[test]
271    fn validate_nonempty_invalid() {
272        let r = validate_output("T", "  ", "nonempty");
273        assert!(!r.passed);
274    }
275
276    #[test]
277    fn validate_named_type_present() {
278        let r = validate_output("T", "some data", "EntityMap");
279        assert!(r.passed);
280        assert!(r.message.contains("EntityMap"));
281    }
282
283    #[test]
284    fn validate_named_type_empty() {
285        let r = validate_output("T", "", "RiskAnalysis");
286        assert!(!r.passed);
287    }
288
289    // ── Effect tracker tests ───────────────────────────────────────
290
291    #[test]
292    fn tracker_empty() {
293        let tracker = EffectTracker::new();
294        assert_eq!(tracker.total_executions(), 0);
295        assert!(tracker.distinct_effects().is_empty());
296        assert!(!tracker.has_network_effects());
297        assert!(!tracker.has_write_effects());
298        assert_eq!(tracker.summary(), "no tool effects");
299    }
300
301    #[test]
302    fn tracker_record_effects() {
303        let mut tracker = EffectTracker::new();
304        tracker.record(
305            "WebSearch",
306            "Search",
307            "Flow1",
308            &["network".to_string(), "read".to_string()],
309        );
310
311        assert_eq!(tracker.total_executions(), 1);
312        assert!(tracker.has_network_effects());
313        assert!(!tracker.has_write_effects());
314        assert_eq!(tracker.effect_count("network"), 1);
315        assert_eq!(tracker.effect_count("read"), 1);
316    }
317
318    #[test]
319    fn tracker_multiple_records() {
320        let mut tracker = EffectTracker::new();
321        tracker.record("WebSearch", "S1", "F1", &["network".to_string()]);
322        tracker.record("DBWrite", "S2", "F1", &["write".to_string(), "network".to_string()]);
323        tracker.record("Calculator", "S3", "F1", &["compute".to_string()]);
324
325        assert_eq!(tracker.total_executions(), 3);
326        assert_eq!(tracker.effect_count("network"), 2);
327        assert_eq!(tracker.effect_count("write"), 1);
328        assert_eq!(tracker.effect_count("compute"), 1);
329        assert!(tracker.has_network_effects());
330        assert!(tracker.has_write_effects());
331    }
332
333    #[test]
334    fn tracker_distinct_effects_sorted() {
335        let mut tracker = EffectTracker::new();
336        tracker.record("T1", "S", "F", &["write".to_string(), "compute".to_string()]);
337        tracker.record("T2", "S", "F", &["network".to_string(), "read".to_string()]);
338
339        let effects = tracker.distinct_effects();
340        assert_eq!(effects, vec!["compute", "network", "read", "write"]);
341    }
342
343    #[test]
344    fn tracker_records_accessible() {
345        let mut tracker = EffectTracker::new();
346        tracker.record("WebSearch", "Search", "Flow1", &["network".to_string()]);
347
348        let records = tracker.records();
349        assert_eq!(records.len(), 1);
350        assert_eq!(records[0].tool_name, "WebSearch");
351        assert_eq!(records[0].step_name, "Search");
352        assert_eq!(records[0].unit_name, "Flow1");
353        assert_eq!(records[0].effects, vec!["network"]);
354    }
355
356    #[test]
357    fn tracker_summary_format() {
358        let mut tracker = EffectTracker::new();
359        tracker.record("T1", "S", "F", &["network".to_string()]);
360        tracker.record("T2", "S", "F", &["network".to_string()]);
361
362        let summary = tracker.summary();
363        assert!(summary.contains("2 tool executions"));
364        assert!(summary.contains("network:2"));
365    }
366}