Skip to main content

a3s_flow/
graph.rs

1//! DAG graph representation and validation.
2//!
3//! [`DagGraph`] parses a JSON flow definition, builds a directed acyclic graph,
4//! validates it (no cycles, all referenced node IDs exist), and exposes
5//! topological-order iteration for the runner.
6
7use petgraph::algo::toposort;
8use petgraph::graph::{DiGraph, NodeIndex};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::collections::HashMap;
12
13use crate::condition::Condition;
14use crate::error::{FlowError, Result};
15use crate::node::RetryPolicy;
16
17/// A single node declaration inside a flow definition.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct NodeDef {
20    /// Unique identifier within the flow (e.g. `"llm_1"`, `"http_call"`).
21    pub id: String,
22    /// The node type, used to resolve the executor (e.g. `"noop"`, `"llm"`, `"http"`).
23    #[serde(rename = "type")]
24    pub node_type: String,
25    /// Static configuration for this node (prompt template, URL, etc.).
26    /// Mirrors Dify's `data` field.
27    #[serde(default)]
28    pub data: Value,
29    /// Optional guard condition extracted from `data["run_if"]` during `from_json`.
30    /// Not present in the wire format — lives inside `data`.
31    #[serde(skip)]
32    pub(crate) run_if: Option<Condition>,
33    /// Optional retry policy extracted from `data["retry"]` during `from_json`.
34    #[serde(skip)]
35    pub(crate) retry: Option<RetryPolicy>,
36    /// Optional execution timeout in milliseconds extracted from `data["timeout_ms"]`.
37    #[serde(skip)]
38    pub(crate) timeout_ms: Option<u64>,
39    /// If true, a node failure produces `{"__error__": "reason"}` as output
40    /// instead of halting the flow. Parsed from `data["continue_on_error"]`.
41    #[serde(skip)]
42    pub(crate) continue_on_error: bool,
43    /// If true, the runner merges the node's output object into the live
44    /// variable map after the wave completes, making the values available to
45    /// all downstream nodes in `ctx.variables`.
46    ///
47    /// Set automatically for `node_type == "assign"`.
48    #[serde(skip)]
49    pub(crate) write_to_variables: bool,
50}
51
52/// A directed edge from one node to another in a flow definition.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct EdgeDef {
55    /// ID of the upstream (source) node.
56    pub source: String,
57    /// ID of the downstream (target) node.
58    pub target: String,
59}
60
61/// Internal: the top-level shape of a flow definition JSON object.
62#[derive(Deserialize)]
63struct FlowDef {
64    nodes: Vec<NodeDef>,
65    #[serde(default)]
66    edges: Vec<EdgeDef>,
67}
68
69/// A parsed and validated directed acyclic graph representing a workflow.
70///
71/// `Clone` is implemented so the `"iteration"` node can share one parsed
72/// sub-flow DAG across concurrent loop iterations without re-parsing.
73///
74/// Constructed via [`DagGraph::from_json`]. After construction the graph is
75/// guaranteed to be acyclic and self-consistent (all edge source/target IDs exist).
76#[derive(Clone)]
77pub struct DagGraph {
78    /// Node definitions, keyed by node ID.
79    pub(crate) nodes: HashMap<String, NodeDef>,
80    /// Topologically sorted node IDs (sources first, sinks last).
81    pub(crate) topo_order: Vec<String>,
82    /// Adjacency: node ID → list of direct dependency node IDs.
83    pub(crate) deps: HashMap<String, Vec<String>>,
84}
85
86impl DagGraph {
87    /// Parse and validate a flow definition from a JSON value.
88    ///
89    /// The expected shape is a `{ "nodes": [...], "edges": [...] }` object
90    /// (Dify-compatible format):
91    ///
92    /// ```json
93    /// {
94    ///   "nodes": [
95    ///     { "id": "fetch", "type": "http-request", "data": { "url": "https://api.example.com" } },
96    ///     { "id": "notify", "type": "noop", "data": { "run_if": { "from": "fetch", "path": "ok", "op": "eq", "value": true } } }
97    ///   ],
98    ///   "edges": [
99    ///     { "source": "fetch", "target": "notify" }
100    ///   ]
101    /// }
102    /// ```
103    pub fn from_json(value: &Value) -> Result<Self> {
104        let raw: FlowDef = serde_json::from_value(value.clone())
105            .map_err(|e| FlowError::InvalidDefinition(e.to_string()))?;
106
107        if raw.nodes.is_empty() {
108            return Err(FlowError::InvalidDefinition(
109                "flow must contain at least one node".into(),
110            ));
111        }
112
113        // Build id → NodeDef map, extracting run_if from data["run_if"].
114        let mut nodes: HashMap<String, NodeDef> = HashMap::new();
115        for mut def in raw.nodes {
116            if nodes.contains_key(&def.id) {
117                return Err(FlowError::InvalidDefinition(format!(
118                    "duplicate node id: {}",
119                    def.id
120                )));
121            }
122            if let Some(run_if_val) = def.data.get("run_if") {
123                def.run_if = serde_json::from_value(run_if_val.clone()).map_err(|e| {
124                    FlowError::InvalidDefinition(format!("node '{}': invalid run_if: {e}", def.id))
125                })?;
126            }
127            if let Some(retry_val) = def.data.get("retry") {
128                def.retry = serde_json::from_value(retry_val.clone()).map_err(|e| {
129                    FlowError::InvalidDefinition(format!("node '{}': invalid retry: {e}", def.id))
130                })?;
131            }
132            if let Some(ms) = def.data.get("timeout_ms").and_then(|v| v.as_u64()) {
133                def.timeout_ms = Some(ms);
134            }
135            if let Some(true) = def.data.get("continue_on_error").and_then(|v| v.as_bool()) {
136                def.continue_on_error = true;
137            }
138            if def.node_type == "assign" {
139                def.write_to_variables = true;
140            }
141            nodes.insert(def.id.clone(), def);
142        }
143
144        // Build deps from edges (target → [sources]). Validate all IDs exist.
145        let mut deps: HashMap<String, Vec<String>> =
146            nodes.keys().map(|id| (id.clone(), vec![])).collect();
147        for edge in &raw.edges {
148            if !nodes.contains_key(&edge.source) {
149                return Err(FlowError::UnknownNode(edge.source.clone()));
150            }
151            if !nodes.contains_key(&edge.target) {
152                return Err(FlowError::UnknownNode(edge.target.clone()));
153            }
154            deps.entry(edge.target.clone())
155                .or_default()
156                .push(edge.source.clone());
157        }
158
159        // Build petgraph DiGraph for cycle detection and topological sort.
160        let mut graph: DiGraph<String, ()> = DiGraph::new();
161        let mut id_to_idx: HashMap<String, NodeIndex> = HashMap::new();
162
163        for id in nodes.keys() {
164            let idx = graph.add_node(id.clone());
165            id_to_idx.insert(id.clone(), idx);
166        }
167
168        for (target, sources) in &deps {
169            let to = id_to_idx[target];
170            for source in sources {
171                let from = id_to_idx[source];
172                graph.add_edge(from, to, ());
173            }
174        }
175
176        let sorted = toposort(&graph, None).map_err(|_| FlowError::CyclicGraph)?;
177
178        let topo_order: Vec<String> = sorted.into_iter().map(|idx| graph[idx].clone()).collect();
179
180        Ok(Self {
181            nodes,
182            topo_order,
183            deps,
184        })
185    }
186
187    /// Returns node definitions in topological order (dependencies first).
188    pub fn nodes_in_order(&self) -> impl Iterator<Item = &NodeDef> {
189        self.topo_order.iter().map(move |id| &self.nodes[id])
190    }
191
192    /// Returns the direct dependencies of a node by ID.
193    pub fn dependencies_of(&self, id: &str) -> &[String] {
194        self.deps.get(id).map(Vec::as_slice).unwrap_or(&[])
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use serde_json::json;
202
203    #[test]
204    fn parse_simple_linear_dag() {
205        let def = json!({
206            "nodes": [
207                { "id": "a", "type": "noop" },
208                { "id": "b", "type": "noop" },
209                { "id": "c", "type": "noop" }
210            ],
211            "edges": [
212                { "source": "a", "target": "b" },
213                { "source": "b", "target": "c" }
214            ]
215        });
216        let dag = DagGraph::from_json(&def).unwrap();
217        let order: Vec<&str> = dag.nodes_in_order().map(|n| n.id.as_str()).collect();
218        // "a" must come before "b", "b" before "c"
219        assert!(order.iter().position(|&x| x == "a") < order.iter().position(|&x| x == "b"));
220        assert!(order.iter().position(|&x| x == "b") < order.iter().position(|&x| x == "c"));
221    }
222
223    #[test]
224    fn rejects_cyclic_graph() {
225        let def = json!({
226            "nodes": [
227                { "id": "a", "type": "noop" },
228                { "id": "b", "type": "noop" }
229            ],
230            "edges": [
231                { "source": "b", "target": "a" },
232                { "source": "a", "target": "b" }
233            ]
234        });
235        assert!(matches!(
236            DagGraph::from_json(&def),
237            Err(FlowError::CyclicGraph)
238        ));
239    }
240
241    #[test]
242    fn rejects_unknown_dependency() {
243        let def = json!({
244            "nodes": [
245                { "id": "a", "type": "noop" }
246            ],
247            "edges": [
248                { "source": "nonexistent", "target": "a" }
249            ]
250        });
251        assert!(matches!(
252            DagGraph::from_json(&def),
253            Err(FlowError::UnknownNode(_))
254        ));
255    }
256
257    #[test]
258    fn rejects_duplicate_node_ids() {
259        let def = json!({
260            "nodes": [
261                { "id": "a", "type": "noop" },
262                { "id": "a", "type": "noop" }
263            ],
264            "edges": []
265        });
266        assert!(matches!(
267            DagGraph::from_json(&def),
268            Err(FlowError::InvalidDefinition(_))
269        ));
270    }
271
272    #[test]
273    fn rejects_empty_flow() {
274        let def = json!({ "nodes": [], "edges": [] });
275        assert!(matches!(
276            DagGraph::from_json(&def),
277            Err(FlowError::InvalidDefinition(_))
278        ));
279    }
280}