Skip to main content

brainwires_agents/task_manager/
dependency_ops.rs

1//! Dependency Operations
2//!
3//! Task dependency management: add, remove, check dependencies, cycle detection.
4
5use anyhow::{Context, Result};
6
7use super::TaskManager;
8use brainwires_core::TaskStatus;
9
10impl TaskManager {
11    /// Add a dependency between tasks
12    pub async fn add_dependency(&self, task_id: &str, depends_on: &str) -> Result<()> {
13        // Check for circular dependency before acquiring write lock
14        if self.would_create_cycle(task_id, depends_on).await? {
15            anyhow::bail!(
16                "Adding dependency '{}' -> '{}' would create a circular dependency",
17                task_id,
18                depends_on
19            );
20        }
21
22        let mut tasks = self.tasks.write().await;
23
24        // Verify both tasks exist
25        if !tasks.contains_key(depends_on) {
26            anyhow::bail!("Dependency task '{}' not found", depends_on);
27        }
28
29        let task = tasks
30            .get_mut(task_id)
31            .context(format!("Task '{}' not found", task_id))?;
32
33        task.add_dependency(depends_on.to_string());
34
35        // If dependency is not complete/skipped, mark task as blocked
36        let dep_status = tasks
37            .get(depends_on)
38            .expect("dependency existence verified above")
39            .status
40            .clone();
41        if dep_status != TaskStatus::Completed && dep_status != TaskStatus::Skipped {
42            tasks
43                .get_mut(task_id)
44                .expect("task existence verified above")
45                .status = TaskStatus::Blocked;
46        }
47
48        Ok(())
49    }
50
51    /// Check if adding a dependency would create a circular dependency
52    async fn would_create_cycle(&self, task_id: &str, depends_on: &str) -> Result<bool> {
53        // If task_id == depends_on, it's a self-dependency (cycle)
54        if task_id == depends_on {
55            return Ok(true);
56        }
57
58        let tasks = self.tasks.read().await;
59
60        // BFS to check if depends_on can reach task_id through its dependencies
61        let mut visited = std::collections::HashSet::new();
62        let mut queue = std::collections::VecDeque::new();
63
64        queue.push_back(depends_on.to_string());
65
66        while let Some(current) = queue.pop_front() {
67            if current == task_id {
68                return Ok(true); // Found a cycle
69            }
70
71            if visited.contains(&current) {
72                continue;
73            }
74            visited.insert(current.clone());
75
76            if let Some(task) = tasks.get(&current) {
77                for dep in &task.depends_on {
78                    if !visited.contains(dep) {
79                        queue.push_back(dep.clone());
80                    }
81                }
82            }
83        }
84
85        Ok(false)
86    }
87
88    /// Check if a task can be started (all dependencies are complete/skipped)
89    /// Returns Ok(true) if task can start, or Err with list of blocking task IDs
90    pub async fn can_start(&self, task_id: &str) -> std::result::Result<bool, Vec<String>> {
91        let tasks = self.tasks.read().await;
92
93        let task = match tasks.get(task_id) {
94            Some(t) => t,
95            None => return Ok(false), // Task doesn't exist
96        };
97
98        // If task is already completed, failed, or skipped, it can't be started
99        if matches!(
100            task.status,
101            TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Skipped
102        ) {
103            return Ok(false);
104        }
105
106        // Collect blocking dependencies
107        let blocking: Vec<String> = task
108            .depends_on
109            .iter()
110            .filter(|dep_id| {
111                tasks
112                    .get(*dep_id)
113                    .map(|t| t.status != TaskStatus::Completed && t.status != TaskStatus::Skipped)
114                    .unwrap_or(true) // Missing dependency is considered blocking
115            })
116            .cloned()
117            .collect();
118
119        if blocking.is_empty() {
120            Ok(true)
121        } else {
122            Err(blocking)
123        }
124    }
125
126    /// Remove a dependency between tasks
127    pub async fn remove_dependency(&self, task_id: &str, depends_on: &str) -> Result<()> {
128        let mut tasks = self.tasks.write().await;
129
130        // First get the task, update it, and collect info for the check
131        let (is_blocked, remaining_deps) = {
132            let task = tasks
133                .get_mut(task_id)
134                .context(format!("Task '{}' not found", task_id))?;
135
136            task.depends_on.retain(|d| d != depends_on);
137            task.updated_at = chrono::Utc::now().timestamp();
138
139            (task.status == TaskStatus::Blocked, task.depends_on.clone())
140        };
141
142        // Check if task should be unblocked
143        if is_blocked {
144            let all_deps_done = remaining_deps.iter().all(|dep_id| {
145                tasks
146                    .get(dep_id)
147                    .map(|t| t.status == TaskStatus::Completed || t.status == TaskStatus::Skipped)
148                    .unwrap_or(false)
149            });
150
151            if all_deps_done && let Some(task) = tasks.get_mut(task_id) {
152                task.status = TaskStatus::Pending;
153            }
154        }
155
156        Ok(())
157    }
158
159    /// Unblock tasks that depend on a completed/skipped task
160    pub async fn unblock_dependents(&self, completed_task_id: &str) -> Result<()> {
161        let mut tasks = self.tasks.write().await;
162
163        // Find all tasks that depend on the completed task
164        let dependent_ids: Vec<String> = tasks
165            .values()
166            .filter(|t| t.depends_on.contains(&completed_task_id.to_string()))
167            .map(|t| t.id.clone())
168            .collect();
169
170        // Collect the dependency lists for each task first
171        let mut tasks_to_check: Vec<(String, Vec<String>)> = Vec::new();
172        for dep_id in &dependent_ids {
173            if let Some(task) = tasks.get(dep_id)
174                && task.status == TaskStatus::Blocked
175            {
176                tasks_to_check.push((dep_id.clone(), task.depends_on.clone()));
177            }
178        }
179
180        // Now update tasks based on dependency status
181        for (dep_id, deps) in tasks_to_check {
182            // Check if all dependencies are now complete/skipped
183            let all_deps_done = deps.iter().all(|d| {
184                tasks
185                    .get(d)
186                    .map(|t| t.status == TaskStatus::Completed || t.status == TaskStatus::Skipped)
187                    .unwrap_or(false)
188            });
189
190            if all_deps_done && let Some(task) = tasks.get_mut(&dep_id) {
191                task.status = TaskStatus::Pending;
192                task.updated_at = chrono::Utc::now().timestamp();
193            }
194        }
195
196        Ok(())
197    }
198}