Skip to main content

car_workflow/
verify.rs

1//! Static workflow verification — validate structure before execution.
2
3use std::collections::{HashMap, HashSet, VecDeque};
4
5use crate::types::*;
6
7/// A single verification finding.
8#[derive(Debug, Clone)]
9pub struct WorkflowIssue {
10    pub severity: String, // "error", "warning"
11    pub stage_id: Option<String>,
12    pub message: String,
13}
14
15/// Result of static workflow verification.
16#[derive(Debug)]
17pub struct WorkflowVerifyResult {
18    pub valid: bool,
19    pub issues: Vec<WorkflowIssue>,
20    pub reachable_stages: Vec<String>,
21    pub unreachable_stages: Vec<String>,
22    pub has_cycles: bool,
23}
24
25/// Statically verify a workflow definition for structural correctness.
26pub fn verify_workflow(workflow: &Workflow) -> WorkflowVerifyResult {
27    let mut issues = Vec::new();
28    let stage_ids: HashSet<&str> = workflow.stages.iter().map(|s| s.id.as_str()).collect();
29
30    // 1. Start stage exists
31    if !stage_ids.contains(workflow.start.as_str()) {
32        issues.push(WorkflowIssue {
33            severity: "error".into(),
34            stage_id: None,
35            message: format!("start stage '{}' does not exist", workflow.start),
36        });
37    }
38
39    // 2. Edge references valid stage IDs
40    for edge in &workflow.edges {
41        if !stage_ids.contains(edge.from.as_str()) {
42            issues.push(WorkflowIssue {
43                severity: "error".into(),
44                stage_id: None,
45                message: format!("edge from '{}' references unknown stage", edge.from),
46            });
47        }
48        if !stage_ids.contains(edge.to.as_str()) {
49            issues.push(WorkflowIssue {
50                severity: "error".into(),
51                stage_id: None,
52                message: format!("edge to '{}' references unknown stage", edge.to),
53            });
54        }
55    }
56
57    // 3. Compensation StageRef references valid stage IDs
58    for stage in &workflow.stages {
59        if let Some(CompensationHandler::StageRef { stage_id }) = &stage.compensation {
60            if !stage_ids.contains(stage_id.as_str()) {
61                issues.push(WorkflowIssue {
62                    severity: "error".into(),
63                    stage_id: Some(stage.id.clone()),
64                    message: format!(
65                        "compensation for stage '{}' references unknown stage '{}'",
66                        stage.id, stage_id
67                    ),
68                });
69            }
70        }
71    }
72
73    // 4. Reachability via BFS from start
74    let adj: HashMap<&str, Vec<&str>> = {
75        let mut m: HashMap<&str, Vec<&str>> = HashMap::new();
76        for edge in &workflow.edges {
77            m.entry(edge.from.as_str()).or_default().push(edge.to.as_str());
78        }
79        m
80    };
81
82    let mut visited: HashSet<&str> = HashSet::new();
83    let mut queue: VecDeque<&str> = VecDeque::new();
84    if stage_ids.contains(workflow.start.as_str()) {
85        queue.push_back(workflow.start.as_str());
86        visited.insert(workflow.start.as_str());
87    }
88    while let Some(node) = queue.pop_front() {
89        if let Some(neighbors) = adj.get(node) {
90            for &next in neighbors {
91                if visited.insert(next) {
92                    queue.push_back(next);
93                }
94            }
95        }
96    }
97
98    let reachable_stages: Vec<String> = visited.iter().map(|s| s.to_string()).collect();
99    let unreachable_stages: Vec<String> = stage_ids
100        .iter()
101        .filter(|s| !visited.contains(**s))
102        .map(|s| s.to_string())
103        .collect();
104
105    for id in &unreachable_stages {
106        issues.push(WorkflowIssue {
107            severity: "warning".into(),
108            stage_id: Some(id.clone()),
109            message: format!("stage '{}' is unreachable from start", id),
110        });
111    }
112
113    // 5. Cycle detection via DFS
114    let has_cycles = detect_cycles(&adj, workflow.start.as_str());
115    if has_cycles {
116        issues.push(WorkflowIssue {
117            severity: "warning".into(),
118            stage_id: None,
119            message: "workflow contains cycles (ensure max_iterations is set)".into(),
120        });
121    }
122
123    // 6. Recurse into sub-workflows
124    for stage in &workflow.stages {
125        if let StageStep::SubWorkflow(ref sw) = stage.step {
126            let sub_result = verify_workflow(&sw.workflow);
127            for issue in sub_result.issues {
128                issues.push(WorkflowIssue {
129                    severity: issue.severity,
130                    stage_id: Some(format!("{}.{}", stage.id, issue.stage_id.unwrap_or_default())),
131                    message: format!("[sub-workflow {}] {}", stage.id, issue.message),
132                });
133            }
134        }
135    }
136
137    // 7. Proposal verification via car-verify
138    for stage in &workflow.stages {
139        if let StageStep::Proposal(ref ps) = stage.step {
140            let vr = car_verify::verify(&ps.proposal, None, None, 100);
141            for issue in &vr.issues {
142                if issue.severity == "error" {
143                    issues.push(WorkflowIssue {
144                        severity: "error".into(),
145                        stage_id: Some(stage.id.clone()),
146                        message: format!("[proposal] {}", issue.message),
147                    });
148                }
149            }
150        }
151    }
152
153    let valid = !issues.iter().any(|i| i.severity == "error");
154
155    WorkflowVerifyResult {
156        valid,
157        issues,
158        reachable_stages,
159        unreachable_stages,
160        has_cycles,
161    }
162}
163
164/// DFS-based cycle detection.
165fn detect_cycles(adj: &HashMap<&str, Vec<&str>>, start: &str) -> bool {
166    let mut visited = HashSet::new();
167    let mut stack = HashSet::new();
168
169    fn dfs<'a>(
170        node: &'a str,
171        adj: &HashMap<&'a str, Vec<&'a str>>,
172        visited: &mut HashSet<&'a str>,
173        stack: &mut HashSet<&'a str>,
174    ) -> bool {
175        visited.insert(node);
176        stack.insert(node);
177
178        if let Some(neighbors) = adj.get(node) {
179            for &next in neighbors {
180                if stack.contains(next) {
181                    return true; // back edge = cycle
182                }
183                if !visited.contains(next) && dfs(next, adj, visited, stack) {
184                    return true;
185                }
186            }
187        }
188
189        stack.remove(node);
190        false
191    }
192
193    dfs(start, adj, &mut visited, &mut stack)
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use car_ir::ActionProposal;
200
201    fn make_stage(id: &str) -> Stage {
202        Stage {
203            id: id.into(),
204            name: id.into(),
205            step: StageStep::Proposal(ProposalStep {
206                proposal: ActionProposal {
207                    id: format!("p-{}", id),
208                    source: "test".into(),
209                    actions: vec![],
210                    timestamp: chrono::Utc::now(),
211                    context: std::collections::HashMap::new(),
212                },
213            }),
214            compensation: None,
215            timeout_ms: None,
216            metadata: std::collections::HashMap::new(),
217        }
218    }
219
220    #[test]
221    fn valid_linear_workflow() {
222        let wf = Workflow {
223            id: "test".into(),
224            name: "Test".into(),
225            start: "a".into(),
226            stages: vec![make_stage("a"), make_stage("b"), make_stage("c")],
227            edges: vec![
228                Edge { from: "a".into(), to: "b".into(), conditions: vec![], label: String::new() },
229                Edge { from: "b".into(), to: "c".into(), conditions: vec![], label: String::new() },
230            ],
231            max_iterations: 100,
232            metadata: std::collections::HashMap::new(),
233        };
234        let result = verify_workflow(&wf);
235        assert!(result.valid);
236        assert!(!result.has_cycles);
237        assert_eq!(result.reachable_stages.len(), 3);
238        assert!(result.unreachable_stages.is_empty());
239    }
240
241    #[test]
242    fn missing_start_stage() {
243        let wf = Workflow {
244            id: "test".into(),
245            name: "Test".into(),
246            start: "nonexistent".into(),
247            stages: vec![make_stage("a")],
248            edges: vec![],
249            max_iterations: 100,
250            metadata: std::collections::HashMap::new(),
251        };
252        let result = verify_workflow(&wf);
253        assert!(!result.valid);
254        assert!(result.issues.iter().any(|i| i.message.contains("nonexistent")));
255    }
256
257    #[test]
258    fn unreachable_stage() {
259        let wf = Workflow {
260            id: "test".into(),
261            name: "Test".into(),
262            start: "a".into(),
263            stages: vec![make_stage("a"), make_stage("b"), make_stage("orphan")],
264            edges: vec![
265                Edge { from: "a".into(), to: "b".into(), conditions: vec![], label: String::new() },
266            ],
267            max_iterations: 100,
268            metadata: std::collections::HashMap::new(),
269        };
270        let result = verify_workflow(&wf);
271        assert!(result.valid); // unreachable is a warning, not error
272        assert_eq!(result.unreachable_stages.len(), 1);
273        assert!(result.unreachable_stages.contains(&"orphan".to_string()));
274    }
275
276    #[test]
277    fn cycle_detected() {
278        let wf = Workflow {
279            id: "test".into(),
280            name: "Test".into(),
281            start: "a".into(),
282            stages: vec![make_stage("a"), make_stage("b")],
283            edges: vec![
284                Edge { from: "a".into(), to: "b".into(), conditions: vec![], label: String::new() },
285                Edge { from: "b".into(), to: "a".into(), conditions: vec![], label: String::new() },
286            ],
287            max_iterations: 100,
288            metadata: std::collections::HashMap::new(),
289        };
290        let result = verify_workflow(&wf);
291        assert!(result.valid); // cycles are warnings
292        assert!(result.has_cycles);
293    }
294
295    #[test]
296    fn invalid_edge_reference() {
297        let wf = Workflow {
298            id: "test".into(),
299            name: "Test".into(),
300            start: "a".into(),
301            stages: vec![make_stage("a")],
302            edges: vec![
303                Edge { from: "a".into(), to: "ghost".into(), conditions: vec![], label: String::new() },
304            ],
305            max_iterations: 100,
306            metadata: std::collections::HashMap::new(),
307        };
308        let result = verify_workflow(&wf);
309        assert!(!result.valid);
310        assert!(result.issues.iter().any(|i| i.message.contains("ghost")));
311    }
312
313    #[test]
314    fn invalid_compensation_ref() {
315        let mut stage = make_stage("a");
316        stage.compensation = Some(CompensationHandler::StageRef {
317            stage_id: "nonexistent".into(),
318        });
319        let wf = Workflow {
320            id: "test".into(),
321            name: "Test".into(),
322            start: "a".into(),
323            stages: vec![stage],
324            edges: vec![],
325            max_iterations: 100,
326            metadata: std::collections::HashMap::new(),
327        };
328        let result = verify_workflow(&wf);
329        assert!(!result.valid);
330    }
331}