Skip to main content

deepstrike_core/orchestration/
task_graph.rs

1use crate::types::error::{DeepStrikeError, Result};
2use crate::types::result::LoopResult;
3use crate::types::task::RuntimeTask;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum TaskStatus {
7    Pending,
8    Ready,
9    Running,
10    Completed,
11    Failed,
12}
13
14#[derive(Debug, Clone)]
15pub struct TaskNode {
16    pub id: usize,
17    pub task: RuntimeTask,
18    pub status: TaskStatus,
19    pub result: Option<LoopResult>,
20    pub dependencies: Vec<usize>,
21}
22
23/// DAG of tasks with dependency tracking.
24/// Maintains an in-degree counter so `ready_tasks()` is O(1) amortized.
25pub struct TaskGraph {
26    nodes: Vec<TaskNode>,
27    /// Number of incomplete dependencies per task.
28    in_degree: Vec<usize>,
29}
30
31impl TaskGraph {
32    pub fn new() -> Self {
33        Self {
34            nodes: Vec::new(),
35            in_degree: Vec::new(),
36        }
37    }
38
39    /// Add a task, returns its ID.
40    pub fn add(&mut self, task: RuntimeTask, dependencies: Vec<usize>) -> usize {
41        let id = self.nodes.len();
42        let deg = dependencies.len();
43        self.nodes.push(TaskNode {
44            id,
45            task,
46            status: if deg == 0 {
47                TaskStatus::Ready
48            } else {
49                TaskStatus::Pending
50            },
51            result: None,
52            dependencies,
53        });
54        self.in_degree.push(deg);
55        id
56    }
57
58    /// Topological sort — returns ordered IDs or error if cycle detected.
59    pub fn topological_sort(&self) -> Result<Vec<usize>> {
60        let n = self.nodes.len();
61        let mut in_deg = self.in_degree.clone();
62        let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
63
64        for node in &self.nodes {
65            for &dep in &node.dependencies {
66                adj[dep].push(node.id);
67            }
68        }
69
70        let mut queue: Vec<usize> = (0..n).filter(|&i| in_deg[i] == 0).collect();
71        let mut order = Vec::with_capacity(n);
72
73        while let Some(id) = queue.pop() {
74            order.push(id);
75            for &next in &adj[id] {
76                in_deg[next] -= 1;
77                if in_deg[next] == 0 {
78                    queue.push(next);
79                }
80            }
81        }
82
83        if order.len() != n {
84            return Err(DeepStrikeError::OrchestrationCycle);
85        }
86        Ok(order)
87    }
88
89    /// Return IDs of tasks that are Ready (deps satisfied, not yet started).
90    pub fn ready_tasks(&self) -> Vec<usize> {
91        self.nodes
92            .iter()
93            .filter(|n| n.status == TaskStatus::Ready)
94            .map(|n| n.id)
95            .collect()
96    }
97
98    /// Mark a task as running.
99    pub fn start(&mut self, task_id: usize) {
100        if let Some(node) = self.nodes.get_mut(task_id) {
101            node.status = TaskStatus::Running;
102        }
103    }
104
105    /// Re-mark a (running) task as Ready without touching dependents — used to re-arm a loop node
106    /// for its next iteration. Unlike [`complete`](Self::complete), this does NOT decrement any
107    /// in-degree, so the loop node's dependents stay pending until the loop finally `complete`s.
108    pub fn set_ready(&mut self, task_id: usize) {
109        if let Some(node) = self.nodes.get_mut(task_id) {
110            node.status = TaskStatus::Ready;
111        }
112    }
113
114    /// Mark a task as completed; promote dependents whose in-degree reaches 0.
115    pub fn complete(&mut self, task_id: usize, result: LoopResult) {
116        if let Some(node) = self.nodes.get_mut(task_id) {
117            node.status = TaskStatus::Completed;
118            node.result = Some(result);
119        }
120        // Collect dependents first to avoid borrow conflict
121        let dependents: Vec<usize> = self
122            .nodes
123            .iter()
124            .filter(|n| n.dependencies.contains(&task_id))
125            .map(|n| n.id)
126            .collect();
127        for dep_id in dependents {
128            self.in_degree[dep_id] -= 1;
129            if self.in_degree[dep_id] == 0 {
130                if let Some(n) = self.nodes.get_mut(dep_id) {
131                    if n.status == TaskStatus::Pending {
132                        n.status = TaskStatus::Ready;
133                    }
134                }
135            }
136        }
137    }
138
139    /// Mark a task as failed (dependents remain Pending — caller decides policy).
140    pub fn fail(&mut self, task_id: usize) {
141        if let Some(node) = self.nodes.get_mut(task_id) {
142            node.status = TaskStatus::Failed;
143        }
144    }
145
146    pub fn get(&self, task_id: usize) -> Option<&TaskNode> {
147        self.nodes.get(task_id)
148    }
149
150    pub fn len(&self) -> usize {
151        self.nodes.len()
152    }
153
154    pub fn is_empty(&self) -> bool {
155        self.nodes.is_empty()
156    }
157
158    pub fn all_done(&self) -> bool {
159        self.nodes
160            .iter()
161            .all(|n| matches!(n.status, TaskStatus::Completed | TaskStatus::Failed))
162    }
163}
164
165impl Default for TaskGraph {
166    fn default() -> Self {
167        Self::new()
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn topological_sort_linear() {
177        let mut g = TaskGraph::new();
178        let a = g.add(RuntimeTask::new("A"), vec![]);
179        let b = g.add(RuntimeTask::new("B"), vec![a]);
180        let c = g.add(RuntimeTask::new("C"), vec![b]);
181
182        let order = g.topological_sort().unwrap();
183        assert_eq!(order, vec![0, 1, 2]);
184        let _ = (a, c);
185    }
186
187    #[test]
188    fn detects_cycle() {
189        let mut g = TaskGraph::new();
190        g.nodes.push(TaskNode {
191            id: 0,
192            task: RuntimeTask::new("A"),
193            status: TaskStatus::Pending,
194            result: None,
195            dependencies: vec![1],
196        });
197        g.nodes.push(TaskNode {
198            id: 1,
199            task: RuntimeTask::new("B"),
200            status: TaskStatus::Pending,
201            result: None,
202            dependencies: vec![0],
203        });
204        g.in_degree.push(1);
205        g.in_degree.push(1);
206
207        assert!(g.topological_sort().is_err());
208    }
209
210    #[test]
211    fn ready_tasks_respects_deps() {
212        let mut g = TaskGraph::new();
213        let a = g.add(RuntimeTask::new("A"), vec![]);
214        let _b = g.add(RuntimeTask::new("B"), vec![a]);
215
216        assert_eq!(g.ready_tasks(), vec![0]); // only A is Ready
217    }
218
219    #[test]
220    fn set_ready_rearms_without_promoting_dependents() {
221        let mut g = TaskGraph::new();
222        let a = g.add(RuntimeTask::new("A"), vec![]); // loop node
223        let b = g.add(RuntimeTask::new("B"), vec![a]); // dependent
224        g.start(a);
225        // Re-arm A for its next iteration: A is Ready again, but B stays Pending (no promotion).
226        g.set_ready(a);
227        assert_eq!(g.nodes[a].status, TaskStatus::Ready);
228        assert_eq!(g.nodes[b].status, TaskStatus::Pending);
229        assert_eq!(g.ready_tasks(), vec![a]);
230    }
231
232    #[test]
233    fn complete_promotes_dependent() {
234        use crate::types::result::{LoopResult, TerminationReason};
235        let mut g = TaskGraph::new();
236        let a = g.add(RuntimeTask::new("A"), vec![]);
237        let b = g.add(RuntimeTask::new("B"), vec![a]);
238
239        assert_eq!(g.nodes[b].status, TaskStatus::Pending);
240        g.complete(
241            a,
242            LoopResult {
243                termination: TerminationReason::Completed,
244                final_message: None,
245                turns_used: 1,
246                total_tokens_used: 0,
247                loop_continue: None,
248                classify_branch: None,
249                tournament_winner: None,
250            },
251        );
252        assert_eq!(g.nodes[b].status, TaskStatus::Ready);
253    }
254}