Skip to main content

jamjet_ir/
validate.rs

1use crate::error::{IrError, IrResult};
2use crate::workflow::WorkflowIr;
3use std::collections::{HashSet, VecDeque};
4
5/// Validate a WorkflowIr. Returns Ok(()) if valid, Err with the first
6/// violation found otherwise.
7///
8/// Validation rules:
9/// 1. workflow_id and version are non-empty; version is valid semver
10/// 2. No duplicate node ids
11/// 3. start_node exists in nodes
12/// 4. All edge targets exist
13/// 5. All nodes are reachable from start_node
14/// 6. All paths from start lead to a terminal node ("end" or a terminal kind)
15/// 7. All tool_ref, model_ref, agent_ref resolve to known definitions
16/// 8. All MCP servers referenced by mcp_tool nodes are configured
17/// 9. All remote agents referenced by a2a_task nodes are configured
18pub fn validate_workflow(ir: &WorkflowIr) -> IrResult<()> {
19    validate_metadata(ir)?;
20    validate_no_duplicate_nodes(ir)?;
21    validate_start_node(ir)?;
22    validate_edges(ir)?;
23    validate_reachability(ir)?;
24    validate_refs(ir)?;
25    Ok(())
26}
27
28fn validate_metadata(ir: &WorkflowIr) -> IrResult<()> {
29    if ir.workflow_id.is_empty() {
30        return Err(IrError::InvalidVersion("workflow_id is empty".into()));
31    }
32    // Basic semver check: must contain two dots
33    let parts: Vec<&str> = ir.version.split('.').collect();
34    if parts.len() != 3 || parts.iter().any(|p| p.parse::<u32>().is_err()) {
35        return Err(IrError::InvalidVersion(ir.version.clone()));
36    }
37    Ok(())
38}
39
40fn validate_no_duplicate_nodes(ir: &WorkflowIr) -> IrResult<()> {
41    // Node ids are the HashMap keys so duplicates are structurally impossible,
42    // but we check the stored ids match their keys.
43    for (key, node) in &ir.nodes {
44        if key != &node.id {
45            return Err(IrError::DuplicateNodeId(node.id.clone()));
46        }
47    }
48    Ok(())
49}
50
51fn validate_start_node(ir: &WorkflowIr) -> IrResult<()> {
52    if ir.start_node.is_empty() {
53        return Err(IrError::NoStartNode);
54    }
55    if !ir.nodes.contains_key(&ir.start_node) {
56        return Err(IrError::UnreachableNode(ir.start_node.clone()));
57    }
58    Ok(())
59}
60
61fn validate_edges(ir: &WorkflowIr) -> IrResult<()> {
62    for edge in &ir.edges {
63        if !ir.nodes.contains_key(&edge.to) && edge.to != "end" {
64            return Err(IrError::UnknownEdgeTarget {
65                from: edge.from.clone(),
66                to: edge.to.clone(),
67            });
68        }
69    }
70    Ok(())
71}
72
73fn validate_reachability(ir: &WorkflowIr) -> IrResult<()> {
74    // BFS from start_node
75    let mut visited: HashSet<&str> = HashSet::new();
76    let mut queue: VecDeque<&str> = VecDeque::new();
77    queue.push_back(&ir.start_node);
78    visited.insert(&ir.start_node);
79
80    while let Some(current) = queue.pop_front() {
81        for edge in ir.edges_from(current) {
82            let next = edge.to.as_str();
83            if next != "end" && !visited.contains(next) {
84                visited.insert(next);
85                queue.push_back(next);
86            }
87        }
88    }
89
90    // Check all nodes are reachable
91    for node_id in ir.nodes.keys() {
92        if !visited.contains(node_id.as_str()) {
93            return Err(IrError::UnreachableNode(node_id.clone()));
94        }
95    }
96    Ok(())
97}
98
99// `if !map.contains_key(...)` inside each match arm is intentional here —
100// reads more naturally than collapsing into match guards (which would require
101// an explicit success arm per variant and duplicate the destructuring).
102#[allow(clippy::collapsible_match, clippy::collapsible_if)]
103fn validate_refs(ir: &WorkflowIr) -> IrResult<()> {
104    use jamjet_core::node::NodeKind;
105
106    for (node_id, node) in &ir.nodes {
107        match &node.kind {
108            NodeKind::Tool { tool_ref, .. } => {
109                if !ir.tools.contains_key(tool_ref) {
110                    return Err(IrError::UnknownToolRef(node_id.clone(), tool_ref.clone()));
111                }
112            }
113            NodeKind::Model { model_ref, .. } => {
114                if !ir.models.contains_key(model_ref) {
115                    return Err(IrError::UnknownModelRef(node_id.clone(), model_ref.clone()));
116                }
117            }
118            NodeKind::McpTool { server, .. } => {
119                if !ir.mcp_servers.contains_key(server) {
120                    return Err(IrError::UnknownMcpServer(node_id.clone(), server.clone()));
121                }
122            }
123            NodeKind::A2aTask { remote_agent, .. } => {
124                if !ir.remote_agents.contains_key(remote_agent) {
125                    return Err(IrError::UnknownRemoteAgent(
126                        node_id.clone(),
127                        remote_agent.clone(),
128                    ));
129                }
130            }
131            _ => {}
132        }
133    }
134    Ok(())
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use crate::workflow::*;
141    use jamjet_core::node::NodeKind;
142    use jamjet_core::timeout::TimeoutConfig;
143    use std::collections::HashMap;
144
145    fn make_ir(start: &str, nodes: Vec<(&str, NodeKind)>, edges: Vec<(&str, &str)>) -> WorkflowIr {
146        let nodes_map: HashMap<String, NodeDef> = nodes
147            .into_iter()
148            .map(|(id, kind)| {
149                (
150                    id.to_string(),
151                    NodeDef {
152                        id: id.to_string(),
153                        kind,
154                        retry_policy: None,
155                        node_timeout_secs: None,
156                        description: None,
157                        labels: HashMap::new(),
158                        policy: None,
159                        data_policy: None,
160                    },
161                )
162            })
163            .collect();
164
165        let edges_vec: Vec<EdgeDef> = edges
166            .into_iter()
167            .map(|(from, to)| EdgeDef {
168                from: from.to_string(),
169                to: to.to_string(),
170                condition: None,
171            })
172            .collect();
173
174        WorkflowIr {
175            workflow_id: "test".into(),
176            version: "0.1.0".into(),
177            name: None,
178            description: None,
179            state_schema: "schemas.State".into(),
180            start_node: start.to_string(),
181            nodes: nodes_map,
182            edges: edges_vec,
183            retry_policies: HashMap::new(),
184            timeouts: TimeoutConfig::default(),
185            models: HashMap::new(),
186            tools: HashMap::new(),
187            mcp_servers: HashMap::new(),
188            remote_agents: HashMap::new(),
189            labels: HashMap::new(),
190            policy: None,
191            token_budget: None,
192            cost_budget_usd: None,
193            on_budget_exceeded: None,
194            data_policy: None,
195        }
196    }
197
198    #[test]
199    fn valid_simple_workflow() {
200        let cond_node = NodeKind::Condition { branches: vec![] };
201        let ir = make_ir("a", vec![("a", cond_node)], vec![("a", "end")]);
202        assert!(validate_workflow(&ir).is_ok());
203    }
204
205    #[test]
206    fn unknown_edge_target() {
207        let cond = NodeKind::Condition { branches: vec![] };
208        let ir = make_ir("a", vec![("a", cond)], vec![("a", "nonexistent")]);
209        let err = validate_workflow(&ir);
210        assert!(matches!(err, Err(IrError::UnknownEdgeTarget { .. })));
211    }
212
213    #[test]
214    fn unreachable_node() {
215        let cond = NodeKind::Condition { branches: vec![] };
216        let cond2 = NodeKind::Condition { branches: vec![] };
217        let ir = make_ir(
218            "a",
219            vec![("a", cond), ("orphan", cond2)],
220            vec![("a", "end")],
221        );
222        let err = validate_workflow(&ir);
223        assert!(matches!(err, Err(IrError::UnreachableNode(_))));
224    }
225
226    #[test]
227    fn invalid_semver() {
228        let mut ir = make_ir("a", vec![], vec![]);
229        ir.nodes.insert(
230            "a".into(),
231            NodeDef {
232                id: "a".into(),
233                kind: NodeKind::Condition { branches: vec![] },
234                retry_policy: None,
235                node_timeout_secs: None,
236                description: None,
237                labels: HashMap::new(),
238                policy: None,
239                data_policy: None,
240            },
241        );
242        ir.edges.push(EdgeDef {
243            from: "a".into(),
244            to: "end".into(),
245            condition: None,
246        });
247        ir.version = "not-semver".into();
248        let err = validate_workflow(&ir);
249        assert!(matches!(err, Err(IrError::InvalidVersion(_))));
250    }
251}