1use cortexai_core::{errors::CrewError, Task};
37use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39use std::sync::Arc;
40use tokio::sync::RwLock;
41
42#[derive(Debug, Clone)]
44pub enum NextAction {
45 Continue,
47
48 ContinueAndExecute,
50
51 WaitForInput(InputRequest),
53
54 Branch(Vec<String>),
56
57 Complete(WorkflowResult),
59
60 Failed(String),
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct InputRequest {
67 pub id: String,
69
70 pub prompt: String,
72
73 pub input_type: InputType,
75
76 pub waiting_node: String,
78
79 pub default: Option<serde_json::Value>,
81
82 pub timeout_secs: Option<u64>,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub enum InputType {
89 Text,
91
92 Confirmation,
94
95 Selection(Vec<String>),
97
98 Approval,
100
101 Json(serde_json::Value), }
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct WorkflowResult {
108 pub output: serde_json::Value,
110
111 pub trace: Vec<NodeExecution>,
113
114 pub duration_ms: u64,
116
117 pub status: WorkflowStatus,
119}
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
123pub enum WorkflowStatus {
124 Success,
125 Failed,
126 Cancelled,
127 TimedOut,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct NodeExecution {
133 pub node_id: String,
135
136 pub started_at: chrono::DateTime<chrono::Utc>,
138
139 pub ended_at: chrono::DateTime<chrono::Utc>,
141
142 pub result: NodeResult,
144
145 pub output: serde_json::Value,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub enum NodeResult {
152 Success,
153 Failed(String),
154 Skipped,
155 WaitingForInput,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct Condition {
161 pub field: String,
163
164 pub operator: ConditionOperator,
166
167 pub value: serde_json::Value,
169}
170
171impl Condition {
172 pub fn when(field: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
174 Self {
175 field: field.into(),
176 operator: ConditionOperator::Equals,
177 value: value.into(),
178 }
179 }
180
181 pub fn when_not(field: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
183 Self {
184 field: field.into(),
185 operator: ConditionOperator::NotEquals,
186 value: value.into(),
187 }
188 }
189
190 pub fn when_gt(field: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
192 Self {
193 field: field.into(),
194 operator: ConditionOperator::GreaterThan,
195 value: value.into(),
196 }
197 }
198
199 pub fn evaluate(&self, context: &serde_json::Value) -> bool {
201 let field_value = context.get(&self.field);
202
203 match (&self.operator, field_value) {
204 (ConditionOperator::Equals, Some(v)) => v == &self.value,
205 (ConditionOperator::NotEquals, Some(v)) => v != &self.value,
206 (ConditionOperator::GreaterThan, Some(v)) => match (v.as_f64(), self.value.as_f64()) {
207 (Some(a), Some(b)) => a > b,
208 _ => false,
209 },
210 (ConditionOperator::LessThan, Some(v)) => match (v.as_f64(), self.value.as_f64()) {
211 (Some(a), Some(b)) => a < b,
212 _ => false,
213 },
214 (ConditionOperator::Contains, Some(v)) => {
215 if let (Some(arr), Some(needle)) = (v.as_array(), self.value.as_str()) {
216 arr.iter().any(|x| x.as_str() == Some(needle))
217 } else if let (Some(s), Some(needle)) = (v.as_str(), self.value.as_str()) {
218 s.contains(needle)
219 } else {
220 false
221 }
222 }
223 (ConditionOperator::Exists, _) => field_value.is_some(),
224 _ => false,
225 }
226 }
227}
228
229#[derive(Debug, Clone, Serialize, Deserialize)]
231pub enum ConditionOperator {
232 Equals,
233 NotEquals,
234 GreaterThan,
235 LessThan,
236 Contains,
237 Exists,
238}
239
240#[derive(Clone)]
242pub struct WorkflowNode {
243 pub id: String,
245
246 pub name: String,
248
249 pub node_type: NodeType,
251
252 pub task: Option<Task>,
254
255 pub executor: Option<Arc<dyn NodeExecutor>>,
257
258 pub metadata: HashMap<String, serde_json::Value>,
260}
261
262impl std::fmt::Debug for WorkflowNode {
263 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264 f.debug_struct("WorkflowNode")
265 .field("id", &self.id)
266 .field("name", &self.name)
267 .field("node_type", &self.node_type)
268 .finish()
269 }
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize)]
274pub enum NodeType {
275 Task,
277
278 HumanInput,
280
281 Decision,
283
284 Join,
286
287 Start,
289
290 End,
292
293 Parallel,
295}
296
297#[async_trait::async_trait]
299pub trait NodeExecutor: Send + Sync {
300 async fn execute(
302 &self,
303 node: &WorkflowNode,
304 context: &mut WorkflowContext,
305 ) -> Result<NextAction, CrewError>;
306}
307
308#[derive(Debug, Clone)]
310pub struct WorkflowEdge {
311 pub from: String,
313
314 pub to: String,
316
317 pub condition: Option<Condition>,
319
320 pub priority: i32,
322}
323
324#[derive(Clone)]
326pub struct Workflow {
327 pub id: String,
329
330 pub name: String,
332
333 pub nodes: HashMap<String, WorkflowNode>,
335
336 pub edges: Vec<WorkflowEdge>,
338
339 pub entry_node: String,
341
342 pub metadata: HashMap<String, serde_json::Value>,
344}
345
346impl std::fmt::Debug for Workflow {
347 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
348 f.debug_struct("Workflow")
349 .field("id", &self.id)
350 .field("name", &self.name)
351 .field("nodes", &self.nodes.keys().collect::<Vec<_>>())
352 .field("entry_node", &self.entry_node)
353 .finish()
354 }
355}
356
357pub struct WorkflowBuilder {
359 id: String,
360 name: String,
361 nodes: HashMap<String, WorkflowNode>,
362 edges: Vec<WorkflowEdge>,
363 entry_node: Option<String>,
364 metadata: HashMap<String, serde_json::Value>,
365}
366
367impl WorkflowBuilder {
368 pub fn new(id: impl Into<String>) -> Self {
370 let id = id.into();
371 Self {
372 name: id.clone(),
373 id,
374 nodes: HashMap::new(),
375 edges: Vec::new(),
376 entry_node: None,
377 metadata: HashMap::new(),
378 }
379 }
380
381 pub fn name(mut self, name: impl Into<String>) -> Self {
383 self.name = name.into();
384 self
385 }
386
387 pub fn add_task_node(mut self, id: impl Into<String>, task: Task) -> Self {
389 let id = id.into();
390 self.nodes.insert(
391 id.clone(),
392 WorkflowNode {
393 id: id.clone(),
394 name: task.description.clone(),
395 node_type: NodeType::Task,
396 task: Some(task),
397 executor: None,
398 metadata: HashMap::new(),
399 },
400 );
401 self
402 }
403
404 pub fn add_input_node(
406 mut self,
407 id: impl Into<String>,
408 prompt: impl Into<String>,
409 input_type: InputType,
410 ) -> Self {
411 let id = id.into();
412 let prompt = prompt.into();
413 self.nodes.insert(
414 id.clone(),
415 WorkflowNode {
416 id: id.clone(),
417 name: prompt.clone(),
418 node_type: NodeType::HumanInput,
419 task: None,
420 executor: None,
421 metadata: {
422 let mut m = HashMap::new();
423 m.insert("prompt".to_string(), serde_json::json!(prompt));
424 m.insert(
425 "input_type".to_string(),
426 serde_json::to_value(&input_type).unwrap(),
427 );
428 m
429 },
430 },
431 );
432 self
433 }
434
435 pub fn add_decision_node(mut self, id: impl Into<String>, name: impl Into<String>) -> Self {
437 let id = id.into();
438 self.nodes.insert(
439 id.clone(),
440 WorkflowNode {
441 id: id.clone(),
442 name: name.into(),
443 node_type: NodeType::Decision,
444 task: None,
445 executor: None,
446 metadata: HashMap::new(),
447 },
448 );
449 self
450 }
451
452 pub fn add_custom_node(
454 mut self,
455 id: impl Into<String>,
456 name: impl Into<String>,
457 executor: Arc<dyn NodeExecutor>,
458 ) -> Self {
459 let id = id.into();
460 self.nodes.insert(
461 id.clone(),
462 WorkflowNode {
463 id: id.clone(),
464 name: name.into(),
465 node_type: NodeType::Task,
466 task: None,
467 executor: Some(executor),
468 metadata: HashMap::new(),
469 },
470 );
471 self
472 }
473
474 pub fn connect(
476 mut self,
477 from: impl Into<String>,
478 to: impl Into<String>,
479 condition: Option<Condition>,
480 ) -> Self {
481 self.edges.push(WorkflowEdge {
482 from: from.into(),
483 to: to.into(),
484 condition,
485 priority: 0,
486 });
487 self
488 }
489
490 pub fn connect_with_priority(
492 mut self,
493 from: impl Into<String>,
494 to: impl Into<String>,
495 condition: Option<Condition>,
496 priority: i32,
497 ) -> Self {
498 self.edges.push(WorkflowEdge {
499 from: from.into(),
500 to: to.into(),
501 condition,
502 priority,
503 });
504 self
505 }
506
507 pub fn set_entry(mut self, node_id: impl Into<String>) -> Self {
509 self.entry_node = Some(node_id.into());
510 self
511 }
512
513 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
515 self.metadata.insert(key.into(), value);
516 self
517 }
518
519 pub fn build(self) -> Result<Workflow, CrewError> {
521 let entry_node = self.entry_node.ok_or_else(|| {
522 CrewError::InvalidConfiguration("No entry node specified".to_string())
523 })?;
524
525 if !self.nodes.contains_key(&entry_node) {
526 return Err(CrewError::InvalidConfiguration(format!(
527 "Entry node '{}' not found",
528 entry_node
529 )));
530 }
531
532 for edge in &self.edges {
534 if !self.nodes.contains_key(&edge.from) {
535 return Err(CrewError::InvalidConfiguration(format!(
536 "Edge source '{}' not found",
537 edge.from
538 )));
539 }
540 if !self.nodes.contains_key(&edge.to) {
541 return Err(CrewError::InvalidConfiguration(format!(
542 "Edge target '{}' not found",
543 edge.to
544 )));
545 }
546 }
547
548 Ok(Workflow {
549 id: self.id,
550 name: self.name,
551 nodes: self.nodes,
552 edges: self.edges,
553 entry_node,
554 metadata: self.metadata,
555 })
556 }
557}
558
559#[derive(Debug, Clone, Default)]
561pub struct WorkflowContext {
562 pub data: serde_json::Value,
564
565 pub trace: Vec<NodeExecution>,
567
568 pub current_node: Option<String>,
570
571 pub pending_input: Option<InputRequest>,
573
574 pub provided_input: Option<serde_json::Value>,
576}
577
578impl WorkflowContext {
579 pub fn new() -> Self {
581 Self {
582 data: serde_json::json!({}),
583 trace: Vec::new(),
584 current_node: None,
585 pending_input: None,
586 provided_input: None,
587 }
588 }
589
590 pub fn with_data(data: serde_json::Value) -> Self {
592 Self {
593 data,
594 ..Self::new()
595 }
596 }
597
598 pub fn set(&mut self, key: &str, value: serde_json::Value) {
600 if let Some(obj) = self.data.as_object_mut() {
601 obj.insert(key.to_string(), value);
602 }
603 }
604
605 pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
607 self.data.get(key)
608 }
609}
610
611pub struct WorkflowRunner {
613 workflow: Workflow,
614 context: Arc<RwLock<WorkflowContext>>,
615 started_at: Option<std::time::Instant>,
616}
617
618impl WorkflowRunner {
619 pub fn new(workflow: Workflow) -> Self {
621 Self {
622 workflow,
623 context: Arc::new(RwLock::new(WorkflowContext::new())),
624 started_at: None,
625 }
626 }
627
628 pub fn with_context(workflow: Workflow, context: WorkflowContext) -> Self {
630 Self {
631 workflow,
632 context: Arc::new(RwLock::new(context)),
633 started_at: None,
634 }
635 }
636
637 pub async fn step(&mut self) -> Result<NextAction, CrewError> {
639 if self.started_at.is_none() {
640 self.started_at = Some(std::time::Instant::now());
641 }
642
643 let mut ctx = self.context.write().await;
644
645 if ctx.pending_input.is_some() && ctx.provided_input.is_none() {
647 return Ok(NextAction::WaitForInput(ctx.pending_input.clone().unwrap()));
648 }
649
650 let current_id = ctx
652 .current_node
653 .clone()
654 .unwrap_or_else(|| self.workflow.entry_node.clone());
655
656 let node = self
657 .workflow
658 .nodes
659 .get(¤t_id)
660 .ok_or_else(|| CrewError::TaskNotFound(current_id.clone()))?
661 .clone();
662
663 let start_time = chrono::Utc::now();
664
665 let (result, output) = match &node.node_type {
667 NodeType::Start => (NodeResult::Success, serde_json::json!({})),
668
669 NodeType::End => {
670 let duration = self
671 .started_at
672 .map(|s| s.elapsed().as_millis() as u64)
673 .unwrap_or(0);
674 return Ok(NextAction::Complete(WorkflowResult {
675 output: ctx.data.clone(),
676 trace: ctx.trace.clone(),
677 duration_ms: duration,
678 status: WorkflowStatus::Success,
679 }));
680 }
681
682 NodeType::HumanInput => {
683 if let Some(input) = ctx.provided_input.take() {
685 ctx.set(¤t_id, input.clone());
687 ctx.pending_input = None;
688 (NodeResult::Success, input)
689 } else {
690 let prompt = node
692 .metadata
693 .get("prompt")
694 .and_then(|v| v.as_str())
695 .unwrap_or("Please provide input")
696 .to_string();
697
698 let input_type: InputType = node
699 .metadata
700 .get("input_type")
701 .and_then(|v| serde_json::from_value(v.clone()).ok())
702 .unwrap_or(InputType::Text);
703
704 let request = InputRequest {
705 id: format!("{}_{}", current_id, chrono::Utc::now().timestamp()),
706 prompt,
707 input_type,
708 waiting_node: current_id.clone(),
709 default: None,
710 timeout_secs: None,
711 };
712
713 ctx.pending_input = Some(request.clone());
714 return Ok(NextAction::WaitForInput(request));
715 }
716 }
717
718 NodeType::Decision => {
719 (NodeResult::Success, serde_json::json!({}))
721 }
722
723 NodeType::Task => {
724 if let Some(executor) = &node.executor {
725 match executor.execute(&node, &mut ctx).await {
726 Ok(action) => return Ok(action),
727 Err(e) => (NodeResult::Failed(e.to_string()), serde_json::json!({})),
728 }
729 } else {
730 (NodeResult::Success, serde_json::json!({"task": node.name}))
732 }
733 }
734
735 NodeType::Join => {
736 (NodeResult::Success, serde_json::json!({}))
738 }
739
740 NodeType::Parallel => {
741 (NodeResult::Success, serde_json::json!({}))
743 }
744 };
745
746 ctx.trace.push(NodeExecution {
748 node_id: current_id.clone(),
749 started_at: start_time,
750 ended_at: chrono::Utc::now(),
751 result,
752 output,
753 });
754
755 let next_node = self.find_next_node(¤t_id, &ctx.data)?;
757
758 if let Some(next_id) = next_node {
759 ctx.current_node = Some(next_id);
760 Ok(NextAction::Continue)
761 } else {
762 let duration = self
764 .started_at
765 .map(|s| s.elapsed().as_millis() as u64)
766 .unwrap_or(0);
767 Ok(NextAction::Complete(WorkflowResult {
768 output: ctx.data.clone(),
769 trace: ctx.trace.clone(),
770 duration_ms: duration,
771 status: WorkflowStatus::Success,
772 }))
773 }
774 }
775
776 pub async fn provide_input(&mut self, input: serde_json::Value) {
778 let mut ctx = self.context.write().await;
779 ctx.provided_input = Some(input);
780 }
781
782 pub async fn run(&mut self) -> Result<WorkflowResult, CrewError> {
784 loop {
785 match self.step().await? {
786 NextAction::Continue | NextAction::ContinueAndExecute => continue,
787 NextAction::Complete(result) => return Ok(result),
788 NextAction::Failed(err) => return Err(CrewError::ExecutionFailed(err)),
789 NextAction::WaitForInput(_) => {
790 return Err(CrewError::ExecutionFailed(
791 "Workflow requires input but running in non-interactive mode".to_string(),
792 ))
793 }
794 NextAction::Branch(_) => {
795 continue;
797 }
798 }
799 }
800 }
801
802 fn find_next_node(
804 &self,
805 current: &str,
806 context: &serde_json::Value,
807 ) -> Result<Option<String>, CrewError> {
808 let mut edges: Vec<_> = self
810 .workflow
811 .edges
812 .iter()
813 .filter(|e| e.from == current)
814 .collect();
815
816 edges.sort_by(|a, b| b.priority.cmp(&a.priority));
817
818 for edge in edges {
820 match &edge.condition {
821 None => return Ok(Some(edge.to.clone())),
822 Some(cond) if cond.evaluate(context) => return Ok(Some(edge.to.clone())),
823 _ => continue,
824 }
825 }
826
827 Ok(None)
828 }
829
830 pub async fn context(&self) -> WorkflowContext {
832 self.context.read().await.clone()
833 }
834}
835
836#[cfg(test)]
837mod tests {
838 use super::*;
839
840 #[test]
841 fn test_workflow_builder() {
842 let workflow = WorkflowBuilder::new("test")
843 .add_task_node("step1", Task::new("First step"))
844 .add_task_node("step2", Task::new("Second step"))
845 .connect("step1", "step2", None)
846 .set_entry("step1")
847 .build()
848 .unwrap();
849
850 assert_eq!(workflow.nodes.len(), 2);
851 assert_eq!(workflow.edges.len(), 1);
852 assert_eq!(workflow.entry_node, "step1");
853 }
854
855 #[test]
856 fn test_condition_evaluation() {
857 let ctx = serde_json::json!({
858 "approved": true,
859 "amount": 100,
860 "name": "test"
861 });
862
863 assert!(Condition::when("approved", true).evaluate(&ctx));
864 assert!(!Condition::when("approved", false).evaluate(&ctx));
865 assert!(Condition::when_gt("amount", 50).evaluate(&ctx));
866 assert!(!Condition::when_gt("amount", 150).evaluate(&ctx));
867 }
868
869 #[tokio::test]
870 async fn test_simple_workflow_execution() {
871 let workflow = WorkflowBuilder::new("simple")
872 .add_task_node("start", Task::new("Start task"))
873 .add_task_node("end", Task::new("End task"))
874 .connect("start", "end", None)
875 .set_entry("start")
876 .build()
877 .unwrap();
878
879 let mut runner = WorkflowRunner::new(workflow);
880 let result = runner.run().await.unwrap();
881
882 assert_eq!(result.status, WorkflowStatus::Success);
883 assert_eq!(result.trace.len(), 2);
884 }
885
886 #[tokio::test]
887 async fn test_conditional_workflow() {
888 let workflow = WorkflowBuilder::new("conditional")
889 .add_task_node("check", Task::new("Check condition"))
890 .add_task_node("yes_path", Task::new("Yes path"))
891 .add_task_node("no_path", Task::new("No path"))
892 .connect("check", "yes_path", Some(Condition::when("approved", true)))
893 .connect("check", "no_path", Some(Condition::when("approved", false)))
894 .set_entry("check")
895 .build()
896 .unwrap();
897
898 let ctx = WorkflowContext::with_data(serde_json::json!({"approved": true}));
900 let mut runner = WorkflowRunner::with_context(workflow.clone(), ctx);
901 let result = runner.run().await.unwrap();
902
903 assert!(result.trace.iter().any(|t| t.node_id == "yes_path"));
904 assert!(!result.trace.iter().any(|t| t.node_id == "no_path"));
905 }
906
907 #[test]
908 fn test_workflow_validation() {
909 let result = WorkflowBuilder::new("test")
911 .add_task_node("step1", Task::new("Step"))
912 .build();
913
914 assert!(result.is_err());
915
916 let result = WorkflowBuilder::new("test")
918 .add_task_node("step1", Task::new("Step"))
919 .set_entry("nonexistent")
920 .build();
921
922 assert!(result.is_err());
923
924 let result = WorkflowBuilder::new("test")
926 .add_task_node("step1", Task::new("Step"))
927 .connect("step1", "nonexistent", None)
928 .set_entry("step1")
929 .build();
930
931 assert!(result.is_err());
932 }
933}