1use crate::{builder::MessageBuilder, Message};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet, VecDeque};
9use uuid::Uuid;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct WorkflowTask {
14 pub id: Uuid,
16 pub task_name: String,
18 pub args: Vec<serde_json::Value>,
20 pub kwargs: HashMap<String, serde_json::Value>,
22 pub dependencies: Vec<Uuid>,
24}
25
26impl WorkflowTask {
27 pub fn new(task_name: impl Into<String>) -> Self {
29 Self {
30 id: Uuid::new_v4(),
31 task_name: task_name.into(),
32 args: Vec::new(),
33 kwargs: HashMap::new(),
34 dependencies: Vec::new(),
35 }
36 }
37
38 #[must_use]
40 pub fn with_args(mut self, args: Vec<serde_json::Value>) -> Self {
41 self.args = args;
42 self
43 }
44
45 #[must_use]
47 pub fn with_kwargs(mut self, kwargs: HashMap<String, serde_json::Value>) -> Self {
48 self.kwargs = kwargs;
49 self
50 }
51
52 #[must_use]
54 pub fn depends_on(mut self, task_id: Uuid) -> Self {
55 self.dependencies.push(task_id);
56 self
57 }
58
59 #[must_use]
61 pub fn depends_on_many(mut self, task_ids: Vec<Uuid>) -> Self {
62 self.dependencies.extend(task_ids);
63 self
64 }
65
66 pub fn to_message(&self, root_id: Option<Uuid>, parent_id: Option<Uuid>) -> Message {
68 let mut builder = MessageBuilder::new(&self.task_name)
69 .id(self.id)
70 .args(self.args.clone())
71 .kwargs(self.kwargs.clone());
72
73 if let Some(root) = root_id {
74 builder = builder.root(root);
75 }
76
77 if let Some(parent) = parent_id {
78 builder = builder.parent(parent);
79 }
80
81 builder.build().expect("Failed to build message")
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct Workflow {
88 tasks: HashMap<Uuid, WorkflowTask>,
90 root_id: Option<Uuid>,
92 name: String,
94}
95
96impl Workflow {
97 pub fn new(name: impl Into<String>) -> Self {
99 Self {
100 tasks: HashMap::new(),
101 root_id: None,
102 name: name.into(),
103 }
104 }
105
106 pub fn add_task(&mut self, task: WorkflowTask) -> Uuid {
108 let id = task.id;
109 if self.root_id.is_none() && task.dependencies.is_empty() {
110 self.root_id = Some(id);
111 }
112 self.tasks.insert(id, task);
113 id
114 }
115
116 pub fn get_task(&self, id: &Uuid) -> Option<&WorkflowTask> {
118 self.tasks.get(id)
119 }
120
121 pub fn set_root(&mut self, task_id: Uuid) {
123 if self.tasks.contains_key(&task_id) {
124 self.root_id = Some(task_id);
125 }
126 }
127
128 pub fn get_entry_tasks(&self) -> Vec<&WorkflowTask> {
130 self.tasks
131 .values()
132 .filter(|task| task.dependencies.is_empty())
133 .collect()
134 }
135
136 pub fn get_dependent_tasks(&self, task_id: &Uuid) -> Vec<&WorkflowTask> {
138 self.tasks
139 .values()
140 .filter(|task| task.dependencies.contains(task_id))
141 .collect()
142 }
143
144 pub fn has_cycles(&self) -> bool {
146 let mut visited = HashSet::new();
147 let mut rec_stack = HashSet::new();
148
149 for task_id in self.tasks.keys() {
150 if self.has_cycle_dfs(task_id, &mut visited, &mut rec_stack) {
151 return true;
152 }
153 }
154
155 false
156 }
157
158 fn has_cycle_dfs(
159 &self,
160 task_id: &Uuid,
161 visited: &mut HashSet<Uuid>,
162 rec_stack: &mut HashSet<Uuid>,
163 ) -> bool {
164 if rec_stack.contains(task_id) {
165 return true;
166 }
167
168 if visited.contains(task_id) {
169 return false;
170 }
171
172 visited.insert(*task_id);
173 rec_stack.insert(*task_id);
174
175 if let Some(task) = self.tasks.get(task_id) {
176 for dep_id in &task.dependencies {
177 if self.has_cycle_dfs(dep_id, visited, rec_stack) {
178 return true;
179 }
180 }
181 }
182
183 rec_stack.remove(task_id);
184 false
185 }
186
187 pub fn topological_sort(&self) -> Result<Vec<Uuid>, String> {
189 if self.has_cycles() {
190 return Err("Workflow contains cycles".to_string());
191 }
192
193 let mut in_degree: HashMap<Uuid, usize> = HashMap::new();
194 let mut adj_list: HashMap<Uuid, Vec<Uuid>> = HashMap::new();
195
196 for id in self.tasks.keys() {
198 in_degree.insert(*id, 0);
199 adj_list.insert(*id, Vec::new());
200 }
201
202 for (id, task) in &self.tasks {
204 for &dep_id in &task.dependencies {
205 adj_list.entry(dep_id).or_default().push(*id);
207 *in_degree.entry(*id).or_insert(0) += 1;
208 }
209 }
210
211 let mut queue: VecDeque<Uuid> = in_degree
213 .iter()
214 .filter(|(_, °ree)| degree == 0)
215 .map(|(&id, _)| id)
216 .collect();
217
218 let mut sorted = Vec::new();
219
220 while let Some(task_id) = queue.pop_front() {
221 sorted.push(task_id);
222
223 if let Some(dependents) = adj_list.get(&task_id) {
225 for &dependent_id in dependents {
226 if let Some(degree) = in_degree.get_mut(&dependent_id) {
227 *degree -= 1;
228 if *degree == 0 {
229 queue.push_back(dependent_id);
230 }
231 }
232 }
233 }
234 }
235
236 if sorted.len() != self.tasks.len() {
237 Err("Could not complete topological sort".to_string())
238 } else {
239 Ok(sorted)
240 }
241 }
242
243 pub fn to_messages(&self) -> Result<Vec<Message>, String> {
245 let order = self.topological_sort()?;
246 let root_id = self.root_id.unwrap_or_else(|| order[0]);
247
248 let messages = order
249 .into_iter()
250 .filter_map(|task_id| {
251 self.tasks.get(&task_id).map(|task| {
252 let parent_id = if task.dependencies.is_empty() {
253 None
254 } else {
255 task.dependencies.first().copied()
256 };
257 task.to_message(Some(root_id), parent_id)
258 })
259 })
260 .collect();
261
262 Ok(messages)
263 }
264
265 #[inline]
267 pub fn name(&self) -> &str {
268 &self.name
269 }
270
271 #[inline]
273 pub fn len(&self) -> usize {
274 self.tasks.len()
275 }
276
277 #[inline]
279 pub fn is_empty(&self) -> bool {
280 self.tasks.is_empty()
281 }
282}
283
284#[derive(Debug, Clone)]
286pub struct ChainBuilder {
287 tasks: Vec<WorkflowTask>,
288 name: String,
289}
290
291impl ChainBuilder {
292 pub fn new(name: impl Into<String>) -> Self {
294 Self {
295 tasks: Vec::new(),
296 name: name.into(),
297 }
298 }
299
300 #[must_use]
302 pub fn then(mut self, task_name: impl Into<String>) -> Self {
303 let task = WorkflowTask::new(task_name);
304 self.tasks.push(task);
305 self
306 }
307
308 #[must_use]
310 pub fn then_with_args(
311 mut self,
312 task_name: impl Into<String>,
313 args: Vec<serde_json::Value>,
314 ) -> Self {
315 let task = WorkflowTask::new(task_name).with_args(args);
316 self.tasks.push(task);
317 self
318 }
319
320 pub fn build(self) -> Workflow {
322 let mut workflow = Workflow::new(self.name);
323
324 let mut prev_id: Option<Uuid> = None;
325
326 for mut task in self.tasks {
327 if let Some(prev) = prev_id {
328 task = task.depends_on(prev);
329 }
330 prev_id = Some(task.id);
331 workflow.add_task(task);
332 }
333
334 workflow
335 }
336
337 pub fn build_messages(self) -> Result<Vec<Message>, String> {
339 self.build().to_messages()
340 }
341}
342
343#[derive(Debug, Clone)]
345pub struct Group {
346 tasks: Vec<WorkflowTask>,
347 group_id: Uuid,
348}
349
350impl Group {
351 pub fn new() -> Self {
353 Self {
354 tasks: Vec::new(),
355 group_id: Uuid::new_v4(),
356 }
357 }
358
359 #[must_use]
361 pub fn with_task(mut self, task: WorkflowTask) -> Self {
362 self.tasks.push(task);
363 self
364 }
365
366 #[must_use]
368 pub fn add_task(mut self, task_name: impl Into<String>) -> Self {
369 self.tasks.push(WorkflowTask::new(task_name));
370 self
371 }
372
373 pub fn to_messages(&self) -> Vec<Message> {
375 self.tasks
376 .iter()
377 .map(|task| {
378 let mut msg = task.to_message(None, None);
379 msg.headers.group = Some(self.group_id);
380 msg
381 })
382 .collect()
383 }
384
385 #[inline]
387 pub fn id(&self) -> Uuid {
388 self.group_id
389 }
390
391 #[inline]
393 pub fn len(&self) -> usize {
394 self.tasks.len()
395 }
396
397 #[inline]
399 pub fn is_empty(&self) -> bool {
400 self.tasks.is_empty()
401 }
402}
403
404impl Default for Group {
405 fn default() -> Self {
406 Self::new()
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[test]
415 fn test_workflow_task_creation() {
416 let task = WorkflowTask::new("tasks.add")
417 .with_args(vec![serde_json::json!(1), serde_json::json!(2)]);
418
419 assert_eq!(task.task_name, "tasks.add");
420 assert_eq!(task.args.len(), 2);
421 assert!(task.dependencies.is_empty());
422 }
423
424 #[test]
425 fn test_workflow_task_dependencies() {
426 let task1_id = Uuid::new_v4();
427 let task2_id = Uuid::new_v4();
428
429 let task = WorkflowTask::new("task3")
430 .depends_on(task1_id)
431 .depends_on(task2_id);
432
433 assert_eq!(task.dependencies.len(), 2);
434 assert!(task.dependencies.contains(&task1_id));
435 assert!(task.dependencies.contains(&task2_id));
436 }
437
438 #[test]
439 fn test_workflow_add_task() {
440 let mut workflow = Workflow::new("test_workflow");
441 let task = WorkflowTask::new("tasks.test");
442 let task_id = workflow.add_task(task);
443
444 assert_eq!(workflow.len(), 1);
445 assert!(workflow.get_task(&task_id).is_some());
446 }
447
448 #[test]
449 fn test_workflow_entry_tasks() {
450 let mut workflow = Workflow::new("test");
451
452 let task1 = WorkflowTask::new("task1");
453 let task2 = WorkflowTask::new("task2");
454 let task1_id = task1.id;
455
456 workflow.add_task(task1);
457 workflow.add_task(task2);
458
459 let task3 = WorkflowTask::new("task3").depends_on(task1_id);
460 workflow.add_task(task3);
461
462 let entry_tasks = workflow.get_entry_tasks();
463 assert_eq!(entry_tasks.len(), 2); }
465
466 #[test]
467 fn test_workflow_dependent_tasks() {
468 let mut workflow = Workflow::new("test");
469
470 let task1 = WorkflowTask::new("task1");
471 let task1_id = task1.id;
472 workflow.add_task(task1);
473
474 let task2 = WorkflowTask::new("task2").depends_on(task1_id);
475 let task3 = WorkflowTask::new("task3").depends_on(task1_id);
476
477 workflow.add_task(task2);
478 workflow.add_task(task3);
479
480 let dependents = workflow.get_dependent_tasks(&task1_id);
481 assert_eq!(dependents.len(), 2);
482 }
483
484 #[test]
485 fn test_workflow_no_cycles() {
486 let mut workflow = Workflow::new("test");
487
488 let task1 = WorkflowTask::new("task1");
489 let task1_id = task1.id;
490 workflow.add_task(task1);
491
492 let task2 = WorkflowTask::new("task2").depends_on(task1_id);
493 workflow.add_task(task2);
494
495 assert!(!workflow.has_cycles());
496 }
497
498 #[test]
499 fn test_workflow_topological_sort() {
500 let mut workflow = Workflow::new("test");
501
502 let task1 = WorkflowTask::new("task1");
503 let task1_id = task1.id;
504 workflow.add_task(task1);
505
506 let task2 = WorkflowTask::new("task2").depends_on(task1_id);
507 let task2_id = task2.id;
508 workflow.add_task(task2);
509
510 let task3 = WorkflowTask::new("task3").depends_on(task2_id);
511 workflow.add_task(task3);
512
513 let sorted = workflow.topological_sort().unwrap();
514 assert_eq!(sorted.len(), 3);
515 assert_eq!(sorted[0], task1_id);
516 assert_eq!(sorted[1], task2_id);
517 }
518
519 #[test]
520 fn test_workflow_to_messages() {
521 let mut workflow = Workflow::new("test");
522
523 let task1 = WorkflowTask::new("task1");
524 let task1_id = task1.id;
525 workflow.add_task(task1);
526
527 let task2 = WorkflowTask::new("task2").depends_on(task1_id);
528 workflow.add_task(task2);
529
530 let messages = workflow.to_messages().unwrap();
531 assert_eq!(messages.len(), 2);
532 }
533
534 #[test]
535 fn test_chain_builder() {
536 let chain = ChainBuilder::new("my_chain")
537 .then("task1")
538 .then("task2")
539 .then("task3")
540 .build();
541
542 assert_eq!(chain.len(), 3);
543
544 let sorted = chain.topological_sort().unwrap();
545 assert_eq!(sorted.len(), 3);
546 }
547
548 #[test]
549 fn test_chain_builder_with_args() {
550 let chain = ChainBuilder::new("my_chain")
551 .then_with_args("task1", vec![serde_json::json!(42)])
552 .then("task2")
553 .build();
554
555 assert_eq!(chain.len(), 2);
556 }
557
558 #[test]
559 fn test_chain_to_messages() {
560 let messages = ChainBuilder::new("my_chain")
561 .then("task1")
562 .then("task2")
563 .build_messages()
564 .unwrap();
565
566 assert_eq!(messages.len(), 2);
567 assert!(messages[0].has_root());
568 }
569
570 #[test]
571 fn test_group_creation() {
572 let group = Group::new()
573 .add_task("task1")
574 .add_task("task2")
575 .add_task("task3");
576
577 assert_eq!(group.len(), 3);
578 }
579
580 #[test]
581 fn test_group_to_messages() {
582 let group = Group::new().add_task("task1").add_task("task2");
583
584 let messages = group.to_messages();
585 assert_eq!(messages.len(), 2);
586
587 let group_id = messages[0].headers.group.unwrap();
589 assert_eq!(messages[1].headers.group.unwrap(), group_id);
590 }
591
592 #[test]
593 fn test_workflow_task_to_message() {
594 let task = WorkflowTask::new("tasks.test").with_args(vec![serde_json::json!(1)]);
595
596 let root_id = Uuid::new_v4();
597 let parent_id = Uuid::new_v4();
598
599 let message = task.to_message(Some(root_id), Some(parent_id));
600
601 assert_eq!(message.headers.task, "tasks.test");
602 assert_eq!(message.headers.root_id, Some(root_id));
603 assert_eq!(message.headers.parent_id, Some(parent_id));
604 }
605}