1use super::state::{NodeResult, WorkflowContext, WorkflowValue};
6use crate::llm::LLMAgent;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12use tracing::{debug, info, warn};
13
14pub type NodeExecutorFn = Arc<
16 dyn Fn(
17 WorkflowContext,
18 WorkflowValue,
19 ) -> Pin<Box<dyn Future<Output = Result<WorkflowValue, String>> + Send>>
20 + Send
21 + Sync,
22>;
23
24pub type ConditionFn = Arc<
26 dyn Fn(WorkflowContext, WorkflowValue) -> Pin<Box<dyn Future<Output = bool> + Send>>
27 + Send
28 + Sync,
29>;
30
31pub type TransformFn = Arc<
33 dyn Fn(HashMap<String, WorkflowValue>) -> Pin<Box<dyn Future<Output = WorkflowValue> + Send>>
34 + Send
35 + Sync,
36>;
37
38#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
40pub enum NodeType {
41 Start,
43 End,
45 Task,
47 Agent,
49 Condition,
51 Parallel,
53 Join,
55 Loop,
57 SubWorkflow,
59 Wait,
61 Transform,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct RetryPolicy {
68 pub max_retries: u32,
70 pub retry_delay_ms: u64,
72 pub exponential_backoff: bool,
74 pub max_delay_ms: u64,
76}
77
78impl Default for RetryPolicy {
79 fn default() -> Self {
80 Self {
81 max_retries: 3,
82 retry_delay_ms: 1000,
83 exponential_backoff: true,
84 max_delay_ms: 30000,
85 }
86 }
87}
88
89impl RetryPolicy {
90 pub fn no_retry() -> Self {
91 Self {
92 max_retries: 0,
93 ..Default::default()
94 }
95 }
96
97 pub fn with_retries(max_retries: u32) -> Self {
98 Self {
99 max_retries,
100 ..Default::default()
101 }
102 }
103
104 pub fn get_delay(&self, retry_count: u32) -> u64 {
106 if self.exponential_backoff {
107 let delay = self.retry_delay_ms * 2u64.pow(retry_count);
108 delay.min(self.max_delay_ms)
109 } else {
110 self.retry_delay_ms
111 }
112 }
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct TimeoutConfig {
118 pub execution_timeout_ms: u64,
120 pub cancel_on_timeout: bool,
122}
123
124impl Default for TimeoutConfig {
125 fn default() -> Self {
126 Self {
127 execution_timeout_ms: 60000, cancel_on_timeout: true,
129 }
130 }
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct NodeConfig {
136 pub id: String,
138 pub name: String,
140 pub node_type: NodeType,
142 pub description: String,
144 pub retry_policy: RetryPolicy,
146 pub timeout: TimeoutConfig,
148 pub metadata: HashMap<String, String>,
150}
151
152impl NodeConfig {
153 pub fn new(id: &str, name: &str, node_type: NodeType) -> Self {
154 Self {
155 id: id.to_string(),
156 name: name.to_string(),
157 node_type,
158 description: String::new(),
159 retry_policy: RetryPolicy::default(),
160 timeout: TimeoutConfig::default(),
161 metadata: HashMap::new(),
162 }
163 }
164
165 pub fn with_description(mut self, desc: &str) -> Self {
166 self.description = desc.to_string();
167 self
168 }
169
170 pub fn with_retry_policy(mut self, policy: RetryPolicy) -> Self {
171 self.retry_policy = policy;
172 self
173 }
174
175 pub fn with_timeout(mut self, timeout: TimeoutConfig) -> Self {
176 self.timeout = timeout;
177 self
178 }
179
180 pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
181 self.metadata.insert(key.to_string(), value.to_string());
182 self
183 }
184}
185
186pub struct WorkflowNode {
188 pub config: NodeConfig,
190 executor: Option<NodeExecutorFn>,
192 condition: Option<ConditionFn>,
194 transform: Option<TransformFn>,
196 loop_condition: Option<ConditionFn>,
198 max_iterations: Option<u32>,
200 parallel_branches: Vec<String>,
202 join_nodes: Vec<String>,
204 sub_workflow_id: Option<String>,
206 wait_event_type: Option<String>,
208 condition_branches: HashMap<String, String>,
210}
211
212impl WorkflowNode {
213 pub fn start(id: &str) -> Self {
215 Self {
216 config: NodeConfig::new(id, "Start", NodeType::Start),
217 executor: None,
218 condition: None,
219 transform: None,
220 loop_condition: None,
221 max_iterations: None,
222 parallel_branches: Vec::new(),
223 join_nodes: Vec::new(),
224 sub_workflow_id: None,
225 wait_event_type: None,
226 condition_branches: HashMap::new(),
227 }
228 }
229
230 pub fn end(id: &str) -> Self {
232 Self {
233 config: NodeConfig::new(id, "End", NodeType::End),
234 executor: None,
235 condition: None,
236 transform: None,
237 loop_condition: None,
238 max_iterations: None,
239 parallel_branches: Vec::new(),
240 join_nodes: Vec::new(),
241 sub_workflow_id: None,
242 wait_event_type: None,
243 condition_branches: HashMap::new(),
244 }
245 }
246
247 pub fn task<F, Fut>(id: &str, name: &str, executor: F) -> Self
249 where
250 F: Fn(WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
251 Fut: Future<Output = Result<WorkflowValue, String>> + Send + 'static,
252 {
253 Self {
254 config: NodeConfig::new(id, name, NodeType::Task),
255 executor: Some(Arc::new(move |ctx, input| Box::pin(executor(ctx, input)))),
256 condition: None,
257 transform: None,
258 loop_condition: None,
259 max_iterations: None,
260 parallel_branches: Vec::new(),
261 join_nodes: Vec::new(),
262 sub_workflow_id: None,
263 wait_event_type: None,
264 condition_branches: HashMap::new(),
265 }
266 }
267
268 pub fn agent<F, Fut>(id: &str, name: &str, agent_executor: F) -> Self
270 where
271 F: Fn(WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
272 Fut: Future<Output = Result<WorkflowValue, String>> + Send + 'static,
273 {
274 Self {
275 config: NodeConfig::new(id, name, NodeType::Agent),
276 executor: Some(Arc::new(move |ctx, input| {
277 Box::pin(agent_executor(ctx, input))
278 })),
279 condition: None,
280 transform: None,
281 loop_condition: None,
282 max_iterations: None,
283 parallel_branches: Vec::new(),
284 join_nodes: Vec::new(),
285 sub_workflow_id: None,
286 wait_event_type: None,
287 condition_branches: HashMap::new(),
288 }
289 }
290
291 pub fn condition<F, Fut>(id: &str, name: &str, condition_fn: F) -> Self
293 where
294 F: Fn(WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
295 Fut: Future<Output = bool> + Send + 'static,
296 {
297 Self {
298 config: NodeConfig::new(id, name, NodeType::Condition),
299 executor: None,
300 condition: Some(Arc::new(move |ctx, input| {
301 Box::pin(condition_fn(ctx, input))
302 })),
303 transform: None,
304 loop_condition: None,
305 max_iterations: None,
306 parallel_branches: Vec::new(),
307 join_nodes: Vec::new(),
308 sub_workflow_id: None,
309 wait_event_type: None,
310 condition_branches: HashMap::new(),
311 }
312 }
313
314 pub fn parallel(id: &str, name: &str, branches: Vec<&str>) -> Self {
316 Self {
317 config: NodeConfig::new(id, name, NodeType::Parallel),
318 executor: None,
319 condition: None,
320 transform: None,
321 loop_condition: None,
322 max_iterations: None,
323 parallel_branches: branches.into_iter().map(|s| s.to_string()).collect(),
324 join_nodes: Vec::new(),
325 sub_workflow_id: None,
326 wait_event_type: None,
327 condition_branches: HashMap::new(),
328 }
329 }
330
331 pub fn join(id: &str, name: &str, wait_for: Vec<&str>) -> Self {
333 Self {
334 config: NodeConfig::new(id, name, NodeType::Join),
335 executor: None,
336 condition: None,
337 transform: None,
338 loop_condition: None,
339 max_iterations: None,
340 parallel_branches: Vec::new(),
341 join_nodes: wait_for.into_iter().map(|s| s.to_string()).collect(),
342 sub_workflow_id: None,
343 wait_event_type: None,
344 condition_branches: HashMap::new(),
345 }
346 }
347
348 pub fn join_with_transform<F, Fut>(
350 id: &str,
351 name: &str,
352 wait_for: Vec<&str>,
353 transform: F,
354 ) -> Self
355 where
356 F: Fn(HashMap<String, WorkflowValue>) -> Fut + Send + Sync + 'static,
357 Fut: Future<Output = WorkflowValue> + Send + 'static,
358 {
359 Self {
360 config: NodeConfig::new(id, name, NodeType::Join),
361 executor: None,
362 condition: None,
363 transform: Some(Arc::new(move |inputs| Box::pin(transform(inputs)))),
364 loop_condition: None,
365 max_iterations: None,
366 parallel_branches: Vec::new(),
367 join_nodes: wait_for.into_iter().map(|s| s.to_string()).collect(),
368 sub_workflow_id: None,
369 wait_event_type: None,
370 condition_branches: HashMap::new(),
371 }
372 }
373
374 pub fn loop_node<F, Fut, C, CFut>(
376 id: &str,
377 name: &str,
378 body: F,
379 condition: C,
380 max_iterations: u32,
381 ) -> Self
382 where
383 F: Fn(WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
384 Fut: Future<Output = Result<WorkflowValue, String>> + Send + 'static,
385 C: Fn(WorkflowContext, WorkflowValue) -> CFut + Send + Sync + 'static,
386 CFut: Future<Output = bool> + Send + 'static,
387 {
388 Self {
389 config: NodeConfig::new(id, name, NodeType::Loop),
390 executor: Some(Arc::new(move |ctx, input| Box::pin(body(ctx, input)))),
391 condition: None,
392 transform: None,
393 loop_condition: Some(Arc::new(move |ctx, input| Box::pin(condition(ctx, input)))),
394 max_iterations: Some(max_iterations),
395 parallel_branches: Vec::new(),
396 join_nodes: Vec::new(),
397 sub_workflow_id: None,
398 wait_event_type: None,
399 condition_branches: HashMap::new(),
400 }
401 }
402
403 pub fn sub_workflow(id: &str, name: &str, sub_workflow_id: &str) -> Self {
405 Self {
406 config: NodeConfig::new(id, name, NodeType::SubWorkflow),
407 executor: None,
408 condition: None,
409 transform: None,
410 loop_condition: None,
411 max_iterations: None,
412 parallel_branches: Vec::new(),
413 join_nodes: Vec::new(),
414 sub_workflow_id: Some(sub_workflow_id.to_string()),
415 wait_event_type: None,
416 condition_branches: HashMap::new(),
417 }
418 }
419
420 pub fn wait(id: &str, name: &str, event_type: &str) -> Self {
422 Self {
423 config: NodeConfig::new(id, name, NodeType::Wait),
424 executor: None,
425 condition: None,
426 transform: None,
427 loop_condition: None,
428 max_iterations: None,
429 parallel_branches: Vec::new(),
430 join_nodes: Vec::new(),
431 sub_workflow_id: None,
432 wait_event_type: Some(event_type.to_string()),
433 condition_branches: HashMap::new(),
434 }
435 }
436
437 pub fn transform<F, Fut>(id: &str, name: &str, transform_fn: F) -> Self
439 where
440 F: Fn(HashMap<String, WorkflowValue>) -> Fut + Send + Sync + 'static,
441 Fut: Future<Output = WorkflowValue> + Send + 'static,
442 {
443 Self {
444 config: NodeConfig::new(id, name, NodeType::Transform),
445 executor: None,
446 condition: None,
447 transform: Some(Arc::new(move |inputs| Box::pin(transform_fn(inputs)))),
448 loop_condition: None,
449 max_iterations: None,
450 parallel_branches: Vec::new(),
451 join_nodes: Vec::new(),
452 sub_workflow_id: None,
453 wait_event_type: None,
454 condition_branches: HashMap::new(),
455 }
456 }
457
458 pub fn llm_agent(id: &str, name: &str, agent: Arc<LLMAgent>) -> Self {
463 let agent_clone = Arc::clone(&agent);
464 Self {
465 config: NodeConfig::new(id, name, NodeType::Agent),
466 executor: Some(Arc::new(move |_ctx, input| {
467 let agent = Arc::clone(&agent_clone);
468 Box::pin(async move {
469 let prompt_str = match input.as_str() {
470 Some(s) => s.to_string(),
471 None => serde_json::to_string(&input).unwrap_or_default(),
472 };
473 agent
474 .chat(&prompt_str)
475 .await
476 .map(WorkflowValue::String)
477 .map_err(|e| e.to_string())
478 })
479 })),
480 condition: None,
481 transform: None,
482 loop_condition: None,
483 max_iterations: None,
484 parallel_branches: Vec::new(),
485 join_nodes: Vec::new(),
486 sub_workflow_id: None,
487 wait_event_type: None,
488 condition_branches: HashMap::new(),
489 }
490 }
491
492 pub fn llm_agent_with_template(
499 id: &str,
500 name: &str,
501 agent: Arc<LLMAgent>,
502 prompt_template: String,
503 ) -> Self {
504 let agent_clone = Arc::clone(&agent);
505 Self {
506 config: NodeConfig::new(id, name, NodeType::Agent),
507 executor: Some(Arc::new(move |_ctx, input| {
508 let agent = Arc::clone(&agent_clone);
509 let template = prompt_template.clone();
510 Box::pin(async move {
511 let prompt = if template.contains("{{") {
513 let input_str = match input.as_str() {
514 Some(s) => s.to_string(),
515 None => serde_json::to_string(&input).unwrap_or_default(),
516 };
517 template.replace("{{ input }}", &input_str)
518 } else {
519 template
520 };
521 agent
522 .chat(&prompt)
523 .await
524 .map(WorkflowValue::String)
525 .map_err(|e| e.to_string())
526 })
527 })),
528 condition: None,
529 transform: None,
530 loop_condition: None,
531 max_iterations: None,
532 parallel_branches: Vec::new(),
533 join_nodes: Vec::new(),
534 sub_workflow_id: None,
535 wait_event_type: None,
536 condition_branches: HashMap::new(),
537 }
538 }
539
540 pub fn with_description(mut self, desc: &str) -> Self {
542 self.config.description = desc.to_string();
543 self
544 }
545
546 pub fn with_retry(mut self, policy: RetryPolicy) -> Self {
548 self.config.retry_policy = policy;
549 self
550 }
551
552 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
554 self.config.timeout.execution_timeout_ms = timeout_ms;
555 self
556 }
557
558 pub fn with_branch(mut self, condition_name: &str, target_node_id: &str) -> Self {
560 self.condition_branches
561 .insert(condition_name.to_string(), target_node_id.to_string());
562 self
563 }
564
565 pub fn id(&self) -> &str {
567 &self.config.id
568 }
569
570 pub fn node_type(&self) -> &NodeType {
572 &self.config.node_type
573 }
574
575 pub fn parallel_branches(&self) -> &[String] {
577 &self.parallel_branches
578 }
579
580 pub fn join_nodes(&self) -> &[String] {
582 &self.join_nodes
583 }
584
585 pub fn condition_branches(&self) -> &HashMap<String, String> {
587 &self.condition_branches
588 }
589
590 pub fn sub_workflow_id(&self) -> Option<&str> {
592 self.sub_workflow_id.as_deref()
593 }
594
595 pub fn wait_event_type(&self) -> Option<&str> {
597 self.wait_event_type.as_deref()
598 }
599
600 pub async fn execute(&self, ctx: &WorkflowContext, input: WorkflowValue) -> NodeResult {
602 let start_time = std::time::Instant::now();
603 let node_id = &self.config.id;
604
605 info!("Executing node: {} ({})", node_id, self.config.name);
606
607 match self.config.node_type {
608 NodeType::Start => {
609 NodeResult::success(node_id, input, start_time.elapsed().as_millis() as u64)
611 }
612 NodeType::End => {
613 NodeResult::success(node_id, input, start_time.elapsed().as_millis() as u64)
615 }
616 NodeType::Task | NodeType::Agent => {
617 self.execute_with_retry(ctx, input, start_time).await
618 }
619 NodeType::Condition => {
620 if let Some(ref condition_fn) = self.condition {
622 let result = condition_fn(ctx.clone(), input.clone()).await;
623 let branch = if result { "true" } else { "false" };
624 debug!("Condition {} evaluated to: {}", node_id, branch);
625 NodeResult::success(
626 node_id,
627 WorkflowValue::String(branch.to_string()),
628 start_time.elapsed().as_millis() as u64,
629 )
630 } else {
631 NodeResult::failed(node_id, "No condition function", 0)
632 }
633 }
634 NodeType::Parallel => {
635 NodeResult::success(
637 node_id,
638 WorkflowValue::List(
639 self.parallel_branches
640 .iter()
641 .map(|b| WorkflowValue::String(b.clone()))
642 .collect(),
643 ),
644 start_time.elapsed().as_millis() as u64,
645 )
646 }
647 NodeType::Join => {
648 let outputs = ctx
650 .get_node_outputs(
651 &self
652 .join_nodes
653 .iter()
654 .map(|s| s.as_str())
655 .collect::<Vec<_>>(),
656 )
657 .await;
658
659 let result = if let Some(ref transform_fn) = self.transform {
660 transform_fn(outputs).await
661 } else {
662 WorkflowValue::Map(outputs)
664 };
665
666 NodeResult::success(node_id, result, start_time.elapsed().as_millis() as u64)
667 }
668 NodeType::Loop => self.execute_loop(ctx, input, start_time).await,
669 NodeType::Transform => {
670 if let Some(ref transform_fn) = self.transform {
671 let mut inputs = HashMap::new();
672 inputs.insert("input".to_string(), input);
673 let result = transform_fn(inputs).await;
674 NodeResult::success(node_id, result, start_time.elapsed().as_millis() as u64)
675 } else {
676 NodeResult::failed(node_id, "No transform function", 0)
677 }
678 }
679 NodeType::SubWorkflow => {
680 NodeResult::success(node_id, input, start_time.elapsed().as_millis() as u64)
682 }
683 NodeType::Wait => {
684 NodeResult::success(node_id, input, start_time.elapsed().as_millis() as u64)
686 }
687 }
688 }
689
690 async fn execute_with_retry(
692 &self,
693 ctx: &WorkflowContext,
694 input: WorkflowValue,
695 start_time: std::time::Instant,
696 ) -> NodeResult {
697 let node_id = &self.config.id;
698 let policy = &self.config.retry_policy;
699
700 let executor = match &self.executor {
701 Some(e) => e,
702 None => return NodeResult::failed(node_id, "No executor function", 0),
703 };
704
705 let mut retry_count = 0;
706 loop {
707 match executor(ctx.clone(), input.clone()).await {
708 Ok(output) => {
709 let mut result = NodeResult::success(
710 node_id,
711 output,
712 start_time.elapsed().as_millis() as u64,
713 );
714 result.retry_count = retry_count;
715 return result;
716 }
717 Err(e) => {
718 if retry_count < policy.max_retries {
719 let delay = policy.get_delay(retry_count);
720 warn!(
721 "Node {} failed (attempt {}), retrying in {}ms: {}",
722 node_id,
723 retry_count + 1,
724 delay,
725 e
726 );
727 tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
728 retry_count += 1;
729 } else {
730 let mut result = NodeResult::failed(
731 node_id,
732 &e,
733 start_time.elapsed().as_millis() as u64,
734 );
735 result.retry_count = retry_count;
736 return result;
737 }
738 }
739 }
740 }
741 }
742
743 async fn execute_loop(
745 &self,
746 ctx: &WorkflowContext,
747 mut input: WorkflowValue,
748 start_time: std::time::Instant,
749 ) -> NodeResult {
750 let node_id = &self.config.id;
751 let max_iter = self.max_iterations.unwrap_or(1000);
752
753 let executor = match &self.executor {
754 Some(e) => e,
755 None => return NodeResult::failed(node_id, "No executor function", 0),
756 };
757
758 let condition = match &self.loop_condition {
759 Some(c) => c,
760 None => return NodeResult::failed(node_id, "No loop condition", 0),
761 };
762
763 let mut iteration = 0;
764 while iteration < max_iter {
765 if !condition(ctx.clone(), input.clone()).await {
767 debug!(
768 "Loop {} condition false, exiting after {} iterations",
769 node_id, iteration
770 );
771 break;
772 }
773
774 match executor(ctx.clone(), input.clone()).await {
776 Ok(output) => {
777 input = output;
778 ctx.set_variable(
779 &format!("{}_iteration", node_id),
780 WorkflowValue::Int(iteration as i64),
781 )
782 .await;
783 }
784 Err(e) => {
785 return NodeResult::failed(
786 node_id,
787 &format!("Loop failed at iteration {}: {}", iteration, e),
788 start_time.elapsed().as_millis() as u64,
789 );
790 }
791 }
792
793 iteration += 1;
794 }
795
796 if iteration >= max_iter {
797 warn!("Loop {} reached max iterations: {}", node_id, max_iter);
798 }
799
800 NodeResult::success(node_id, input, start_time.elapsed().as_millis() as u64)
801 }
802}
803
804#[cfg(test)]
805mod tests {
806 use super::*;
807
808 #[tokio::test]
809 async fn test_task_node() {
810 let node = WorkflowNode::task("task1", "Test Task", |_ctx, input| async move {
811 let value = input.as_i64().unwrap_or(0);
812 Ok(WorkflowValue::Int(value * 2))
813 });
814
815 let ctx = WorkflowContext::new("test");
816 let result = node.execute(&ctx, WorkflowValue::Int(21)).await;
817
818 assert!(result.status.is_success());
819 assert_eq!(result.output.as_i64(), Some(42));
820 }
821
822 #[tokio::test]
823 async fn test_condition_node() {
824 let node = WorkflowNode::condition("cond1", "Check Value", |_ctx, input| async move {
825 input.as_i64().unwrap_or(0) > 10
826 });
827
828 let ctx = WorkflowContext::new("test");
829
830 let result = node.execute(&ctx, WorkflowValue::Int(20)).await;
831 assert_eq!(result.output.as_str(), Some("true"));
832
833 let result = node.execute(&ctx, WorkflowValue::Int(5)).await;
834 assert_eq!(result.output.as_str(), Some("false"));
835 }
836
837 #[tokio::test]
838 async fn test_loop_node() {
839 let node = WorkflowNode::loop_node(
840 "loop1",
841 "Counter Loop",
842 |_ctx, input| async move {
843 let value = input.as_i64().unwrap_or(0);
844 Ok(WorkflowValue::Int(value + 1))
845 },
846 |_ctx, input| async move { input.as_i64().unwrap_or(0) < 5 },
847 100,
848 );
849
850 let ctx = WorkflowContext::new("test");
851 let result = node.execute(&ctx, WorkflowValue::Int(0)).await;
852
853 assert!(result.status.is_success());
854 assert_eq!(result.output.as_i64(), Some(5));
855 }
856
857 #[test]
858 fn test_retry_policy() {
859 let policy = RetryPolicy::default();
860
861 assert_eq!(policy.get_delay(0), 1000);
862 assert_eq!(policy.get_delay(1), 2000);
863 assert_eq!(policy.get_delay(2), 4000);
864 assert_eq!(policy.get_delay(10), 30000); }
866}