Skip to main content

orchestrator_collab/
dag.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5/// Directed acyclic graph definition used by collaboration workflows.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct WorkflowDag {
8    /// Stable workflow identifier.
9    pub id: String,
10    /// Human-readable workflow name.
11    pub name: String,
12    /// Nodes keyed by node identifier.
13    pub nodes: HashMap<String, WorkflowNode>,
14    /// Directed edges describing dependencies.
15    pub edges: Vec<WorkflowEdge>,
16}
17
18impl WorkflowDag {
19    /// Creates an empty DAG with the given identifier and name.
20    pub fn new(id: String, name: String) -> Self {
21        Self {
22            id,
23            name,
24            nodes: HashMap::new(),
25            edges: Vec::new(),
26        }
27    }
28
29    /// Inserts or replaces a node by its identifier.
30    pub fn add_node(&mut self, node: WorkflowNode) {
31        self.nodes.insert(node.id.clone(), node);
32    }
33
34    /// Appends a directed edge.
35    pub fn add_edge(&mut self, edge: WorkflowEdge) {
36        self.edges.push(edge);
37    }
38
39    /// Returns node identifiers that have no incoming edges.
40    pub fn get_entry_nodes(&self) -> Vec<&String> {
41        let targets: std::collections::HashSet<_> = self.edges.iter().map(|e| &e.to).collect();
42
43        self.nodes.keys().filter(|k| !targets.contains(k)).collect()
44    }
45
46    /// Returns nodes whose dependencies have all been completed.
47    pub fn get_ready_nodes(&self, completed: &std::collections::HashSet<String>) -> Vec<String> {
48        self.nodes
49            .keys()
50            .filter(|k| !completed.contains(*k))
51            .filter(|k| {
52                let deps = self.get_dependencies(k);
53                deps.iter().all(|d| completed.contains(d))
54            })
55            .cloned()
56            .collect()
57    }
58
59    fn get_dependencies(&self, node_id: &str) -> Vec<String> {
60        self.edges
61            .iter()
62            .filter(|e| e.to == node_id)
63            .map(|e| e.from.clone())
64            .collect()
65    }
66}
67
68/// Executable node inside a collaboration DAG.
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct WorkflowNode {
71    /// Stable node identifier.
72    pub id: String,
73    /// Step kind associated with the node.
74    pub step_type: StepType,
75    /// Agent-selection requirements for the node.
76    pub agent_requirement: AgentRequirement,
77    /// Optional prehook expression that gates execution.
78    pub prehook: Option<String>,
79    /// Runtime execution settings for the node.
80    pub config: NodeConfig,
81}
82
83/// Agent-selection constraints attached to a DAG node.
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct AgentRequirement {
86    /// Required capability for candidate agents.
87    pub capability: Option<String>,
88    /// Preferred agent identifiers ranked ahead of others.
89    pub preferred_agents: Vec<String>,
90    /// Optional minimum historical success rate for selection.
91    pub min_success_rate: Option<f32>,
92}
93
94/// Runtime configuration for a DAG node.
95#[derive(Debug, Clone, Serialize, Deserialize, Default)]
96pub struct NodeConfig {
97    /// Optional timeout in milliseconds.
98    pub timeout_ms: Option<u64>,
99    /// Enables retry behavior for node execution.
100    pub retry_enabled: bool,
101    /// Maximum retry count when retries are enabled.
102    pub max_retries: u32,
103}
104
105/// Directed edge between two workflow nodes.
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct WorkflowEdge {
108    /// Upstream node identifier.
109    pub from: String,
110    /// Downstream node identifier.
111    pub to: String,
112    /// Optional expression that must pass for the edge to activate.
113    pub condition: Option<String>,
114    /// Optional transform applied to upstream output before passing it forward.
115    pub transform: Option<OutputTransform>,
116}
117
118/// Mapping from upstream output into downstream shared state.
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct OutputTransform {
121    /// Source phase used to select upstream output.
122    pub source_phase: String,
123    /// Extraction strategy applied to the source output.
124    pub extraction: OutputExtraction,
125    /// Shared-state key populated on the downstream node.
126    pub target_key: String,
127}
128
129/// Supported output extraction strategies for DAG edges.
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub enum OutputExtraction {
132    /// Forward all artifacts from the source phase.
133    AllArtifacts,
134    /// Forward artifacts matching a single artifact kind string.
135    ArtifactKind(String),
136    /// Forward only the last `N` artifacts.
137    LastN(u32),
138    /// Apply a custom filter expression.
139    Filter(String),
140}
141
142/// Logical step kinds used by collaboration DAG nodes.
143#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
144pub enum StepType {
145    /// One-time initialization step.
146    InitOnce,
147    /// Quality-assurance step.
148    Qa,
149    /// Ticket-scanning step.
150    TicketScan,
151    /// Remediation or implementation step.
152    Fix,
153    /// Re-test step after a fix.
154    Retest,
155    /// Loop-guard or termination-check step.
156    LoopGuard,
157    /// User-defined custom step type.
158    Custom(String),
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[test]
166    fn test_workflow_dag_entry_nodes() {
167        let mut dag = WorkflowDag::new("test".to_string(), "Test Workflow".to_string());
168
169        dag.add_node(WorkflowNode {
170            id: "start".to_string(),
171            step_type: StepType::InitOnce,
172            agent_requirement: AgentRequirement {
173                capability: None,
174                preferred_agents: vec![],
175                min_success_rate: None,
176            },
177            prehook: None,
178            config: NodeConfig::default(),
179        });
180
181        dag.add_node(WorkflowNode {
182            id: "qa".to_string(),
183            step_type: StepType::Qa,
184            agent_requirement: AgentRequirement {
185                capability: Some("qa".to_string()),
186                preferred_agents: vec![],
187                min_success_rate: None,
188            },
189            prehook: None,
190            config: NodeConfig::default(),
191        });
192
193        dag.add_edge(WorkflowEdge {
194            from: "start".to_string(),
195            to: "qa".to_string(),
196            condition: None,
197            transform: None,
198        });
199
200        let entries = dag.get_entry_nodes();
201        assert_eq!(entries.len(), 1);
202        assert_eq!(entries[0], "start");
203    }
204
205    #[test]
206    fn test_workflow_dag_get_ready_nodes() {
207        let mut dag = WorkflowDag::new("test".to_string(), "Test".to_string());
208
209        dag.add_node(WorkflowNode {
210            id: "a".to_string(),
211            step_type: StepType::InitOnce,
212            agent_requirement: AgentRequirement {
213                capability: None,
214                preferred_agents: vec![],
215                min_success_rate: None,
216            },
217            prehook: None,
218            config: NodeConfig::default(),
219        });
220        dag.add_node(WorkflowNode {
221            id: "b".to_string(),
222            step_type: StepType::Qa,
223            agent_requirement: AgentRequirement {
224                capability: None,
225                preferred_agents: vec![],
226                min_success_rate: None,
227            },
228            prehook: None,
229            config: NodeConfig::default(),
230        });
231        dag.add_edge(WorkflowEdge {
232            from: "a".to_string(),
233            to: "b".to_string(),
234            condition: None,
235            transform: None,
236        });
237
238        let completed = std::collections::HashSet::new();
239        let ready = dag.get_ready_nodes(&completed);
240        assert_eq!(ready.len(), 1);
241        assert!(ready.contains(&"a".to_string()));
242
243        let mut completed = std::collections::HashSet::new();
244        completed.insert("a".to_string());
245        let ready = dag.get_ready_nodes(&completed);
246        assert_eq!(ready.len(), 1);
247        assert!(ready.contains(&"b".to_string()));
248    }
249}