1use chrono::{DateTime, Utc};
2use roboticus_core::{Result, RoboticusError};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use tracing::{debug, info, warn};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9pub enum OrchestrationPattern {
10 Sequential,
11 Parallel,
12 FanOutFanIn,
13 Handoff,
14}
15
16impl std::fmt::Display for OrchestrationPattern {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 match self {
19 OrchestrationPattern::Sequential => write!(f, "sequential"),
20 OrchestrationPattern::Parallel => write!(f, "parallel"),
21 OrchestrationPattern::FanOutFanIn => write!(f, "fan-out/fan-in"),
22 OrchestrationPattern::Handoff => write!(f, "handoff"),
23 }
24 }
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct Subtask {
30 pub id: String,
31 pub description: String,
32 pub required_capabilities: Vec<String>,
33 #[serde(default)]
34 pub model_preference: Option<String>,
35 pub assigned_agent: Option<String>,
36 pub status: SubtaskStatus,
37 pub result: Option<String>,
38 pub created_at: DateTime<Utc>,
39 pub completed_at: Option<DateTime<Utc>>,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
44pub enum SubtaskStatus {
45 Pending,
46 Assigned,
47 Running,
48 Completed,
49 Failed,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct Workflow {
55 pub id: String,
56 pub name: String,
57 pub pattern: OrchestrationPattern,
58 pub subtasks: Vec<Subtask>,
59 pub status: WorkflowStatus,
60 pub created_at: DateTime<Utc>,
61 pub completed_at: Option<DateTime<Utc>>,
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
66pub enum WorkflowStatus {
67 Created,
68 Running,
69 Completed,
70 Failed,
71 Cancelled,
72}
73
74pub struct Orchestrator {
76 workflows: HashMap<String, Workflow>,
77 workflow_counter: u64,
78}
79
80impl Orchestrator {
81 pub fn new() -> Self {
82 Self {
83 workflows: HashMap::new(),
84 workflow_counter: 0,
85 }
86 }
87
88 pub fn create_workflow(
90 &mut self,
91 name: &str,
92 pattern: OrchestrationPattern,
93 subtasks: Vec<(String, Vec<String>)>,
94 ) -> String {
95 self.workflow_counter += 1;
96 let workflow_id = format!("wf_{}", self.workflow_counter);
97
98 let tasks: Vec<Subtask> = subtasks
99 .into_iter()
100 .enumerate()
101 .map(|(i, (desc, caps))| Subtask {
102 id: format!("{}_task_{}", workflow_id, i),
103 description: desc,
104 required_capabilities: caps,
105 model_preference: None,
106 assigned_agent: None,
107 status: SubtaskStatus::Pending,
108 result: None,
109 created_at: Utc::now(),
110 completed_at: None,
111 })
112 .collect();
113
114 let workflow = Workflow {
115 id: workflow_id.clone(),
116 name: name.to_string(),
117 pattern,
118 subtasks: tasks,
119 status: WorkflowStatus::Created,
120 created_at: Utc::now(),
121 completed_at: None,
122 };
123
124 info!(id = %workflow_id, name, pattern = %pattern, tasks = workflow.subtasks.len(), "created workflow");
125 self.workflows.insert(workflow_id.clone(), workflow);
126 workflow_id
127 }
128
129 pub fn create_workflow_with_model_preferences(
131 &mut self,
132 name: &str,
133 pattern: OrchestrationPattern,
134 subtasks: Vec<(String, Vec<String>, Option<String>)>,
135 ) -> String {
136 self.workflow_counter += 1;
137 let workflow_id = format!("wf_{}", self.workflow_counter);
138
139 let tasks: Vec<Subtask> = subtasks
140 .into_iter()
141 .enumerate()
142 .map(|(i, (desc, caps, model_pref))| Subtask {
143 id: format!("{}_task_{}", workflow_id, i),
144 description: desc,
145 required_capabilities: caps,
146 model_preference: model_pref,
147 assigned_agent: None,
148 status: SubtaskStatus::Pending,
149 result: None,
150 created_at: Utc::now(),
151 completed_at: None,
152 })
153 .collect();
154
155 let workflow = Workflow {
156 id: workflow_id.clone(),
157 name: name.to_string(),
158 pattern,
159 subtasks: tasks,
160 status: WorkflowStatus::Created,
161 created_at: Utc::now(),
162 completed_at: None,
163 };
164
165 info!(
166 id = %workflow_id,
167 name,
168 pattern = %pattern,
169 tasks = workflow.subtasks.len(),
170 "created workflow with model preferences"
171 );
172 self.workflows.insert(workflow_id.clone(), workflow);
173 workflow_id
174 }
175
176 pub fn assign_agent(&mut self, workflow_id: &str, task_id: &str, agent_id: &str) -> Result<()> {
178 let workflow = self.workflows.get_mut(workflow_id).ok_or_else(|| {
179 RoboticusError::Config(format!("workflow '{}' not found", workflow_id))
180 })?;
181
182 let task = workflow
183 .subtasks
184 .iter_mut()
185 .find(|t| t.id == task_id)
186 .ok_or_else(|| RoboticusError::Config(format!("task '{}' not found", task_id)))?;
187
188 task.assigned_agent = Some(agent_id.to_string());
189 task.status = SubtaskStatus::Assigned;
190 debug!(
191 workflow = workflow_id,
192 task = task_id,
193 agent = agent_id,
194 "agent assigned"
195 );
196 Ok(())
197 }
198
199 pub fn set_task_model_preference(
201 &mut self,
202 workflow_id: &str,
203 task_id: &str,
204 model_preference: Option<String>,
205 ) -> Result<()> {
206 let workflow = self.workflows.get_mut(workflow_id).ok_or_else(|| {
207 RoboticusError::Config(format!("workflow '{}' not found", workflow_id))
208 })?;
209 let task = workflow
210 .subtasks
211 .iter_mut()
212 .find(|t| t.id == task_id)
213 .ok_or_else(|| RoboticusError::Config(format!("task '{}' not found", task_id)))?;
214 task.model_preference = model_preference;
215 Ok(())
216 }
217
218 pub fn match_capabilities(
220 &self,
221 workflow_id: &str,
222 available_agents: &[(String, Vec<String>)],
223 ) -> Result<Vec<(String, String)>> {
224 let workflow = self.workflows.get(workflow_id).ok_or_else(|| {
225 RoboticusError::Config(format!("workflow '{}' not found", workflow_id))
226 })?;
227
228 let mut assignments = Vec::new();
229
230 for task in &workflow.subtasks {
231 if task.status != SubtaskStatus::Pending {
232 continue;
233 }
234
235 let best_agent = available_agents.iter().max_by_key(|(_, caps)| {
236 task.required_capabilities
237 .iter()
238 .filter(|rc| caps.contains(rc))
239 .count()
240 });
241
242 if let Some((agent_id, caps)) = best_agent {
243 let overlap = task
244 .required_capabilities
245 .iter()
246 .filter(|rc| caps.contains(rc))
247 .count();
248 if overlap > 0 {
249 assignments.push((task.id.clone(), agent_id.clone()));
250 }
251 }
252 }
253
254 Ok(assignments)
255 }
256
257 pub fn start_task(&mut self, workflow_id: &str, task_id: &str) -> Result<()> {
259 let workflow = self.workflows.get_mut(workflow_id).ok_or_else(|| {
260 RoboticusError::Config(format!("workflow '{}' not found", workflow_id))
261 })?;
262
263 let task = workflow
264 .subtasks
265 .iter_mut()
266 .find(|t| t.id == task_id)
267 .ok_or_else(|| RoboticusError::Config(format!("task '{}' not found", task_id)))?;
268
269 task.status = SubtaskStatus::Running;
270 workflow.status = WorkflowStatus::Running;
271 Ok(())
272 }
273
274 pub fn complete_task(
276 &mut self,
277 workflow_id: &str,
278 task_id: &str,
279 result: String,
280 ) -> Result<()> {
281 let workflow = self.workflows.get_mut(workflow_id).ok_or_else(|| {
282 RoboticusError::Config(format!("workflow '{}' not found", workflow_id))
283 })?;
284
285 let task = workflow
286 .subtasks
287 .iter_mut()
288 .find(|t| t.id == task_id)
289 .ok_or_else(|| RoboticusError::Config(format!("task '{}' not found", task_id)))?;
290
291 task.status = SubtaskStatus::Completed;
292 task.result = Some(result);
293 task.completed_at = Some(Utc::now());
294
295 if workflow
296 .subtasks
297 .iter()
298 .all(|t| t.status == SubtaskStatus::Completed)
299 {
300 workflow.status = WorkflowStatus::Completed;
301 workflow.completed_at = Some(Utc::now());
302 info!(id = %workflow_id, "workflow completed");
303 }
304
305 Ok(())
306 }
307
308 pub fn fail_task(&mut self, workflow_id: &str, task_id: &str, error: &str) -> Result<()> {
310 let workflow = self.workflows.get_mut(workflow_id).ok_or_else(|| {
311 RoboticusError::Config(format!("workflow '{}' not found", workflow_id))
312 })?;
313
314 let task = workflow
315 .subtasks
316 .iter_mut()
317 .find(|t| t.id == task_id)
318 .ok_or_else(|| RoboticusError::Config(format!("task '{}' not found", task_id)))?;
319
320 task.status = SubtaskStatus::Failed;
321 task.result = Some(format!("ERROR: {}", error));
322 task.completed_at = Some(Utc::now());
323
324 workflow.status = WorkflowStatus::Failed;
325 warn!(workflow = workflow_id, task = task_id, error, "task failed");
326 Ok(())
327 }
328
329 pub fn get_workflow(&self, id: &str) -> Option<&Workflow> {
331 self.workflows.get(id)
332 }
333
334 pub fn next_tasks(&self, workflow_id: &str) -> Result<Vec<&Subtask>> {
336 let workflow = self.workflows.get(workflow_id).ok_or_else(|| {
337 RoboticusError::Config(format!("workflow '{}' not found", workflow_id))
338 })?;
339
340 match workflow.pattern {
341 OrchestrationPattern::Sequential => Ok(workflow
342 .subtasks
343 .iter()
344 .find(|t| t.status == SubtaskStatus::Pending || t.status == SubtaskStatus::Assigned)
345 .into_iter()
346 .collect()),
347 OrchestrationPattern::Parallel | OrchestrationPattern::FanOutFanIn => Ok(workflow
348 .subtasks
349 .iter()
350 .filter(|t| {
351 t.status == SubtaskStatus::Pending || t.status == SubtaskStatus::Assigned
352 })
353 .collect()),
354 OrchestrationPattern::Handoff => {
355 let last_completed = workflow
356 .subtasks
357 .iter()
358 .rposition(|t| t.status == SubtaskStatus::Completed);
359 let start_idx = last_completed.map(|i| i + 1).unwrap_or(0);
360 Ok(workflow.subtasks[start_idx..]
362 .iter()
363 .find(|t| {
364 t.status == SubtaskStatus::Pending || t.status == SubtaskStatus::Assigned
365 })
366 .into_iter()
367 .collect())
368 }
369 }
370 }
371
372 pub fn workflow_count(&self) -> usize {
373 self.workflows.len()
374 }
375}
376
377impl Default for Orchestrator {
378 fn default() -> Self {
379 Self::new()
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386
387 fn simple_tasks() -> Vec<(String, Vec<String>)> {
388 vec![
389 ("Research the topic".into(), vec!["research".into()]),
390 ("Write the summary".into(), vec!["summarization".into()]),
391 ("Review the output".into(), vec!["review".into()]),
392 ]
393 }
394
395 #[test]
396 fn create_workflow() {
397 let mut orch = Orchestrator::new();
398 let id = orch.create_workflow(
399 "Test Flow",
400 OrchestrationPattern::Sequential,
401 simple_tasks(),
402 );
403 assert!(id.starts_with("wf_"));
404 let wf = orch.get_workflow(&id).unwrap();
405 assert_eq!(wf.subtasks.len(), 3);
406 assert_eq!(wf.status, WorkflowStatus::Created);
407 assert!(wf.subtasks.iter().all(|t| t.model_preference.is_none()));
408 }
409
410 #[test]
411 fn create_workflow_with_model_preferences() {
412 let mut orch = Orchestrator::new();
413 let id = orch.create_workflow_with_model_preferences(
414 "Model Aware Flow",
415 OrchestrationPattern::Parallel,
416 vec![
417 (
418 "Draft summary".into(),
419 vec!["summarization".into()],
420 Some("ollama/qwen3:8b".into()),
421 ),
422 ("Review output".into(), vec!["review".into()], None),
423 ],
424 );
425 let wf = orch.get_workflow(&id).unwrap();
426 assert_eq!(
427 wf.subtasks[0].model_preference.as_deref(),
428 Some("ollama/qwen3:8b")
429 );
430 assert!(wf.subtasks[1].model_preference.is_none());
431 }
432
433 #[test]
434 fn assign_and_start() {
435 let mut orch = Orchestrator::new();
436 let wf_id = orch.create_workflow("Test", OrchestrationPattern::Sequential, simple_tasks());
437 let task_id = orch.get_workflow(&wf_id).unwrap().subtasks[0].id.clone();
438
439 orch.assign_agent(&wf_id, &task_id, "agent-research")
440 .unwrap();
441 let task = &orch.get_workflow(&wf_id).unwrap().subtasks[0];
442 assert_eq!(task.status, SubtaskStatus::Assigned);
443 assert_eq!(task.assigned_agent.as_deref(), Some("agent-research"));
444
445 orch.start_task(&wf_id, &task_id).unwrap();
446 assert_eq!(
447 orch.get_workflow(&wf_id).unwrap().subtasks[0].status,
448 SubtaskStatus::Running
449 );
450 }
451
452 #[test]
453 fn set_task_model_preference_updates_task() {
454 let mut orch = Orchestrator::new();
455 let wf_id = orch.create_workflow(
456 "Model Edit",
457 OrchestrationPattern::Sequential,
458 simple_tasks(),
459 );
460 let task_id = orch.get_workflow(&wf_id).unwrap().subtasks[0].id.clone();
461 orch.set_task_model_preference(&wf_id, &task_id, Some("openai/gpt-4o".into()))
462 .unwrap();
463 let task = &orch.get_workflow(&wf_id).unwrap().subtasks[0];
464 assert_eq!(task.model_preference.as_deref(), Some("openai/gpt-4o"));
465 }
466
467 #[test]
468 fn complete_workflow() {
469 let mut orch = Orchestrator::new();
470 let wf_id = orch.create_workflow("Test", OrchestrationPattern::Parallel, simple_tasks());
471 let task_ids: Vec<String> = orch
472 .get_workflow(&wf_id)
473 .unwrap()
474 .subtasks
475 .iter()
476 .map(|t| t.id.clone())
477 .collect();
478
479 for tid in &task_ids {
480 orch.complete_task(&wf_id, tid, "done".into()).unwrap();
481 }
482
483 let wf = orch.get_workflow(&wf_id).unwrap();
484 assert_eq!(wf.status, WorkflowStatus::Completed);
485 assert!(wf.completed_at.is_some());
486 }
487
488 #[test]
489 fn fail_task_fails_workflow() {
490 let mut orch = Orchestrator::new();
491 let wf_id = orch.create_workflow("Test", OrchestrationPattern::Sequential, simple_tasks());
492 let task_id = orch.get_workflow(&wf_id).unwrap().subtasks[0].id.clone();
493
494 orch.fail_task(&wf_id, &task_id, "something broke").unwrap();
495 assert_eq!(
496 orch.get_workflow(&wf_id).unwrap().status,
497 WorkflowStatus::Failed
498 );
499 }
500
501 #[test]
502 fn sequential_next_tasks() {
503 let mut orch = Orchestrator::new();
504 let wf_id = orch.create_workflow("Seq", OrchestrationPattern::Sequential, simple_tasks());
505
506 let next = orch.next_tasks(&wf_id).unwrap();
507 assert_eq!(next.len(), 1);
508 assert_eq!(next[0].description, "Research the topic");
509 }
510
511 #[test]
512 fn parallel_next_tasks() {
513 let mut orch = Orchestrator::new();
514 let wf_id = orch.create_workflow("Par", OrchestrationPattern::Parallel, simple_tasks());
515
516 let next = orch.next_tasks(&wf_id).unwrap();
517 assert_eq!(next.len(), 3);
518 }
519
520 #[test]
521 fn capability_matching() {
522 let mut orch = Orchestrator::new();
523 let wf_id = orch.create_workflow("Match", OrchestrationPattern::Parallel, simple_tasks());
524
525 let agents = vec![
526 (
527 "researcher".into(),
528 vec!["research".into(), "analysis".into()],
529 ),
530 (
531 "writer".into(),
532 vec!["summarization".into(), "writing".into()],
533 ),
534 ];
535
536 let matches = orch.match_capabilities(&wf_id, &agents).unwrap();
537 assert!(!matches.is_empty());
538 }
539
540 #[test]
541 fn pattern_display() {
542 assert_eq!(
543 format!("{}", OrchestrationPattern::Sequential),
544 "sequential"
545 );
546 assert_eq!(format!("{}", OrchestrationPattern::Parallel), "parallel");
547 assert_eq!(
548 format!("{}", OrchestrationPattern::FanOutFanIn),
549 "fan-out/fan-in"
550 );
551 assert_eq!(format!("{}", OrchestrationPattern::Handoff), "handoff");
552 }
553
554 #[test]
555 fn pattern_serde() {
556 for pattern in [
557 OrchestrationPattern::Sequential,
558 OrchestrationPattern::Parallel,
559 OrchestrationPattern::FanOutFanIn,
560 OrchestrationPattern::Handoff,
561 ] {
562 let json = serde_json::to_string(&pattern).unwrap();
563 let back: OrchestrationPattern = serde_json::from_str(&json).unwrap();
564 assert_eq!(pattern, back);
565 }
566 }
567
568 #[test]
569 fn handoff_skips_failed_tasks() {
570 let mut orch = Orchestrator::new();
571 let wf_id = orch.create_workflow("Handoff", OrchestrationPattern::Handoff, simple_tasks());
572 let task_ids: Vec<String> = orch
573 .get_workflow(&wf_id)
574 .unwrap()
575 .subtasks
576 .iter()
577 .map(|t| t.id.clone())
578 .collect();
579
580 orch.complete_task(&wf_id, &task_ids[0], "done".into())
582 .unwrap();
583 orch.fail_task(&wf_id, &task_ids[1], "broken").unwrap();
584
585 let next = orch.next_tasks(&wf_id).unwrap();
587 assert_eq!(next.len(), 1);
588 assert_eq!(next[0].description, "Review the output");
589 }
590
591 #[test]
592 fn workflow_not_found() {
593 let orch = Orchestrator::new();
594 assert!(orch.get_workflow("nope").is_none());
595 assert!(orch.next_tasks("nope").is_err());
596 }
597}