Skip to main content

scud/attractor/
graph.rs

1//! Internal pipeline graph representation using petgraph.
2
3use anyhow::{Context, Result};
4use petgraph::graph::{DiGraph, NodeIndex};
5use petgraph::visit::EdgeRef;
6use std::collections::HashMap;
7use std::time::Duration;
8
9use super::dot_parser::{AttrValue, DotGraph};
10
11/// The internal pipeline graph built from a parsed DOT file.
12#[derive(Debug)]
13pub struct PipelineGraph {
14    pub name: String,
15    pub graph_attrs: GraphAttrs,
16    pub graph: DiGraph<PipelineNode, PipelineEdge>,
17    pub node_index: HashMap<String, NodeIndex>,
18    pub start_node: NodeIndex,
19    pub exit_node: NodeIndex,
20}
21
22/// Graph-level attributes.
23#[derive(Debug, Clone, Default)]
24pub struct GraphAttrs {
25    pub goal: Option<String>,
26    pub fidelity: Option<FidelityMode>,
27    pub model_stylesheet: Option<String>,
28    pub extra: HashMap<String, String>,
29}
30
31/// A node in the pipeline graph.
32#[derive(Debug, Clone)]
33pub struct PipelineNode {
34    pub id: String,
35    pub label: String,
36    pub shape: String,
37    pub handler_type: String,
38    pub prompt: String,
39    pub max_retries: u32,
40    pub goal_gate: bool,
41    pub retry_target: Option<String>,
42    pub fallback_retry_target: Option<String>,
43    pub fidelity: Option<FidelityMode>,
44    pub thread_id: Option<String>,
45    pub classes: Vec<String>,
46    pub timeout: Option<Duration>,
47    pub llm_model: Option<String>,
48    pub llm_provider: Option<String>,
49    pub reasoning_effort: String,
50    pub auto_status: bool,
51    pub allow_partial: bool,
52    pub extra_attrs: HashMap<String, AttrValue>,
53}
54
55impl Default for PipelineNode {
56    fn default() -> Self {
57        Self {
58            id: String::new(),
59            label: String::new(),
60            shape: "box".into(),
61            handler_type: "codergen".into(),
62            prompt: String::new(),
63            max_retries: 0,
64            goal_gate: false,
65            retry_target: None,
66            fallback_retry_target: None,
67            fidelity: None,
68            thread_id: None,
69            classes: vec![],
70            timeout: None,
71            llm_model: None,
72            llm_provider: None,
73            reasoning_effort: "high".into(),
74            auto_status: true,
75            allow_partial: false,
76            extra_attrs: HashMap::new(),
77        }
78    }
79}
80
81/// An edge in the pipeline graph.
82#[derive(Debug, Clone)]
83pub struct PipelineEdge {
84    pub label: String,
85    pub condition: String,
86    pub weight: i32,
87    pub fidelity: Option<FidelityMode>,
88    pub thread_id: Option<String>,
89    pub loop_restart: bool,
90}
91
92impl Default for PipelineEdge {
93    fn default() -> Self {
94        Self {
95            label: String::new(),
96            condition: String::new(),
97            weight: 0,
98            fidelity: None,
99            thread_id: None,
100            loop_restart: false,
101        }
102    }
103}
104
105/// Fidelity mode for context passing.
106#[derive(Debug, Clone, PartialEq)]
107pub enum FidelityMode {
108    Full,
109    Truncate,
110    Compact,
111    Summary(SummaryLevel),
112}
113
114/// Summary detail level.
115#[derive(Debug, Clone, PartialEq)]
116pub enum SummaryLevel {
117    Low,
118    Medium,
119    High,
120}
121
122impl FidelityMode {
123    pub fn from_str(s: &str) -> Option<Self> {
124        match s.to_lowercase().as_str() {
125            "full" => Some(FidelityMode::Full),
126            "truncate" => Some(FidelityMode::Truncate),
127            "compact" => Some(FidelityMode::Compact),
128            "summary" | "summary-medium" => Some(FidelityMode::Summary(SummaryLevel::Medium)),
129            "summary-low" => Some(FidelityMode::Summary(SummaryLevel::Low)),
130            "summary-high" => Some(FidelityMode::Summary(SummaryLevel::High)),
131            _ => None,
132        }
133    }
134}
135
136/// Shape-to-handler mapping per spec Section 2.8.
137fn handler_type_from_shape(shape: &str) -> &str {
138    match shape.to_lowercase().as_str() {
139        "mdiamond" => "start",
140        "msquare" => "exit",
141        "box" | "rect" | "rectangle" => "codergen",
142        "hexagon" => "wait.human",
143        "diamond" => "conditional",
144        "component" => "parallel",
145        "tripleoctagon" => "parallel.fan_in",
146        "parallelogram" => "tool",
147        "house" => "stack.manager_loop",
148        _ => "codergen", // default
149    }
150}
151
152impl PipelineGraph {
153    /// Build a PipelineGraph from a parsed DotGraph.
154    pub fn from_dot(dot: &DotGraph) -> Result<Self> {
155        let mut graph = DiGraph::new();
156        let mut node_index = HashMap::new();
157
158        // Extract graph-level attrs
159        let graph_attrs = GraphAttrs {
160            goal: dot.graph_attrs.get("goal").map(|v| v.as_str()),
161            fidelity: dot
162                .graph_attrs
163                .get("fidelity")
164                .and_then(|v| FidelityMode::from_str(&v.as_str())),
165            model_stylesheet: dot.graph_attrs.get("model_stylesheet").map(|v| v.as_str()),
166            extra: dot
167                .graph_attrs
168                .iter()
169                .filter(|(k, _)| !["goal", "fidelity", "model_stylesheet"].contains(&k.as_str()))
170                .map(|(k, v)| (k.clone(), v.as_str()))
171                .collect(),
172        };
173
174        // Collect all node IDs (from explicit nodes, edges, and subgraphs)
175        let mut all_node_ids: Vec<String> = Vec::new();
176        for node in &dot.nodes {
177            if !all_node_ids.contains(&node.id) {
178                all_node_ids.push(node.id.clone());
179            }
180        }
181        for edge in &dot.edges {
182            if !all_node_ids.contains(&edge.from) {
183                all_node_ids.push(edge.from.clone());
184            }
185            if !all_node_ids.contains(&edge.to) {
186                all_node_ids.push(edge.to.clone());
187            }
188        }
189        for sg in &dot.subgraphs {
190            for node in &sg.nodes {
191                if !all_node_ids.contains(&node.id) {
192                    all_node_ids.push(node.id.clone());
193                }
194            }
195            for edge in &sg.edges {
196                if !all_node_ids.contains(&edge.from) {
197                    all_node_ids.push(edge.from.clone());
198                }
199                if !all_node_ids.contains(&edge.to) {
200                    all_node_ids.push(edge.to.clone());
201                }
202            }
203        }
204
205        // Build lookup for node attrs
206        let mut node_attrs_map: HashMap<String, HashMap<String, AttrValue>> = HashMap::new();
207        for node in &dot.nodes {
208            node_attrs_map.insert(node.id.clone(), node.attrs.clone());
209        }
210        for sg in &dot.subgraphs {
211            for node in &sg.nodes {
212                node_attrs_map.insert(node.id.clone(), node.attrs.clone());
213            }
214        }
215
216        // Create petgraph nodes
217        for id in &all_node_ids {
218            let attrs = node_attrs_map.get(id).cloned().unwrap_or_default();
219            let merged_attrs = merge_with_defaults(&attrs, &dot.node_defaults);
220            let pipeline_node = build_pipeline_node(id, &merged_attrs);
221            let idx = graph.add_node(pipeline_node);
222            node_index.insert(id.clone(), idx);
223        }
224
225        // Add edges
226        let all_edges: Vec<_> = dot
227            .edges
228            .iter()
229            .chain(dot.subgraphs.iter().flat_map(|sg| sg.edges.iter()))
230            .collect();
231
232        for edge in all_edges {
233            let from_idx = *node_index
234                .get(&edge.from)
235                .context(format!("Edge source '{}' not found", edge.from))?;
236            let to_idx = *node_index
237                .get(&edge.to)
238                .context(format!("Edge target '{}' not found", edge.to))?;
239
240            let merged = merge_with_defaults(&edge.attrs, &dot.edge_defaults);
241            let pipeline_edge = build_pipeline_edge(&merged);
242            graph.add_edge(from_idx, to_idx, pipeline_edge);
243        }
244
245        // Find start and exit nodes
246        let start_node = find_node_by_handler(&graph, &node_index, "start")
247            .context("No start node found (need a node with shape=Mdiamond)")?;
248        let exit_node = find_node_by_handler(&graph, &node_index, "exit")
249            .context("No exit node found (need a node with shape=Msquare)")?;
250
251        Ok(PipelineGraph {
252            name: dot.name.clone(),
253            graph_attrs,
254            graph,
255            node_index,
256            start_node,
257            exit_node,
258        })
259    }
260
261    /// Get a node by its string ID.
262    pub fn node(&self, id: &str) -> Option<&PipelineNode> {
263        self.node_index.get(id).map(|idx| &self.graph[*idx])
264    }
265
266    /// Get all outgoing edges from a node.
267    pub fn outgoing_edges(&self, idx: NodeIndex) -> Vec<(NodeIndex, &PipelineEdge)> {
268        self.graph
269            .edges(idx)
270            .map(|e| (e.target(), e.weight()))
271            .collect()
272    }
273
274    /// Get the node IDs in topological order.
275    pub fn topo_order(&self) -> Result<Vec<NodeIndex>> {
276        petgraph::algo::toposort(&self.graph, None)
277            .map_err(|_| anyhow::anyhow!("Pipeline graph contains a cycle"))
278    }
279}
280
281fn merge_with_defaults(
282    attrs: &HashMap<String, AttrValue>,
283    defaults: &HashMap<String, AttrValue>,
284) -> HashMap<String, AttrValue> {
285    let mut merged = defaults.clone();
286    for (k, v) in attrs {
287        merged.insert(k.clone(), v.clone());
288    }
289    merged
290}
291
292fn build_pipeline_node(id: &str, attrs: &HashMap<String, AttrValue>) -> PipelineNode {
293    let shape = attrs
294        .get("shape")
295        .map(|v| v.as_str())
296        .unwrap_or_else(|| "box".into());
297
298    let explicit_type = attrs.get("type").map(|v| v.as_str());
299    let handler_type = explicit_type.unwrap_or_else(|| handler_type_from_shape(&shape).into());
300
301    let label = attrs
302        .get("label")
303        .map(|v| v.as_str())
304        .unwrap_or_else(|| id.to_string());
305
306    let classes = attrs
307        .get("class")
308        .map(|v| v.as_str().split_whitespace().map(String::from).collect())
309        .unwrap_or_default();
310
311    let mut extra_attrs = HashMap::new();
312    let known_keys = [
313        "shape",
314        "type",
315        "label",
316        "prompt",
317        "max_retries",
318        "goal_gate",
319        "retry_target",
320        "fallback_retry_target",
321        "fidelity",
322        "thread_id",
323        "class",
324        "timeout",
325        "llm_model",
326        "llm_provider",
327        "reasoning_effort",
328        "auto_status",
329        "allow_partial",
330    ];
331    for (k, v) in attrs {
332        if !known_keys.contains(&k.as_str()) {
333            extra_attrs.insert(k.clone(), v.clone());
334        }
335    }
336
337    PipelineNode {
338        id: id.to_string(),
339        label,
340        shape,
341        handler_type,
342        prompt: attrs.get("prompt").map(|v| v.as_str()).unwrap_or_default(),
343        max_retries: attrs
344            .get("max_retries")
345            .and_then(|v| v.as_int())
346            .unwrap_or(0) as u32,
347        goal_gate: attrs
348            .get("goal_gate")
349            .and_then(|v| v.as_bool())
350            .unwrap_or(false),
351        retry_target: attrs.get("retry_target").map(|v| v.as_str()),
352        fallback_retry_target: attrs.get("fallback_retry_target").map(|v| v.as_str()),
353        fidelity: attrs
354            .get("fidelity")
355            .and_then(|v| FidelityMode::from_str(&v.as_str())),
356        thread_id: attrs.get("thread_id").map(|v| v.as_str()),
357        classes,
358        timeout: attrs.get("timeout").and_then(|v| match v {
359            AttrValue::Duration(d) => Some(*d),
360            _ => None,
361        }),
362        llm_model: attrs.get("llm_model").map(|v| v.as_str()),
363        llm_provider: attrs.get("llm_provider").map(|v| v.as_str()),
364        reasoning_effort: attrs
365            .get("reasoning_effort")
366            .map(|v| v.as_str())
367            .unwrap_or_else(|| "high".into()),
368        auto_status: attrs
369            .get("auto_status")
370            .and_then(|v| v.as_bool())
371            .unwrap_or(true),
372        allow_partial: attrs
373            .get("allow_partial")
374            .and_then(|v| v.as_bool())
375            .unwrap_or(false),
376        extra_attrs,
377    }
378}
379
380fn build_pipeline_edge(attrs: &HashMap<String, AttrValue>) -> PipelineEdge {
381    PipelineEdge {
382        label: attrs.get("label").map(|v| v.as_str()).unwrap_or_default(),
383        condition: attrs
384            .get("condition")
385            .map(|v| v.as_str())
386            .unwrap_or_default(),
387        weight: attrs.get("weight").and_then(|v| v.as_int()).unwrap_or(0) as i32,
388        fidelity: attrs
389            .get("fidelity")
390            .and_then(|v| FidelityMode::from_str(&v.as_str())),
391        thread_id: attrs.get("thread_id").map(|v| v.as_str()),
392        loop_restart: attrs
393            .get("loop_restart")
394            .and_then(|v| v.as_bool())
395            .unwrap_or(false),
396    }
397}
398
399fn find_node_by_handler(
400    graph: &DiGraph<PipelineNode, PipelineEdge>,
401    node_index: &HashMap<String, NodeIndex>,
402    handler: &str,
403) -> Option<NodeIndex> {
404    node_index
405        .values()
406        .copied()
407        .find(|idx| graph[*idx].handler_type == handler)
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413    use crate::attractor::dot_parser::parse_dot;
414
415    #[test]
416    fn test_build_simple_pipeline() {
417        let input = r#"
418        digraph pipeline {
419            graph [goal="Build feature X"]
420            start [shape=Mdiamond]
421            task_a [shape=box, label="Implement A", prompt="Write the code for A"]
422            finish [shape=Msquare]
423            start -> task_a -> finish
424        }
425        "#;
426        let dot = parse_dot(input).unwrap();
427        let pipeline = PipelineGraph::from_dot(&dot).unwrap();
428
429        assert_eq!(pipeline.name, "pipeline");
430        assert_eq!(pipeline.graph_attrs.goal, Some("Build feature X".into()));
431        assert_eq!(pipeline.graph.node_count(), 3);
432        assert_eq!(pipeline.graph.edge_count(), 2);
433
434        let start = &pipeline.graph[pipeline.start_node];
435        assert_eq!(start.handler_type, "start");
436
437        let exit = &pipeline.graph[pipeline.exit_node];
438        assert_eq!(exit.handler_type, "exit");
439
440        let task = pipeline.node("task_a").unwrap();
441        assert_eq!(task.handler_type, "codergen");
442        assert_eq!(task.prompt, "Write the code for A");
443    }
444
445    #[test]
446    fn test_shape_to_handler_mapping() {
447        assert_eq!(handler_type_from_shape("Mdiamond"), "start");
448        assert_eq!(handler_type_from_shape("Msquare"), "exit");
449        assert_eq!(handler_type_from_shape("box"), "codergen");
450        assert_eq!(handler_type_from_shape("hexagon"), "wait.human");
451        assert_eq!(handler_type_from_shape("diamond"), "conditional");
452        assert_eq!(handler_type_from_shape("component"), "parallel");
453        assert_eq!(handler_type_from_shape("tripleoctagon"), "parallel.fan_in");
454        assert_eq!(handler_type_from_shape("parallelogram"), "tool");
455        assert_eq!(handler_type_from_shape("house"), "stack.manager_loop");
456    }
457
458    #[test]
459    fn test_outgoing_edges() {
460        let input = r#"
461        digraph test {
462            start [shape=Mdiamond]
463            a [shape=box]
464            b [shape=box]
465            finish [shape=Msquare]
466            start -> a [label="go"]
467            start -> b [label="alt"]
468            a -> finish
469            b -> finish
470        }
471        "#;
472        let dot = parse_dot(input).unwrap();
473        let pipeline = PipelineGraph::from_dot(&dot).unwrap();
474
475        let edges = pipeline.outgoing_edges(pipeline.start_node);
476        assert_eq!(edges.len(), 2);
477    }
478
479    #[test]
480    fn test_node_defaults_applied() {
481        let input = r#"
482        digraph test {
483            node [reasoning_effort="medium"]
484            start [shape=Mdiamond]
485            a [shape=box]
486            finish [shape=Msquare]
487            start -> a -> finish
488        }
489        "#;
490        let dot = parse_dot(input).unwrap();
491        let pipeline = PipelineGraph::from_dot(&dot).unwrap();
492        let a = pipeline.node("a").unwrap();
493        assert_eq!(a.reasoning_effort, "medium");
494    }
495}