Skip to main content

brainwires_agents/
task_orchestrator.rs

1//! Task Orchestrator - Bridges TaskManager and AgentPool
2//!
3//! [`TaskOrchestrator`] runs a scheduling loop that queries ready tasks from the
4//! dependency graph, spawns agents via the pool, and feeds results back into the
5//! task manager.  This provides centralized status tracking with concurrent agent
6//! execution and dependency-aware ordering.
7
8use std::collections::HashMap;
9use std::sync::Arc;
10
11use anyhow::{Result, anyhow};
12use tokio::sync::RwLock;
13
14use brainwires_core::{Task, TaskPriority, TaskStatus};
15
16use crate::communication::{AgentMessage, CommunicationHub};
17use crate::pool::AgentPool;
18use crate::task_agent::{TaskAgentConfig, TaskAgentResult};
19use crate::task_manager::TaskManager;
20use crate::task_manager::TaskStats;
21
22const DEFAULT_POLL_INTERVAL_MS: u64 = 250;
23
24// ── Public types ────────────────────────────────────────────────────────────
25
26/// What happens when an agent's task fails.
27#[derive(Debug, Clone, PartialEq, Eq, Default)]
28pub enum FailurePolicy {
29    /// Stop scheduling new tasks and drain running agents (default).
30    #[default]
31    StopOnFirstFailure,
32    /// Keep scheduling independent tasks that aren't blocked by the failure.
33    ContinueOnFailure,
34}
35
36/// Configuration for the orchestration loop.
37#[derive(Debug, Clone)]
38pub struct TaskOrchestratorConfig {
39    /// Behaviour on agent failure.
40    pub failure_policy: FailurePolicy,
41    /// Default agent config used when no per-task override exists.
42    pub default_agent_config: TaskAgentConfig,
43    /// Polling interval in milliseconds.  Default: 250.
44    pub poll_interval_ms: u64,
45    /// Identifier used in CommunicationHub messages.
46    pub orchestrator_id: String,
47}
48
49impl Default for TaskOrchestratorConfig {
50    fn default() -> Self {
51        Self {
52            failure_policy: FailurePolicy::default(),
53            default_agent_config: TaskAgentConfig::default(),
54            poll_interval_ms: DEFAULT_POLL_INTERVAL_MS,
55            orchestrator_id: "orchestrator".to_string(),
56        }
57    }
58}
59
60/// Specification for creating a task via the convenience API.
61#[derive(Debug, Clone)]
62pub struct TaskSpec {
63    /// Task description.
64    pub description: String,
65    /// Task priority.
66    pub priority: TaskPriority,
67    /// Indices into the spec list that this task depends on.
68    pub depends_on_indices: Vec<usize>,
69    /// Per-task agent config override (falls back to default).
70    pub agent_config: Option<TaskAgentConfig>,
71}
72
73/// Result of a complete orchestration run.
74#[derive(Debug)]
75pub struct OrchestrationResult {
76    /// `true` when every task succeeded.
77    pub all_succeeded: bool,
78    /// Per-task agent results keyed by task ID.
79    pub task_results: HashMap<String, TaskAgentResult>,
80    /// Task IDs that were never started (e.g. blocked by a failure).
81    pub unstarted_tasks: Vec<String>,
82    /// Final task statistics snapshot.
83    pub stats: TaskStats,
84}
85
86// ── TaskOrchestrator ────────────────────────────────────────────────────────
87
88/// Bridges [`TaskManager`] and [`AgentPool`] with a dependency-aware
89/// scheduling loop.
90pub struct TaskOrchestrator {
91    task_manager: Arc<TaskManager>,
92    agent_pool: Arc<AgentPool>,
93    communication_hub: Arc<CommunicationHub>,
94    config: TaskOrchestratorConfig,
95    /// Per-task agent config overrides.
96    per_task_configs: Arc<RwLock<HashMap<String, TaskAgentConfig>>>,
97    /// Maps agent_id -> task_id for running agents.
98    agent_to_task: Arc<RwLock<HashMap<String, String>>>,
99    /// Abort flag — set by `abort()`.
100    aborted: Arc<tokio::sync::Notify>,
101    abort_flag: Arc<std::sync::atomic::AtomicBool>,
102}
103
104impl TaskOrchestrator {
105    /// Create a new orchestrator.
106    pub fn new(
107        task_manager: Arc<TaskManager>,
108        agent_pool: Arc<AgentPool>,
109        communication_hub: Arc<CommunicationHub>,
110        config: TaskOrchestratorConfig,
111    ) -> Self {
112        Self {
113            task_manager,
114            agent_pool,
115            communication_hub,
116            config,
117            per_task_configs: Arc::new(RwLock::new(HashMap::new())),
118            agent_to_task: Arc::new(RwLock::new(HashMap::new())),
119            aborted: Arc::new(tokio::sync::Notify::new()),
120            abort_flag: Arc::new(std::sync::atomic::AtomicBool::new(false)),
121        }
122    }
123
124    /// Set a per-task agent config override.
125    pub async fn set_task_config(&self, task_id: impl Into<String>, config: TaskAgentConfig) {
126        self.per_task_configs
127            .write()
128            .await
129            .insert(task_id.into(), config);
130    }
131
132    /// Bulk-set per-task agent config overrides.
133    pub async fn set_task_configs(&self, configs: HashMap<String, TaskAgentConfig>) {
134        let mut map = self.per_task_configs.write().await;
135        map.extend(configs);
136    }
137
138    /// Convenience API: create tasks with dependencies in the TaskManager,
139    /// then run the scheduling loop.
140    ///
141    /// `parent_task_id` is an optional root task under which all specs are
142    /// created as subtasks.
143    pub async fn create_and_run(
144        &self,
145        parent_task_id: Option<&str>,
146        specs: Vec<TaskSpec>,
147    ) -> Result<OrchestrationResult> {
148        // Create tasks and collect their IDs (index-ordered).
149        let mut task_ids: Vec<String> = Vec::with_capacity(specs.len());
150        for spec in &specs {
151            let id = self
152                .task_manager
153                .create_task(
154                    spec.description.clone(),
155                    parent_task_id.map(|s| s.to_string()),
156                    spec.priority,
157                )
158                .await?;
159            task_ids.push(id);
160        }
161
162        // Wire up dependencies by index.
163        for (i, spec) in specs.iter().enumerate() {
164            for &dep_idx in &spec.depends_on_indices {
165                if dep_idx >= task_ids.len() {
166                    return Err(anyhow!(
167                        "TaskSpec[{}] depends_on_indices contains out-of-range index {}",
168                        i,
169                        dep_idx
170                    ));
171                }
172                self.task_manager
173                    .add_dependency(&task_ids[i], &task_ids[dep_idx])
174                    .await?;
175            }
176
177            // Store per-task config overrides.
178            if let Some(ref cfg) = spec.agent_config {
179                self.set_task_config(&task_ids[i], cfg.clone()).await;
180            }
181        }
182
183        // Determine root: either explicit parent or the first created task.
184        let root = parent_task_id
185            .map(|s| s.to_string())
186            .or_else(|| task_ids.first().cloned());
187
188        match root {
189            Some(id) => self.run(&id).await,
190            None => {
191                // Empty spec list — nothing to do.
192                Ok(OrchestrationResult {
193                    all_succeeded: true,
194                    task_results: HashMap::new(),
195                    unstarted_tasks: Vec::new(),
196                    stats: self.task_manager.get_stats().await,
197                })
198            }
199        }
200    }
201
202    /// Main scheduling loop over existing tasks in the TaskManager.
203    ///
204    /// Runs until all tasks reachable from `root_task_id` are completed/failed
205    /// or the failure policy halts scheduling.
206    pub async fn run(&self, root_task_id: &str) -> Result<OrchestrationResult> {
207        let mut task_results: HashMap<String, TaskAgentResult> = HashMap::new();
208        let mut halted = false;
209        let poll = tokio::time::Duration::from_millis(self.config.poll_interval_ms);
210
211        loop {
212            // Check abort flag.
213            if self.abort_flag.load(std::sync::atomic::Ordering::Relaxed) {
214                halted = true;
215            }
216
217            // ── 1. Harvest completed agents ──────────────────────────────
218            let completed = self.agent_pool.cleanup_completed().await;
219            for (agent_id, result) in completed {
220                let task_id = { self.agent_to_task.write().await.remove(&agent_id) };
221
222                if let Some(task_id) = task_id {
223                    match result {
224                        Ok(agent_result) => {
225                            if agent_result.success {
226                                let summary = agent_result.summary.clone();
227                                self.task_manager
228                                    .complete_task(&task_id, summary.clone())
229                                    .await?;
230
231                                if let Err(e) = self
232                                    .communication_hub
233                                    .broadcast(
234                                        self.config.orchestrator_id.clone(),
235                                        AgentMessage::AgentCompleted {
236                                            agent_id: agent_id.clone(),
237                                            task_id: task_id.clone(),
238                                            summary,
239                                        },
240                                    )
241                                    .await
242                                {
243                                    tracing::warn!(agent_id = %agent_id, task_id = %task_id, "Failed to broadcast agent completion: {}", e);
244                                }
245                            } else {
246                                let error = agent_result.summary.clone();
247                                self.task_manager.fail_task(&task_id, error.clone()).await?;
248
249                                if let Err(e) = self
250                                    .communication_hub
251                                    .broadcast(
252                                        self.config.orchestrator_id.clone(),
253                                        AgentMessage::AgentCompleted {
254                                            agent_id: agent_id.clone(),
255                                            task_id: task_id.clone(),
256                                            summary: format!("FAILED: {}", error),
257                                        },
258                                    )
259                                    .await
260                                {
261                                    tracing::warn!(agent_id = %agent_id, task_id = %task_id, "Failed to broadcast agent failure: {}", e);
262                                }
263
264                                if self.config.failure_policy == FailurePolicy::StopOnFirstFailure {
265                                    halted = true;
266                                }
267                            }
268                            task_results.insert(task_id, agent_result);
269                        }
270                        Err(e) => {
271                            let error = format!("Agent panicked: {}", e);
272                            self.task_manager.fail_task(&task_id, error.clone()).await?;
273
274                            if let Err(e) = self
275                                .communication_hub
276                                .broadcast(
277                                    self.config.orchestrator_id.clone(),
278                                    AgentMessage::AgentCompleted {
279                                        agent_id: agent_id.clone(),
280                                        task_id: task_id.clone(),
281                                        summary: error,
282                                    },
283                                )
284                                .await
285                            {
286                                tracing::warn!(agent_id = %agent_id, task_id = %task_id, "Failed to broadcast agent panic: {}", e);
287                            }
288
289                            if self.config.failure_policy == FailurePolicy::StopOnFirstFailure {
290                                halted = true;
291                            }
292                        }
293                    }
294                }
295            }
296
297            // ── 2. Schedule new tasks (unless halted) ────────────────────
298            if !halted {
299                let ready = self.task_manager.get_ready_tasks().await;
300
301                // Filter out tasks already assigned to an agent.
302                let assigned: std::collections::HashSet<String> = {
303                    let map = self.agent_to_task.read().await;
304                    map.values().cloned().collect()
305                };
306
307                let available_slots = {
308                    let stats = self.agent_pool.stats().await;
309                    stats.max_agents.saturating_sub(stats.running)
310                };
311
312                let mut spawned = 0usize;
313                for task in &ready {
314                    if spawned >= available_slots {
315                        break;
316                    }
317                    if assigned.contains(&task.id) {
318                        continue;
319                    }
320                    // Skip tasks already InProgress (shouldn't happen, but be safe).
321                    if task.status == TaskStatus::InProgress {
322                        continue;
323                    }
324                    // Skip parent/container tasks — they auto-complete via
325                    // check_parent_completion when all children finish.
326                    if !task.children.is_empty() {
327                        continue;
328                    }
329
330                    // Resolve config for this task.
331                    let agent_config = {
332                        let overrides = self.per_task_configs.read().await;
333                        overrides
334                            .get(&task.id)
335                            .cloned()
336                            .unwrap_or_else(|| self.config.default_agent_config.clone())
337                    };
338
339                    // Build a core Task for the agent pool.
340                    let agent_task = Task::new(task.id.clone(), task.description.clone());
341
342                    // Start + assign in TaskManager.
343                    self.task_manager.start_task(&task.id).await?;
344
345                    match self
346                        .agent_pool
347                        .spawn_agent(agent_task, Some(agent_config))
348                        .await
349                    {
350                        Ok(agent_id) => {
351                            self.task_manager.assign_task(&task.id, &agent_id).await?;
352                            self.agent_to_task
353                                .write()
354                                .await
355                                .insert(agent_id.clone(), task.id.clone());
356
357                            if let Err(e) = self
358                                .communication_hub
359                                .broadcast(
360                                    self.config.orchestrator_id.clone(),
361                                    AgentMessage::AgentSpawned {
362                                        agent_id,
363                                        task_id: task.id.clone(),
364                                    },
365                                )
366                                .await
367                            {
368                                tracing::warn!(task_id = %task.id, "Failed to broadcast agent spawn: {}", e);
369                            }
370
371                            spawned += 1;
372                        }
373                        Err(e) => {
374                            tracing::warn!(task_id = %task.id, error = %e, "failed to spawn agent");
375                            // Revert status back to Pending so it can be retried.
376                            self.task_manager
377                                .update_status(&task.id, TaskStatus::Pending, None)
378                                .await?;
379                        }
380                    }
381                }
382            }
383
384            // ── 3. Check termination ─────────────────────────────────────
385            let running = {
386                let map = self.agent_to_task.read().await;
387                map.len()
388            };
389
390            if running == 0 {
391                // No running agents.  If halted or no more schedulable tasks, we're done.
392                let ready = self.task_manager.get_ready_tasks().await;
393                let assigned: std::collections::HashSet<String> = {
394                    let map = self.agent_to_task.read().await;
395                    map.values().cloned().collect()
396                };
397                // Only leaf tasks (no children) are schedulable.
398                let schedulable: Vec<_> = ready
399                    .iter()
400                    .filter(|t| !assigned.contains(&t.id) && t.children.is_empty())
401                    .collect();
402
403                if halted || schedulable.is_empty() {
404                    break;
405                }
406            }
407
408            tokio::time::sleep(poll).await;
409        }
410
411        // ── Build result ────────────────────────────────────────────────
412        let stats = self.task_manager.get_stats().await;
413        let all_tasks = self.task_manager.get_task_tree(Some(root_task_id)).await;
414        let unstarted: Vec<String> = all_tasks
415            .iter()
416            .filter(|t| t.status == TaskStatus::Pending || t.status == TaskStatus::Blocked)
417            .map(|t| t.id.clone())
418            .collect();
419
420        let all_succeeded = stats.failed == 0 && unstarted.is_empty();
421
422        Ok(OrchestrationResult {
423            all_succeeded,
424            task_results,
425            unstarted_tasks: unstarted,
426            stats,
427        })
428    }
429
430    /// Cancel all running agents and return.
431    pub async fn abort(&self) {
432        self.abort_flag
433            .store(true, std::sync::atomic::Ordering::Relaxed);
434        self.aborted.notify_one();
435        self.agent_pool.shutdown().await;
436        self.agent_to_task.write().await.clear();
437    }
438
439    /// Live task statistics snapshot.
440    pub async fn progress(&self) -> TaskStats {
441        self.task_manager.get_stats().await
442    }
443
444    /// Map of currently running agent_id -> task_id.
445    pub async fn running_agents(&self) -> HashMap<String, String> {
446        self.agent_to_task.read().await.clone()
447    }
448}
449
450// ── Tests ───────────────────────────────────────────────────────────────────
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455    use crate::communication::CommunicationHub;
456    use crate::file_locks::FileLockManager;
457    use crate::pool::AgentPool;
458    use crate::task_agent::TaskAgentConfig;
459
460    use async_trait::async_trait;
461    use brainwires_core::{
462        ChatOptions, ChatResponse, Message, Provider, StreamChunk, Tool, ToolContext, ToolResult,
463        ToolUse, Usage,
464    };
465    use brainwires_tool_system::ToolExecutor;
466    use futures::stream::BoxStream;
467
468    // ── Mock provider that returns "Done" immediately ────────────────────
469
470    struct MockProvider(ChatResponse);
471
472    impl MockProvider {
473        fn done(text: &str) -> Self {
474            Self(ChatResponse {
475                message: Message::assistant(text),
476                finish_reason: Some("stop".to_string()),
477                usage: Usage::default(),
478            })
479        }
480    }
481
482    #[async_trait]
483    impl Provider for MockProvider {
484        fn name(&self) -> &str {
485            "mock"
486        }
487
488        async fn chat(
489            &self,
490            _: &[Message],
491            _: Option<&[Tool]>,
492            _: &ChatOptions,
493        ) -> Result<ChatResponse> {
494            Ok(self.0.clone())
495        }
496
497        fn stream_chat<'a>(
498            &'a self,
499            _: &'a [Message],
500            _: Option<&'a [Tool]>,
501            _: &'a ChatOptions,
502        ) -> BoxStream<'a, Result<StreamChunk>> {
503            Box::pin(futures::stream::empty())
504        }
505    }
506
507    struct NoOpExecutor;
508
509    #[async_trait]
510    impl ToolExecutor for NoOpExecutor {
511        async fn execute(&self, tu: &ToolUse, _: &ToolContext) -> Result<ToolResult> {
512            Ok(ToolResult::success(tu.id.clone(), "ok".to_string()))
513        }
514
515        fn available_tools(&self) -> Vec<Tool> {
516            vec![]
517        }
518    }
519
520    // ── Helpers ──────────────────────────────────────────────────────────
521
522    fn make_deps(max_pool: usize) -> (Arc<TaskManager>, Arc<AgentPool>, Arc<CommunicationHub>) {
523        let hub = Arc::new(CommunicationHub::new());
524        let flm = Arc::new(FileLockManager::new());
525        let provider: Arc<dyn Provider> = Arc::new(MockProvider::done("Done"));
526        let executor: Arc<dyn ToolExecutor> = Arc::new(NoOpExecutor);
527
528        let tm = Arc::new(TaskManager::new());
529        let pool = Arc::new(AgentPool::new(
530            max_pool,
531            provider,
532            executor,
533            Arc::clone(&hub),
534            flm,
535            "/tmp",
536        ));
537
538        (tm, pool, hub)
539    }
540
541    fn no_validation() -> TaskAgentConfig {
542        TaskAgentConfig {
543            validation_config: None,
544            ..Default::default()
545        }
546    }
547
548    fn make_orchestrator(
549        tm: Arc<TaskManager>,
550        pool: Arc<AgentPool>,
551        hub: Arc<CommunicationHub>,
552    ) -> TaskOrchestrator {
553        TaskOrchestrator::new(
554            tm,
555            pool,
556            hub,
557            TaskOrchestratorConfig {
558                default_agent_config: no_validation(),
559                ..Default::default()
560            },
561        )
562    }
563
564    // ── Tests ────────────────────────────────────────────────────────────
565
566    #[tokio::test]
567    async fn test_empty_orchestration() {
568        let (tm, pool, hub) = make_deps(5);
569        let orch = make_orchestrator(tm.clone(), pool, hub);
570
571        // Create a root task with no children.
572        let root = tm
573            .create_task("root".to_string(), None, TaskPriority::Normal)
574            .await
575            .unwrap();
576
577        // Mark it completed so the loop has nothing to do.
578        tm.complete_task(&root, "already done".to_string())
579            .await
580            .unwrap();
581
582        let result = orch.run(&root).await.unwrap();
583        assert!(result.all_succeeded);
584        assert!(result.task_results.is_empty());
585        assert!(result.unstarted_tasks.is_empty());
586    }
587
588    #[tokio::test]
589    async fn test_single_task() {
590        let (tm, pool, hub) = make_deps(5);
591        let orch = make_orchestrator(tm.clone(), pool, hub);
592
593        let root = tm
594            .create_task("build widget".to_string(), None, TaskPriority::Normal)
595            .await
596            .unwrap();
597
598        let result = orch.run(&root).await.unwrap();
599        assert!(result.all_succeeded);
600        assert_eq!(result.task_results.len(), 1);
601        assert!(result.task_results.contains_key(&root));
602    }
603
604    #[tokio::test]
605    async fn test_sequential_dependency_chain() {
606        let (tm, pool, hub) = make_deps(5);
607        let orch = make_orchestrator(tm.clone(), pool, hub);
608
609        // Create a chain: A -> B -> C  (C depends on B, B depends on A).
610        // Use a parent to group them so get_task_tree works.
611        let parent = tm
612            .create_task("parent".to_string(), None, TaskPriority::Normal)
613            .await
614            .unwrap();
615        let a = tm
616            .create_task(
617                "step A".to_string(),
618                Some(parent.clone()),
619                TaskPriority::Normal,
620            )
621            .await
622            .unwrap();
623        let b = tm
624            .create_task(
625                "step B".to_string(),
626                Some(parent.clone()),
627                TaskPriority::Normal,
628            )
629            .await
630            .unwrap();
631        let c = tm
632            .create_task(
633                "step C".to_string(),
634                Some(parent.clone()),
635                TaskPriority::Normal,
636            )
637            .await
638            .unwrap();
639
640        tm.add_dependency(&b, &a).await.unwrap();
641        tm.add_dependency(&c, &b).await.unwrap();
642
643        let result = orch.run(&parent).await.unwrap();
644        assert!(result.all_succeeded);
645        assert_eq!(result.task_results.len(), 3);
646    }
647
648    #[tokio::test]
649    async fn test_parallel_independent_tasks() {
650        let (tm, pool, hub) = make_deps(5);
651        let orch = make_orchestrator(tm.clone(), pool, hub);
652
653        let parent = tm
654            .create_task("parent".to_string(), None, TaskPriority::Normal)
655            .await
656            .unwrap();
657        let _a = tm
658            .create_task(
659                "task A".to_string(),
660                Some(parent.clone()),
661                TaskPriority::Normal,
662            )
663            .await
664            .unwrap();
665        let _b = tm
666            .create_task(
667                "task B".to_string(),
668                Some(parent.clone()),
669                TaskPriority::Normal,
670            )
671            .await
672            .unwrap();
673        let _c = tm
674            .create_task(
675                "task C".to_string(),
676                Some(parent.clone()),
677                TaskPriority::Normal,
678            )
679            .await
680            .unwrap();
681
682        let result = orch.run(&parent).await.unwrap();
683        assert!(result.all_succeeded);
684        assert_eq!(result.task_results.len(), 3);
685    }
686
687    #[tokio::test]
688    async fn test_diamond_dependency() {
689        // A -> (B, C) -> D
690        let (tm, pool, hub) = make_deps(5);
691        let orch = make_orchestrator(tm.clone(), pool, hub);
692
693        let parent = tm
694            .create_task("parent".to_string(), None, TaskPriority::Normal)
695            .await
696            .unwrap();
697        let a = tm
698            .create_task("A".to_string(), Some(parent.clone()), TaskPriority::Normal)
699            .await
700            .unwrap();
701        let b = tm
702            .create_task("B".to_string(), Some(parent.clone()), TaskPriority::Normal)
703            .await
704            .unwrap();
705        let c = tm
706            .create_task("C".to_string(), Some(parent.clone()), TaskPriority::Normal)
707            .await
708            .unwrap();
709        let d = tm
710            .create_task("D".to_string(), Some(parent.clone()), TaskPriority::Normal)
711            .await
712            .unwrap();
713
714        tm.add_dependency(&b, &a).await.unwrap();
715        tm.add_dependency(&c, &a).await.unwrap();
716        tm.add_dependency(&d, &b).await.unwrap();
717        tm.add_dependency(&d, &c).await.unwrap();
718
719        let result = orch.run(&parent).await.unwrap();
720        assert!(result.all_succeeded);
721        assert_eq!(result.task_results.len(), 4);
722    }
723
724    #[tokio::test]
725    async fn test_stop_on_first_failure() {
726        // Use a provider that returns a failure response.
727        let hub = Arc::new(CommunicationHub::new());
728        let flm = Arc::new(FileLockManager::new());
729
730        // A provider whose "Done" text triggers agent success, but we need
731        // a way to fail. TaskAgent treats a "stop" finish_reason as success
732        // when the assistant text is non-empty. So we use a two-task setup
733        // where we manually fail one task to test the policy.
734        let provider: Arc<dyn Provider> = Arc::new(MockProvider::done("Done"));
735        let executor: Arc<dyn ToolExecutor> = Arc::new(NoOpExecutor);
736
737        let tm = Arc::new(TaskManager::new());
738        let pool = Arc::new(AgentPool::new(
739            5,
740            provider,
741            executor,
742            Arc::clone(&hub),
743            flm,
744            "/tmp",
745        ));
746
747        let orch = TaskOrchestrator::new(
748            Arc::clone(&tm),
749            Arc::clone(&pool),
750            hub,
751            TaskOrchestratorConfig {
752                failure_policy: FailurePolicy::StopOnFirstFailure,
753                default_agent_config: no_validation(),
754                ..Default::default()
755            },
756        );
757
758        // A -> B (sequential), so if A succeeds normally, B should follow.
759        // For this test, we create independent tasks so the orchestrator sees
760        // failures on independent paths.
761        let parent = tm
762            .create_task("parent".to_string(), None, TaskPriority::Normal)
763            .await
764            .unwrap();
765        let a = tm
766            .create_task("A".to_string(), Some(parent.clone()), TaskPriority::Normal)
767            .await
768            .unwrap();
769        let _b = tm
770            .create_task("B".to_string(), Some(parent.clone()), TaskPriority::Normal)
771            .await
772            .unwrap();
773
774        // Pre-fail task A so the orchestrator picks it up as failed immediately.
775        tm.fail_task(&a, "forced failure".to_string())
776            .await
777            .unwrap();
778
779        let result = orch.run(&parent).await.unwrap();
780        // A is failed, B may or may not have run depending on timing.
781        assert!(!result.all_succeeded);
782    }
783
784    #[tokio::test]
785    async fn test_continue_on_failure() {
786        let hub = Arc::new(CommunicationHub::new());
787        let flm = Arc::new(FileLockManager::new());
788        let provider: Arc<dyn Provider> = Arc::new(MockProvider::done("Done"));
789        let executor: Arc<dyn ToolExecutor> = Arc::new(NoOpExecutor);
790
791        let tm = Arc::new(TaskManager::new());
792        let pool = Arc::new(AgentPool::new(
793            5,
794            provider,
795            executor,
796            Arc::clone(&hub),
797            flm,
798            "/tmp",
799        ));
800
801        let orch = TaskOrchestrator::new(
802            Arc::clone(&tm),
803            Arc::clone(&pool),
804            hub,
805            TaskOrchestratorConfig {
806                failure_policy: FailurePolicy::ContinueOnFailure,
807                default_agent_config: no_validation(),
808                ..Default::default()
809            },
810        );
811
812        let parent = tm
813            .create_task("parent".to_string(), None, TaskPriority::Normal)
814            .await
815            .unwrap();
816        let a = tm
817            .create_task("A".to_string(), Some(parent.clone()), TaskPriority::Normal)
818            .await
819            .unwrap();
820        let b_id = tm
821            .create_task("B".to_string(), Some(parent.clone()), TaskPriority::Normal)
822            .await
823            .unwrap();
824
825        // Pre-fail A, B is independent and should still run.
826        tm.fail_task(&a, "forced failure".to_string())
827            .await
828            .unwrap();
829
830        let result = orch.run(&parent).await.unwrap();
831        // B should have completed even though A failed.
832        assert!(!result.all_succeeded); // A failed so not all_succeeded.
833        assert!(result.task_results.contains_key(&b_id));
834    }
835
836    #[tokio::test]
837    async fn test_pool_capacity_respect() {
838        // Pool size 1, 3 independent tasks — only one at a time.
839        let (tm, pool, hub) = make_deps(1);
840        let orch = make_orchestrator(tm.clone(), pool, hub);
841
842        let parent = tm
843            .create_task("parent".to_string(), None, TaskPriority::Normal)
844            .await
845            .unwrap();
846        let _a = tm
847            .create_task("A".to_string(), Some(parent.clone()), TaskPriority::Normal)
848            .await
849            .unwrap();
850        let _b = tm
851            .create_task("B".to_string(), Some(parent.clone()), TaskPriority::Normal)
852            .await
853            .unwrap();
854        let _c = tm
855            .create_task("C".to_string(), Some(parent.clone()), TaskPriority::Normal)
856            .await
857            .unwrap();
858
859        let result = orch.run(&parent).await.unwrap();
860        assert!(result.all_succeeded);
861        assert_eq!(result.task_results.len(), 3);
862    }
863
864    #[tokio::test]
865    async fn test_assigned_to_tracking() {
866        let (tm, pool, hub) = make_deps(5);
867        let orch = make_orchestrator(tm.clone(), pool, hub);
868
869        let root = tm
870            .create_task("build widget".to_string(), None, TaskPriority::Normal)
871            .await
872            .unwrap();
873
874        let result = orch.run(&root).await.unwrap();
875        assert!(result.all_succeeded);
876
877        // After completion, assigned_to should have been set.
878        let task = tm.get_task(&root).await.unwrap();
879        assert!(task.assigned_to.is_some());
880    }
881
882    #[tokio::test]
883    async fn test_create_and_run() {
884        let (tm, pool, hub) = make_deps(5);
885        let orch = make_orchestrator(tm.clone(), pool, hub);
886
887        let specs = vec![
888            TaskSpec {
889                description: "step A".to_string(),
890                priority: TaskPriority::Normal,
891                depends_on_indices: vec![],
892                agent_config: None,
893            },
894            TaskSpec {
895                description: "step B".to_string(),
896                priority: TaskPriority::Normal,
897                depends_on_indices: vec![0],
898                agent_config: None,
899            },
900            TaskSpec {
901                description: "step C".to_string(),
902                priority: TaskPriority::Normal,
903                depends_on_indices: vec![0],
904                agent_config: None,
905            },
906            TaskSpec {
907                description: "step D".to_string(),
908                priority: TaskPriority::Normal,
909                depends_on_indices: vec![1, 2],
910                agent_config: None,
911            },
912        ];
913
914        let result = orch.create_and_run(None, specs).await.unwrap();
915        assert!(result.all_succeeded);
916        assert_eq!(result.task_results.len(), 4);
917    }
918}