Skip to main content

roboticus_agent/
orchestration.rs

1use chrono::{DateTime, Utc};
2use roboticus_core::{Result, RoboticusError};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use tracing::{debug, info, warn};
6
7/// Orchestration pattern for multi-agent coordination.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9pub enum OrchestrationPattern {
10    Sequential,
11    Parallel,
12    FanOutFanIn,
13    Handoff,
14}
15
16impl std::fmt::Display for OrchestrationPattern {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        match self {
19            OrchestrationPattern::Sequential => write!(f, "sequential"),
20            OrchestrationPattern::Parallel => write!(f, "parallel"),
21            OrchestrationPattern::FanOutFanIn => write!(f, "fan-out/fan-in"),
22            OrchestrationPattern::Handoff => write!(f, "handoff"),
23        }
24    }
25}
26
27/// A subtask assigned to a specialist agent.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct Subtask {
30    pub id: String,
31    pub description: String,
32    pub required_capabilities: Vec<String>,
33    #[serde(default)]
34    pub model_preference: Option<String>,
35    pub assigned_agent: Option<String>,
36    pub status: SubtaskStatus,
37    pub result: Option<String>,
38    pub created_at: DateTime<Utc>,
39    pub completed_at: Option<DateTime<Utc>>,
40}
41
42/// Status of a subtask.
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
44pub enum SubtaskStatus {
45    Pending,
46    Assigned,
47    Running,
48    Completed,
49    Failed,
50}
51
52/// A workflow composed of subtasks with an orchestration pattern.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct Workflow {
55    pub id: String,
56    pub name: String,
57    pub pattern: OrchestrationPattern,
58    pub subtasks: Vec<Subtask>,
59    pub status: WorkflowStatus,
60    pub created_at: DateTime<Utc>,
61    pub completed_at: Option<DateTime<Utc>>,
62}
63
64/// Status of a workflow.
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
66pub enum WorkflowStatus {
67    Created,
68    Running,
69    Completed,
70    Failed,
71    Cancelled,
72}
73
74/// Coordinates workflows and agent assignment.
75pub struct Orchestrator {
76    workflows: HashMap<String, Workflow>,
77    workflow_counter: u64,
78}
79
80impl Orchestrator {
81    pub fn new() -> Self {
82        Self {
83            workflows: HashMap::new(),
84            workflow_counter: 0,
85        }
86    }
87
88    /// Create a new workflow from subtask descriptions.
89    pub fn create_workflow(
90        &mut self,
91        name: &str,
92        pattern: OrchestrationPattern,
93        subtasks: Vec<(String, Vec<String>)>,
94    ) -> String {
95        self.workflow_counter += 1;
96        let workflow_id = format!("wf_{}", self.workflow_counter);
97
98        let tasks: Vec<Subtask> = subtasks
99            .into_iter()
100            .enumerate()
101            .map(|(i, (desc, caps))| Subtask {
102                id: format!("{}_task_{}", workflow_id, i),
103                description: desc,
104                required_capabilities: caps,
105                model_preference: None,
106                assigned_agent: None,
107                status: SubtaskStatus::Pending,
108                result: None,
109                created_at: Utc::now(),
110                completed_at: None,
111            })
112            .collect();
113
114        let workflow = Workflow {
115            id: workflow_id.clone(),
116            name: name.to_string(),
117            pattern,
118            subtasks: tasks,
119            status: WorkflowStatus::Created,
120            created_at: Utc::now(),
121            completed_at: None,
122        };
123
124        info!(id = %workflow_id, name, pattern = %pattern, tasks = workflow.subtasks.len(), "created workflow");
125        self.workflows.insert(workflow_id.clone(), workflow);
126        workflow_id
127    }
128
129    /// Create a workflow where each task can optionally request a model.
130    pub fn create_workflow_with_model_preferences(
131        &mut self,
132        name: &str,
133        pattern: OrchestrationPattern,
134        subtasks: Vec<(String, Vec<String>, Option<String>)>,
135    ) -> String {
136        self.workflow_counter += 1;
137        let workflow_id = format!("wf_{}", self.workflow_counter);
138
139        let tasks: Vec<Subtask> = subtasks
140            .into_iter()
141            .enumerate()
142            .map(|(i, (desc, caps, model_pref))| Subtask {
143                id: format!("{}_task_{}", workflow_id, i),
144                description: desc,
145                required_capabilities: caps,
146                model_preference: model_pref,
147                assigned_agent: None,
148                status: SubtaskStatus::Pending,
149                result: None,
150                created_at: Utc::now(),
151                completed_at: None,
152            })
153            .collect();
154
155        let workflow = Workflow {
156            id: workflow_id.clone(),
157            name: name.to_string(),
158            pattern,
159            subtasks: tasks,
160            status: WorkflowStatus::Created,
161            created_at: Utc::now(),
162            completed_at: None,
163        };
164
165        info!(
166            id = %workflow_id,
167            name,
168            pattern = %pattern,
169            tasks = workflow.subtasks.len(),
170            "created workflow with model preferences"
171        );
172        self.workflows.insert(workflow_id.clone(), workflow);
173        workflow_id
174    }
175
176    /// Assign an agent to a subtask.
177    pub fn assign_agent(&mut self, workflow_id: &str, task_id: &str, agent_id: &str) -> Result<()> {
178        let workflow = self.workflows.get_mut(workflow_id).ok_or_else(|| {
179            RoboticusError::Config(format!("workflow '{}' not found", workflow_id))
180        })?;
181
182        let task = workflow
183            .subtasks
184            .iter_mut()
185            .find(|t| t.id == task_id)
186            .ok_or_else(|| RoboticusError::Config(format!("task '{}' not found", task_id)))?;
187
188        task.assigned_agent = Some(agent_id.to_string());
189        task.status = SubtaskStatus::Assigned;
190        debug!(
191            workflow = workflow_id,
192            task = task_id,
193            agent = agent_id,
194            "agent assigned"
195        );
196        Ok(())
197    }
198
199    /// Set or clear a model preference for a specific task.
200    pub fn set_task_model_preference(
201        &mut self,
202        workflow_id: &str,
203        task_id: &str,
204        model_preference: Option<String>,
205    ) -> Result<()> {
206        let workflow = self.workflows.get_mut(workflow_id).ok_or_else(|| {
207            RoboticusError::Config(format!("workflow '{}' not found", workflow_id))
208        })?;
209        let task = workflow
210            .subtasks
211            .iter_mut()
212            .find(|t| t.id == task_id)
213            .ok_or_else(|| RoboticusError::Config(format!("task '{}' not found", task_id)))?;
214        task.model_preference = model_preference;
215        Ok(())
216    }
217
218    /// Match subtasks to available agents by capability overlap.
219    pub fn match_capabilities(
220        &self,
221        workflow_id: &str,
222        available_agents: &[(String, Vec<String>)],
223    ) -> Result<Vec<(String, String)>> {
224        let workflow = self.workflows.get(workflow_id).ok_or_else(|| {
225            RoboticusError::Config(format!("workflow '{}' not found", workflow_id))
226        })?;
227
228        let mut assignments = Vec::new();
229
230        for task in &workflow.subtasks {
231            if task.status != SubtaskStatus::Pending {
232                continue;
233            }
234
235            let best_agent = available_agents.iter().max_by_key(|(_, caps)| {
236                task.required_capabilities
237                    .iter()
238                    .filter(|rc| caps.contains(rc))
239                    .count()
240            });
241
242            if let Some((agent_id, caps)) = best_agent {
243                let overlap = task
244                    .required_capabilities
245                    .iter()
246                    .filter(|rc| caps.contains(rc))
247                    .count();
248                if overlap > 0 {
249                    assignments.push((task.id.clone(), agent_id.clone()));
250                }
251            }
252        }
253
254        Ok(assignments)
255    }
256
257    /// Start a subtask.
258    pub fn start_task(&mut self, workflow_id: &str, task_id: &str) -> Result<()> {
259        let workflow = self.workflows.get_mut(workflow_id).ok_or_else(|| {
260            RoboticusError::Config(format!("workflow '{}' not found", workflow_id))
261        })?;
262
263        let task = workflow
264            .subtasks
265            .iter_mut()
266            .find(|t| t.id == task_id)
267            .ok_or_else(|| RoboticusError::Config(format!("task '{}' not found", task_id)))?;
268
269        task.status = SubtaskStatus::Running;
270        workflow.status = WorkflowStatus::Running;
271        Ok(())
272    }
273
274    /// Complete a subtask with a result.
275    pub fn complete_task(
276        &mut self,
277        workflow_id: &str,
278        task_id: &str,
279        result: String,
280    ) -> Result<()> {
281        let workflow = self.workflows.get_mut(workflow_id).ok_or_else(|| {
282            RoboticusError::Config(format!("workflow '{}' not found", workflow_id))
283        })?;
284
285        let task = workflow
286            .subtasks
287            .iter_mut()
288            .find(|t| t.id == task_id)
289            .ok_or_else(|| RoboticusError::Config(format!("task '{}' not found", task_id)))?;
290
291        task.status = SubtaskStatus::Completed;
292        task.result = Some(result);
293        task.completed_at = Some(Utc::now());
294
295        if workflow
296            .subtasks
297            .iter()
298            .all(|t| t.status == SubtaskStatus::Completed)
299        {
300            workflow.status = WorkflowStatus::Completed;
301            workflow.completed_at = Some(Utc::now());
302            info!(id = %workflow_id, "workflow completed");
303        }
304
305        Ok(())
306    }
307
308    /// Fail a subtask.
309    pub fn fail_task(&mut self, workflow_id: &str, task_id: &str, error: &str) -> Result<()> {
310        let workflow = self.workflows.get_mut(workflow_id).ok_or_else(|| {
311            RoboticusError::Config(format!("workflow '{}' not found", workflow_id))
312        })?;
313
314        let task = workflow
315            .subtasks
316            .iter_mut()
317            .find(|t| t.id == task_id)
318            .ok_or_else(|| RoboticusError::Config(format!("task '{}' not found", task_id)))?;
319
320        task.status = SubtaskStatus::Failed;
321        task.result = Some(format!("ERROR: {}", error));
322        task.completed_at = Some(Utc::now());
323
324        workflow.status = WorkflowStatus::Failed;
325        warn!(workflow = workflow_id, task = task_id, error, "task failed");
326        Ok(())
327    }
328
329    /// Get a workflow by ID.
330    pub fn get_workflow(&self, id: &str) -> Option<&Workflow> {
331        self.workflows.get(id)
332    }
333
334    /// Get the next actionable tasks for a workflow based on its pattern.
335    pub fn next_tasks(&self, workflow_id: &str) -> Result<Vec<&Subtask>> {
336        let workflow = self.workflows.get(workflow_id).ok_or_else(|| {
337            RoboticusError::Config(format!("workflow '{}' not found", workflow_id))
338        })?;
339
340        match workflow.pattern {
341            OrchestrationPattern::Sequential => Ok(workflow
342                .subtasks
343                .iter()
344                .find(|t| t.status == SubtaskStatus::Pending || t.status == SubtaskStatus::Assigned)
345                .into_iter()
346                .collect()),
347            OrchestrationPattern::Parallel | OrchestrationPattern::FanOutFanIn => Ok(workflow
348                .subtasks
349                .iter()
350                .filter(|t| {
351                    t.status == SubtaskStatus::Pending || t.status == SubtaskStatus::Assigned
352                })
353                .collect()),
354            OrchestrationPattern::Handoff => {
355                let last_completed = workflow
356                    .subtasks
357                    .iter()
358                    .rposition(|t| t.status == SubtaskStatus::Completed);
359                let start_idx = last_completed.map(|i| i + 1).unwrap_or(0);
360                // Skip past any Failed tasks to find the next actionable one
361                Ok(workflow.subtasks[start_idx..]
362                    .iter()
363                    .find(|t| {
364                        t.status == SubtaskStatus::Pending || t.status == SubtaskStatus::Assigned
365                    })
366                    .into_iter()
367                    .collect())
368            }
369        }
370    }
371
372    pub fn workflow_count(&self) -> usize {
373        self.workflows.len()
374    }
375}
376
377impl Default for Orchestrator {
378    fn default() -> Self {
379        Self::new()
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    fn simple_tasks() -> Vec<(String, Vec<String>)> {
388        vec![
389            ("Research the topic".into(), vec!["research".into()]),
390            ("Write the summary".into(), vec!["summarization".into()]),
391            ("Review the output".into(), vec!["review".into()]),
392        ]
393    }
394
395    #[test]
396    fn create_workflow() {
397        let mut orch = Orchestrator::new();
398        let id = orch.create_workflow(
399            "Test Flow",
400            OrchestrationPattern::Sequential,
401            simple_tasks(),
402        );
403        assert!(id.starts_with("wf_"));
404        let wf = orch.get_workflow(&id).unwrap();
405        assert_eq!(wf.subtasks.len(), 3);
406        assert_eq!(wf.status, WorkflowStatus::Created);
407        assert!(wf.subtasks.iter().all(|t| t.model_preference.is_none()));
408    }
409
410    #[test]
411    fn create_workflow_with_model_preferences() {
412        let mut orch = Orchestrator::new();
413        let id = orch.create_workflow_with_model_preferences(
414            "Model Aware Flow",
415            OrchestrationPattern::Parallel,
416            vec![
417                (
418                    "Draft summary".into(),
419                    vec!["summarization".into()],
420                    Some("ollama/qwen3:8b".into()),
421                ),
422                ("Review output".into(), vec!["review".into()], None),
423            ],
424        );
425        let wf = orch.get_workflow(&id).unwrap();
426        assert_eq!(
427            wf.subtasks[0].model_preference.as_deref(),
428            Some("ollama/qwen3:8b")
429        );
430        assert!(wf.subtasks[1].model_preference.is_none());
431    }
432
433    #[test]
434    fn assign_and_start() {
435        let mut orch = Orchestrator::new();
436        let wf_id = orch.create_workflow("Test", OrchestrationPattern::Sequential, simple_tasks());
437        let task_id = orch.get_workflow(&wf_id).unwrap().subtasks[0].id.clone();
438
439        orch.assign_agent(&wf_id, &task_id, "agent-research")
440            .unwrap();
441        let task = &orch.get_workflow(&wf_id).unwrap().subtasks[0];
442        assert_eq!(task.status, SubtaskStatus::Assigned);
443        assert_eq!(task.assigned_agent.as_deref(), Some("agent-research"));
444
445        orch.start_task(&wf_id, &task_id).unwrap();
446        assert_eq!(
447            orch.get_workflow(&wf_id).unwrap().subtasks[0].status,
448            SubtaskStatus::Running
449        );
450    }
451
452    #[test]
453    fn set_task_model_preference_updates_task() {
454        let mut orch = Orchestrator::new();
455        let wf_id = orch.create_workflow(
456            "Model Edit",
457            OrchestrationPattern::Sequential,
458            simple_tasks(),
459        );
460        let task_id = orch.get_workflow(&wf_id).unwrap().subtasks[0].id.clone();
461        orch.set_task_model_preference(&wf_id, &task_id, Some("openai/gpt-4o".into()))
462            .unwrap();
463        let task = &orch.get_workflow(&wf_id).unwrap().subtasks[0];
464        assert_eq!(task.model_preference.as_deref(), Some("openai/gpt-4o"));
465    }
466
467    #[test]
468    fn complete_workflow() {
469        let mut orch = Orchestrator::new();
470        let wf_id = orch.create_workflow("Test", OrchestrationPattern::Parallel, simple_tasks());
471        let task_ids: Vec<String> = orch
472            .get_workflow(&wf_id)
473            .unwrap()
474            .subtasks
475            .iter()
476            .map(|t| t.id.clone())
477            .collect();
478
479        for tid in &task_ids {
480            orch.complete_task(&wf_id, tid, "done".into()).unwrap();
481        }
482
483        let wf = orch.get_workflow(&wf_id).unwrap();
484        assert_eq!(wf.status, WorkflowStatus::Completed);
485        assert!(wf.completed_at.is_some());
486    }
487
488    #[test]
489    fn fail_task_fails_workflow() {
490        let mut orch = Orchestrator::new();
491        let wf_id = orch.create_workflow("Test", OrchestrationPattern::Sequential, simple_tasks());
492        let task_id = orch.get_workflow(&wf_id).unwrap().subtasks[0].id.clone();
493
494        orch.fail_task(&wf_id, &task_id, "something broke").unwrap();
495        assert_eq!(
496            orch.get_workflow(&wf_id).unwrap().status,
497            WorkflowStatus::Failed
498        );
499    }
500
501    #[test]
502    fn sequential_next_tasks() {
503        let mut orch = Orchestrator::new();
504        let wf_id = orch.create_workflow("Seq", OrchestrationPattern::Sequential, simple_tasks());
505
506        let next = orch.next_tasks(&wf_id).unwrap();
507        assert_eq!(next.len(), 1);
508        assert_eq!(next[0].description, "Research the topic");
509    }
510
511    #[test]
512    fn parallel_next_tasks() {
513        let mut orch = Orchestrator::new();
514        let wf_id = orch.create_workflow("Par", OrchestrationPattern::Parallel, simple_tasks());
515
516        let next = orch.next_tasks(&wf_id).unwrap();
517        assert_eq!(next.len(), 3);
518    }
519
520    #[test]
521    fn capability_matching() {
522        let mut orch = Orchestrator::new();
523        let wf_id = orch.create_workflow("Match", OrchestrationPattern::Parallel, simple_tasks());
524
525        let agents = vec![
526            (
527                "researcher".into(),
528                vec!["research".into(), "analysis".into()],
529            ),
530            (
531                "writer".into(),
532                vec!["summarization".into(), "writing".into()],
533            ),
534        ];
535
536        let matches = orch.match_capabilities(&wf_id, &agents).unwrap();
537        assert!(!matches.is_empty());
538    }
539
540    #[test]
541    fn pattern_display() {
542        assert_eq!(
543            format!("{}", OrchestrationPattern::Sequential),
544            "sequential"
545        );
546        assert_eq!(format!("{}", OrchestrationPattern::Parallel), "parallel");
547        assert_eq!(
548            format!("{}", OrchestrationPattern::FanOutFanIn),
549            "fan-out/fan-in"
550        );
551        assert_eq!(format!("{}", OrchestrationPattern::Handoff), "handoff");
552    }
553
554    #[test]
555    fn pattern_serde() {
556        for pattern in [
557            OrchestrationPattern::Sequential,
558            OrchestrationPattern::Parallel,
559            OrchestrationPattern::FanOutFanIn,
560            OrchestrationPattern::Handoff,
561        ] {
562            let json = serde_json::to_string(&pattern).unwrap();
563            let back: OrchestrationPattern = serde_json::from_str(&json).unwrap();
564            assert_eq!(pattern, back);
565        }
566    }
567
568    #[test]
569    fn handoff_skips_failed_tasks() {
570        let mut orch = Orchestrator::new();
571        let wf_id = orch.create_workflow("Handoff", OrchestrationPattern::Handoff, simple_tasks());
572        let task_ids: Vec<String> = orch
573            .get_workflow(&wf_id)
574            .unwrap()
575            .subtasks
576            .iter()
577            .map(|t| t.id.clone())
578            .collect();
579
580        // Complete first task, fail second
581        orch.complete_task(&wf_id, &task_ids[0], "done".into())
582            .unwrap();
583        orch.fail_task(&wf_id, &task_ids[1], "broken").unwrap();
584
585        // Handoff should skip the Failed task and return the third (Pending) task
586        let next = orch.next_tasks(&wf_id).unwrap();
587        assert_eq!(next.len(), 1);
588        assert_eq!(next[0].description, "Review the output");
589    }
590
591    #[test]
592    fn workflow_not_found() {
593        let orch = Orchestrator::new();
594        assert!(orch.get_workflow("nope").is_none());
595        assert!(orch.next_tasks("nope").is_err());
596    }
597}