celers_protocol/
workflow.rs

1//! Workflow and task chain utilities
2//!
3//! This module provides helpers for building and managing task workflows,
4//! chains, and directed acyclic graphs (DAGs) of tasks.
5
6use crate::{builder::MessageBuilder, Message};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet, VecDeque};
9use uuid::Uuid;
10
11/// A task in a workflow
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct WorkflowTask {
14    /// Unique task identifier
15    pub id: Uuid,
16    /// Task name
17    pub task_name: String,
18    /// Task arguments (JSON)
19    pub args: Vec<serde_json::Value>,
20    /// Task keyword arguments (JSON)
21    pub kwargs: HashMap<String, serde_json::Value>,
22    /// Dependencies (task IDs that must complete first)
23    pub dependencies: Vec<Uuid>,
24}
25
26impl WorkflowTask {
27    /// Create a new workflow task
28    pub fn new(task_name: impl Into<String>) -> Self {
29        Self {
30            id: Uuid::new_v4(),
31            task_name: task_name.into(),
32            args: Vec::new(),
33            kwargs: HashMap::new(),
34            dependencies: Vec::new(),
35        }
36    }
37
38    /// Set task arguments
39    #[must_use]
40    pub fn with_args(mut self, args: Vec<serde_json::Value>) -> Self {
41        self.args = args;
42        self
43    }
44
45    /// Set task keyword arguments
46    #[must_use]
47    pub fn with_kwargs(mut self, kwargs: HashMap<String, serde_json::Value>) -> Self {
48        self.kwargs = kwargs;
49        self
50    }
51
52    /// Add a dependency on another task
53    #[must_use]
54    pub fn depends_on(mut self, task_id: Uuid) -> Self {
55        self.dependencies.push(task_id);
56        self
57    }
58
59    /// Add multiple dependencies
60    #[must_use]
61    pub fn depends_on_many(mut self, task_ids: Vec<Uuid>) -> Self {
62        self.dependencies.extend(task_ids);
63        self
64    }
65
66    /// Convert to a Message with workflow metadata
67    pub fn to_message(&self, root_id: Option<Uuid>, parent_id: Option<Uuid>) -> Message {
68        let mut builder = MessageBuilder::new(&self.task_name)
69            .id(self.id)
70            .args(self.args.clone())
71            .kwargs(self.kwargs.clone());
72
73        if let Some(root) = root_id {
74            builder = builder.root(root);
75        }
76
77        if let Some(parent) = parent_id {
78            builder = builder.parent(parent);
79        }
80
81        builder.build().expect("Failed to build message")
82    }
83}
84
85/// A workflow of tasks with dependencies
86#[derive(Debug, Clone)]
87pub struct Workflow {
88    /// All tasks in the workflow
89    tasks: HashMap<Uuid, WorkflowTask>,
90    /// Root task ID (entry point)
91    root_id: Option<Uuid>,
92    /// Workflow name
93    name: String,
94}
95
96impl Workflow {
97    /// Create a new workflow
98    pub fn new(name: impl Into<String>) -> Self {
99        Self {
100            tasks: HashMap::new(),
101            root_id: None,
102            name: name.into(),
103        }
104    }
105
106    /// Add a task to the workflow
107    pub fn add_task(&mut self, task: WorkflowTask) -> Uuid {
108        let id = task.id;
109        if self.root_id.is_none() && task.dependencies.is_empty() {
110            self.root_id = Some(id);
111        }
112        self.tasks.insert(id, task);
113        id
114    }
115
116    /// Get a task by ID
117    pub fn get_task(&self, id: &Uuid) -> Option<&WorkflowTask> {
118        self.tasks.get(id)
119    }
120
121    /// Set the root task
122    pub fn set_root(&mut self, task_id: Uuid) {
123        if self.tasks.contains_key(&task_id) {
124            self.root_id = Some(task_id);
125        }
126    }
127
128    /// Get all tasks with no dependencies (entry points)
129    pub fn get_entry_tasks(&self) -> Vec<&WorkflowTask> {
130        self.tasks
131            .values()
132            .filter(|task| task.dependencies.is_empty())
133            .collect()
134    }
135
136    /// Get tasks that depend on a specific task
137    pub fn get_dependent_tasks(&self, task_id: &Uuid) -> Vec<&WorkflowTask> {
138        self.tasks
139            .values()
140            .filter(|task| task.dependencies.contains(task_id))
141            .collect()
142    }
143
144    /// Check if the workflow has cycles (invalid)
145    pub fn has_cycles(&self) -> bool {
146        let mut visited = HashSet::new();
147        let mut rec_stack = HashSet::new();
148
149        for task_id in self.tasks.keys() {
150            if self.has_cycle_dfs(task_id, &mut visited, &mut rec_stack) {
151                return true;
152            }
153        }
154
155        false
156    }
157
158    fn has_cycle_dfs(
159        &self,
160        task_id: &Uuid,
161        visited: &mut HashSet<Uuid>,
162        rec_stack: &mut HashSet<Uuid>,
163    ) -> bool {
164        if rec_stack.contains(task_id) {
165            return true;
166        }
167
168        if visited.contains(task_id) {
169            return false;
170        }
171
172        visited.insert(*task_id);
173        rec_stack.insert(*task_id);
174
175        if let Some(task) = self.tasks.get(task_id) {
176            for dep_id in &task.dependencies {
177                if self.has_cycle_dfs(dep_id, visited, rec_stack) {
178                    return true;
179                }
180            }
181        }
182
183        rec_stack.remove(task_id);
184        false
185    }
186
187    /// Get tasks in topological order (execution order)
188    pub fn topological_sort(&self) -> Result<Vec<Uuid>, String> {
189        if self.has_cycles() {
190            return Err("Workflow contains cycles".to_string());
191        }
192
193        let mut in_degree: HashMap<Uuid, usize> = HashMap::new();
194        let mut adj_list: HashMap<Uuid, Vec<Uuid>> = HashMap::new();
195
196        // Initialize all tasks
197        for id in self.tasks.keys() {
198            in_degree.insert(*id, 0);
199            adj_list.insert(*id, Vec::new());
200        }
201
202        // Build adjacency list and count in-degrees
203        for (id, task) in &self.tasks {
204            for &dep_id in &task.dependencies {
205                // dep_id -> id (dep_id points to id)
206                adj_list.entry(dep_id).or_default().push(*id);
207                *in_degree.entry(*id).or_insert(0) += 1;
208            }
209        }
210
211        // Find all tasks with no dependencies
212        let mut queue: VecDeque<Uuid> = in_degree
213            .iter()
214            .filter(|(_, &degree)| degree == 0)
215            .map(|(&id, _)| id)
216            .collect();
217
218        let mut sorted = Vec::new();
219
220        while let Some(task_id) = queue.pop_front() {
221            sorted.push(task_id);
222
223            // For all tasks that depend on this task
224            if let Some(dependents) = adj_list.get(&task_id) {
225                for &dependent_id in dependents {
226                    if let Some(degree) = in_degree.get_mut(&dependent_id) {
227                        *degree -= 1;
228                        if *degree == 0 {
229                            queue.push_back(dependent_id);
230                        }
231                    }
232                }
233            }
234        }
235
236        if sorted.len() != self.tasks.len() {
237            Err("Could not complete topological sort".to_string())
238        } else {
239            Ok(sorted)
240        }
241    }
242
243    /// Convert workflow to messages in execution order
244    pub fn to_messages(&self) -> Result<Vec<Message>, String> {
245        let order = self.topological_sort()?;
246        let root_id = self.root_id.unwrap_or_else(|| order[0]);
247
248        let messages = order
249            .into_iter()
250            .filter_map(|task_id| {
251                self.tasks.get(&task_id).map(|task| {
252                    let parent_id = if task.dependencies.is_empty() {
253                        None
254                    } else {
255                        task.dependencies.first().copied()
256                    };
257                    task.to_message(Some(root_id), parent_id)
258                })
259            })
260            .collect();
261
262        Ok(messages)
263    }
264
265    /// Get the workflow name
266    #[inline]
267    pub fn name(&self) -> &str {
268        &self.name
269    }
270
271    /// Get the number of tasks in the workflow
272    #[inline]
273    pub fn len(&self) -> usize {
274        self.tasks.len()
275    }
276
277    /// Check if the workflow is empty
278    #[inline]
279    pub fn is_empty(&self) -> bool {
280        self.tasks.is_empty()
281    }
282}
283
284/// Builder for creating task chains (linear workflows)
285#[derive(Debug, Clone)]
286pub struct ChainBuilder {
287    tasks: Vec<WorkflowTask>,
288    name: String,
289}
290
291impl ChainBuilder {
292    /// Create a new chain builder
293    pub fn new(name: impl Into<String>) -> Self {
294        Self {
295            tasks: Vec::new(),
296            name: name.into(),
297        }
298    }
299
300    /// Add a task to the chain
301    #[must_use]
302    pub fn then(mut self, task_name: impl Into<String>) -> Self {
303        let task = WorkflowTask::new(task_name);
304        self.tasks.push(task);
305        self
306    }
307
308    /// Add a task with arguments
309    #[must_use]
310    pub fn then_with_args(
311        mut self,
312        task_name: impl Into<String>,
313        args: Vec<serde_json::Value>,
314    ) -> Self {
315        let task = WorkflowTask::new(task_name).with_args(args);
316        self.tasks.push(task);
317        self
318    }
319
320    /// Build the chain as a workflow
321    pub fn build(self) -> Workflow {
322        let mut workflow = Workflow::new(self.name);
323
324        let mut prev_id: Option<Uuid> = None;
325
326        for mut task in self.tasks {
327            if let Some(prev) = prev_id {
328                task = task.depends_on(prev);
329            }
330            prev_id = Some(task.id);
331            workflow.add_task(task);
332        }
333
334        workflow
335    }
336
337    /// Build and convert to messages
338    pub fn build_messages(self) -> Result<Vec<Message>, String> {
339        self.build().to_messages()
340    }
341}
342
343/// Group multiple tasks to run in parallel
344#[derive(Debug, Clone)]
345pub struct Group {
346    tasks: Vec<WorkflowTask>,
347    group_id: Uuid,
348}
349
350impl Group {
351    /// Create a new task group
352    pub fn new() -> Self {
353        Self {
354            tasks: Vec::new(),
355            group_id: Uuid::new_v4(),
356        }
357    }
358
359    /// Add a task to the group
360    #[must_use]
361    pub fn with_task(mut self, task: WorkflowTask) -> Self {
362        self.tasks.push(task);
363        self
364    }
365
366    /// Add a simple task by name
367    #[must_use]
368    pub fn add_task(mut self, task_name: impl Into<String>) -> Self {
369        self.tasks.push(WorkflowTask::new(task_name));
370        self
371    }
372
373    /// Convert to messages with group ID
374    pub fn to_messages(&self) -> Vec<Message> {
375        self.tasks
376            .iter()
377            .map(|task| {
378                let mut msg = task.to_message(None, None);
379                msg.headers.group = Some(self.group_id);
380                msg
381            })
382            .collect()
383    }
384
385    /// Get the group ID
386    #[inline]
387    pub fn id(&self) -> Uuid {
388        self.group_id
389    }
390
391    /// Get the number of tasks in the group
392    #[inline]
393    pub fn len(&self) -> usize {
394        self.tasks.len()
395    }
396
397    /// Check if the group is empty
398    #[inline]
399    pub fn is_empty(&self) -> bool {
400        self.tasks.is_empty()
401    }
402}
403
404impl Default for Group {
405    fn default() -> Self {
406        Self::new()
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413
414    #[test]
415    fn test_workflow_task_creation() {
416        let task = WorkflowTask::new("tasks.add")
417            .with_args(vec![serde_json::json!(1), serde_json::json!(2)]);
418
419        assert_eq!(task.task_name, "tasks.add");
420        assert_eq!(task.args.len(), 2);
421        assert!(task.dependencies.is_empty());
422    }
423
424    #[test]
425    fn test_workflow_task_dependencies() {
426        let task1_id = Uuid::new_v4();
427        let task2_id = Uuid::new_v4();
428
429        let task = WorkflowTask::new("task3")
430            .depends_on(task1_id)
431            .depends_on(task2_id);
432
433        assert_eq!(task.dependencies.len(), 2);
434        assert!(task.dependencies.contains(&task1_id));
435        assert!(task.dependencies.contains(&task2_id));
436    }
437
438    #[test]
439    fn test_workflow_add_task() {
440        let mut workflow = Workflow::new("test_workflow");
441        let task = WorkflowTask::new("tasks.test");
442        let task_id = workflow.add_task(task);
443
444        assert_eq!(workflow.len(), 1);
445        assert!(workflow.get_task(&task_id).is_some());
446    }
447
448    #[test]
449    fn test_workflow_entry_tasks() {
450        let mut workflow = Workflow::new("test");
451
452        let task1 = WorkflowTask::new("task1");
453        let task2 = WorkflowTask::new("task2");
454        let task1_id = task1.id;
455
456        workflow.add_task(task1);
457        workflow.add_task(task2);
458
459        let task3 = WorkflowTask::new("task3").depends_on(task1_id);
460        workflow.add_task(task3);
461
462        let entry_tasks = workflow.get_entry_tasks();
463        assert_eq!(entry_tasks.len(), 2); // task1 and task2 have no dependencies
464    }
465
466    #[test]
467    fn test_workflow_dependent_tasks() {
468        let mut workflow = Workflow::new("test");
469
470        let task1 = WorkflowTask::new("task1");
471        let task1_id = task1.id;
472        workflow.add_task(task1);
473
474        let task2 = WorkflowTask::new("task2").depends_on(task1_id);
475        let task3 = WorkflowTask::new("task3").depends_on(task1_id);
476
477        workflow.add_task(task2);
478        workflow.add_task(task3);
479
480        let dependents = workflow.get_dependent_tasks(&task1_id);
481        assert_eq!(dependents.len(), 2);
482    }
483
484    #[test]
485    fn test_workflow_no_cycles() {
486        let mut workflow = Workflow::new("test");
487
488        let task1 = WorkflowTask::new("task1");
489        let task1_id = task1.id;
490        workflow.add_task(task1);
491
492        let task2 = WorkflowTask::new("task2").depends_on(task1_id);
493        workflow.add_task(task2);
494
495        assert!(!workflow.has_cycles());
496    }
497
498    #[test]
499    fn test_workflow_topological_sort() {
500        let mut workflow = Workflow::new("test");
501
502        let task1 = WorkflowTask::new("task1");
503        let task1_id = task1.id;
504        workflow.add_task(task1);
505
506        let task2 = WorkflowTask::new("task2").depends_on(task1_id);
507        let task2_id = task2.id;
508        workflow.add_task(task2);
509
510        let task3 = WorkflowTask::new("task3").depends_on(task2_id);
511        workflow.add_task(task3);
512
513        let sorted = workflow.topological_sort().unwrap();
514        assert_eq!(sorted.len(), 3);
515        assert_eq!(sorted[0], task1_id);
516        assert_eq!(sorted[1], task2_id);
517    }
518
519    #[test]
520    fn test_workflow_to_messages() {
521        let mut workflow = Workflow::new("test");
522
523        let task1 = WorkflowTask::new("task1");
524        let task1_id = task1.id;
525        workflow.add_task(task1);
526
527        let task2 = WorkflowTask::new("task2").depends_on(task1_id);
528        workflow.add_task(task2);
529
530        let messages = workflow.to_messages().unwrap();
531        assert_eq!(messages.len(), 2);
532    }
533
534    #[test]
535    fn test_chain_builder() {
536        let chain = ChainBuilder::new("my_chain")
537            .then("task1")
538            .then("task2")
539            .then("task3")
540            .build();
541
542        assert_eq!(chain.len(), 3);
543
544        let sorted = chain.topological_sort().unwrap();
545        assert_eq!(sorted.len(), 3);
546    }
547
548    #[test]
549    fn test_chain_builder_with_args() {
550        let chain = ChainBuilder::new("my_chain")
551            .then_with_args("task1", vec![serde_json::json!(42)])
552            .then("task2")
553            .build();
554
555        assert_eq!(chain.len(), 2);
556    }
557
558    #[test]
559    fn test_chain_to_messages() {
560        let messages = ChainBuilder::new("my_chain")
561            .then("task1")
562            .then("task2")
563            .build_messages()
564            .unwrap();
565
566        assert_eq!(messages.len(), 2);
567        assert!(messages[0].has_root());
568    }
569
570    #[test]
571    fn test_group_creation() {
572        let group = Group::new()
573            .add_task("task1")
574            .add_task("task2")
575            .add_task("task3");
576
577        assert_eq!(group.len(), 3);
578    }
579
580    #[test]
581    fn test_group_to_messages() {
582        let group = Group::new().add_task("task1").add_task("task2");
583
584        let messages = group.to_messages();
585        assert_eq!(messages.len(), 2);
586
587        // All messages should have the same group ID
588        let group_id = messages[0].headers.group.unwrap();
589        assert_eq!(messages[1].headers.group.unwrap(), group_id);
590    }
591
592    #[test]
593    fn test_workflow_task_to_message() {
594        let task = WorkflowTask::new("tasks.test").with_args(vec![serde_json::json!(1)]);
595
596        let root_id = Uuid::new_v4();
597        let parent_id = Uuid::new_v4();
598
599        let message = task.to_message(Some(root_id), Some(parent_id));
600
601        assert_eq!(message.headers.task, "tasks.test");
602        assert_eq!(message.headers.root_id, Some(root_id));
603        assert_eq!(message.headers.parent_id, Some(parent_id));
604    }
605}