1use serde::{Deserialize, Serialize};
2use uuid::Uuid;
3
4#[cfg(feature = "openapi")]
5use utoipa::ToSchema;
6
7pub type NodeId = Uuid;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12#[cfg_attr(feature = "openapi", derive(ToSchema))]
13pub struct Node {
14 #[cfg_attr(feature = "openapi", schema(value_type = String))]
16 pub id: NodeId,
17
18 pub name: String,
20
21 pub kind: NodeKind,
23
24 pub position: Option<(f64, f64)>,
26
27 #[serde(default)]
29 pub retry_config: Option<RetryConfig>,
30
31 #[serde(default)]
33 pub timeout_config: Option<TimeoutConfig>,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38#[cfg_attr(feature = "openapi", derive(ToSchema))]
39pub struct RetryConfig {
40 pub max_retries: u32,
42
43 pub initial_delay_ms: u64,
45
46 pub backoff_multiplier: f64,
48
49 pub max_delay_ms: u64,
51}
52
53impl Default for RetryConfig {
54 fn default() -> Self {
55 Self {
56 max_retries: 3,
57 initial_delay_ms: 1000,
58 backoff_multiplier: 2.0,
59 max_delay_ms: 30000,
60 }
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66#[cfg_attr(feature = "openapi", derive(ToSchema))]
67pub struct TimeoutConfig {
68 pub execution_timeout_ms: u64,
70
71 #[serde(default)]
73 pub idle_timeout_ms: Option<u64>,
74
75 #[serde(default)]
77 pub timeout_action: TimeoutAction,
78}
79
80impl Default for TimeoutConfig {
81 fn default() -> Self {
82 Self {
83 execution_timeout_ms: 60000, idle_timeout_ms: None,
85 timeout_action: TimeoutAction::Fail,
86 }
87 }
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
92#[cfg_attr(feature = "openapi", derive(ToSchema))]
93pub enum TimeoutAction {
94 #[default]
96 Fail,
97
98 Skip,
100
101 UseDefault(String),
103}
104
105impl Node {
106 pub fn new(name: String, kind: NodeKind) -> Self {
107 Self {
108 id: Uuid::new_v4(),
109 name,
110 kind,
111 position: None,
112 retry_config: None,
113 timeout_config: None,
114 }
115 }
116
117 pub fn with_retry(mut self, retry_config: RetryConfig) -> Self {
118 self.retry_config = Some(retry_config);
119 self
120 }
121
122 pub fn with_timeout(mut self, timeout_config: TimeoutConfig) -> Self {
123 self.timeout_config = Some(timeout_config);
124 self
125 }
126
127 pub fn with_position(mut self, x: f64, y: f64) -> Self {
128 self.position = Some((x, y));
129 self
130 }
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135#[cfg_attr(feature = "openapi", derive(ToSchema))]
136#[serde(tag = "type", content = "config")]
137pub enum NodeKind {
138 Start,
140
141 End,
143
144 LLM(LlmConfig),
146
147 Retriever(VectorConfig),
149
150 Code(ScriptConfig),
152
153 IfElse(Condition),
155
156 Tool(McpConfig),
158
159 Loop(LoopConfig),
161
162 TryCatch(TryCatchConfig),
164
165 SubWorkflow(SubWorkflowConfig),
167
168 Switch(SwitchConfig),
170
171 Parallel(ParallelConfig),
173
174 Approval(ApprovalConfig),
176
177 Form(FormConfig),
179
180 Vision(VisionConfig),
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
186#[cfg_attr(feature = "openapi", derive(ToSchema))]
187pub struct LlmConfig {
188 pub provider: String,
190
191 pub model: String,
193
194 pub system_prompt: Option<String>,
196
197 pub prompt_template: String,
199
200 pub temperature: Option<f64>,
202
203 pub max_tokens: Option<u32>,
205
206 #[serde(default)]
208 pub tools: Vec<serde_json::Value>,
209
210 #[serde(default)]
212 pub images: Vec<serde_json::Value>,
213
214 #[serde(default)]
216 pub extra_params: serde_json::Value,
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
221#[cfg_attr(feature = "openapi", derive(ToSchema))]
222pub struct VectorConfig {
223 pub db_type: String,
225
226 pub collection: String,
228
229 pub query: String,
231
232 pub top_k: usize,
234
235 pub score_threshold: Option<f64>,
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
241#[cfg_attr(feature = "openapi", derive(ToSchema))]
242pub struct ScriptConfig {
243 pub runtime: String,
245
246 pub code: String,
248
249 #[serde(default)]
251 pub inputs: Vec<String>,
252
253 pub output: String,
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
259#[cfg_attr(feature = "openapi", derive(ToSchema))]
260pub struct Condition {
261 pub expression: String,
263
264 #[cfg_attr(feature = "openapi", schema(value_type = String))]
266 pub true_branch: NodeId,
267
268 #[cfg_attr(feature = "openapi", schema(value_type = String))]
270 pub false_branch: NodeId,
271}
272
273#[derive(Debug, Clone, Serialize, Deserialize)]
275#[cfg_attr(feature = "openapi", derive(ToSchema))]
276pub struct McpConfig {
277 pub server_id: String,
279
280 pub tool_name: String,
282
283 #[serde(default)]
285 pub parameters: serde_json::Value,
286}
287
288#[derive(Debug, Clone, Serialize, Deserialize)]
290#[cfg_attr(feature = "openapi", derive(ToSchema))]
291pub struct LoopConfig {
292 pub loop_type: LoopType,
294
295 #[serde(default = "default_max_iterations")]
297 pub max_iterations: usize,
298}
299
300fn default_max_iterations() -> usize {
301 1000
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
306#[cfg_attr(feature = "openapi", derive(ToSchema))]
307#[serde(tag = "variant")]
308pub enum LoopType {
309 ForEach {
311 collection_path: String,
313
314 item_variable: String,
316
317 #[serde(default)]
319 index_variable: Option<String>,
320
321 body_expression: String,
324
325 #[serde(default)]
327 parallel: bool,
328
329 #[serde(default)]
332 max_concurrency: Option<usize>,
333 },
334
335 While {
337 condition: String,
339
340 body_expression: String,
342
343 #[serde(default)]
345 counter_variable: Option<String>,
346 },
347
348 Repeat {
350 count: String,
352
353 body_expression: String,
355
356 #[serde(default)]
358 index_variable: Option<String>,
359 },
360}
361
362#[derive(Debug, Clone, Serialize, Deserialize)]
364#[cfg_attr(feature = "openapi", derive(ToSchema))]
365pub struct TryCatchConfig {
366 pub try_expression: String,
368
369 #[serde(default)]
372 pub catch_expression: Option<String>,
373
374 #[serde(default)]
376 pub finally_expression: Option<String>,
377
378 #[serde(default)]
381 pub rethrow: bool,
382
383 #[serde(default = "default_error_variable")]
385 pub error_variable: String,
386}
387
388fn default_error_variable() -> String {
389 "error".to_string()
390}
391
392#[derive(Debug, Clone, Serialize, Deserialize)]
394#[cfg_attr(feature = "openapi", derive(ToSchema))]
395pub struct SubWorkflowConfig {
396 pub workflow_path: String,
398
399 #[serde(default)]
402 pub input_mappings: std::collections::HashMap<String, String>,
403
404 #[serde(default)]
407 pub output_variable: Option<String>,
408
409 #[serde(default)]
411 pub inherit_context: bool,
412}
413
414#[derive(Debug, Clone, Serialize, Deserialize)]
416#[cfg_attr(feature = "openapi", derive(ToSchema))]
417pub struct SwitchConfig {
418 pub switch_on: String,
420
421 pub cases: Vec<SwitchCase>,
423
424 #[serde(default)]
427 pub default_case: Option<String>,
428}
429
430#[derive(Debug, Clone, Serialize, Deserialize)]
432#[cfg_attr(feature = "openapi", derive(ToSchema))]
433pub struct SwitchCase {
434 pub match_value: String,
437
438 pub action: String,
441}
442
443#[derive(Debug, Clone, Serialize, Deserialize)]
445#[cfg_attr(feature = "openapi", derive(ToSchema))]
446pub struct ParallelConfig {
447 pub strategy: ParallelStrategy,
449
450 pub tasks: Vec<ParallelTask>,
453
454 #[serde(default)]
456 pub max_concurrency: Option<usize>,
457
458 #[serde(default)]
460 pub timeout_ms: Option<u64>,
461}
462
463#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
465#[cfg_attr(feature = "openapi", derive(ToSchema))]
466pub enum ParallelStrategy {
467 WaitAll,
470
471 Race,
474
475 AllSettled,
478}
479
480#[derive(Debug, Clone, Serialize, Deserialize)]
482#[cfg_attr(feature = "openapi", derive(ToSchema))]
483pub struct ParallelTask {
484 pub id: String,
486
487 pub expression: String,
489
490 #[serde(default)]
492 pub description: Option<String>,
493}
494
495#[derive(Debug, Clone, Serialize, Deserialize)]
497#[cfg_attr(feature = "openapi", derive(ToSchema))]
498pub struct ApprovalConfig {
499 pub message: String,
501
502 pub description: Option<String>,
504
505 #[serde(default)]
507 pub approvers: Vec<String>,
508
509 pub timeout_seconds: Option<u64>,
511
512 #[serde(default)]
514 pub context_data: serde_json::Value,
515}
516
517#[derive(Debug, Clone, Serialize, Deserialize)]
519#[cfg_attr(feature = "openapi", derive(ToSchema))]
520pub struct FormConfig {
521 pub title: String,
523
524 pub description: Option<String>,
526
527 pub fields: Vec<FormField>,
529
530 pub timeout_seconds: Option<u64>,
532
533 #[serde(default)]
535 pub allowed_submitters: Vec<String>,
536}
537
538#[derive(Debug, Clone, Serialize, Deserialize)]
540#[cfg_attr(feature = "openapi", derive(ToSchema))]
541pub struct FormField {
542 pub id: String,
544
545 pub label: String,
547
548 pub field_type: FormFieldType,
550
551 #[serde(default)]
553 pub required: bool,
554
555 pub default_value: Option<serde_json::Value>,
557
558 #[serde(default)]
560 pub validation: Option<serde_json::Value>,
561
562 #[serde(default)]
564 pub options: Vec<String>,
565}
566
567#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
569#[cfg_attr(feature = "openapi", derive(ToSchema))]
570pub enum FormFieldType {
571 Text,
572 Number,
573 Email,
574 Password,
575 TextArea,
576 Select,
577 MultiSelect,
578 Radio,
579 Checkbox,
580 Date,
581 DateTime,
582}
583
584#[derive(Debug, Clone, Serialize, Deserialize)]
586#[cfg_attr(feature = "openapi", derive(ToSchema))]
587pub struct VisionConfig {
588 pub provider: String,
590
591 #[serde(default)]
593 pub model_path: Option<String>,
594
595 #[serde(default = "default_output_format")]
597 pub output_format: String,
598
599 #[serde(default)]
601 pub use_gpu: bool,
602
603 #[serde(default)]
605 pub language: Option<String>,
606
607 pub image_input: String,
609
610 #[serde(default)]
612 pub options: serde_json::Value,
613}
614
615fn default_output_format() -> String {
616 "markdown".to_string()
617}
618
619#[cfg(test)]
620mod tests {
621 use super::*;
622
623 #[test]
624 fn test_node_creation() {
625 let node = Node::new(
626 "Test LLM".to_string(),
627 NodeKind::LLM(LlmConfig {
628 provider: "openai".to_string(),
629 model: "gpt-4".to_string(),
630 system_prompt: None,
631 prompt_template: "Hello {{input}}".to_string(),
632 temperature: Some(0.7),
633 max_tokens: Some(1000),
634 tools: vec![],
635 images: vec![],
636 extra_params: serde_json::Value::Null,
637 }),
638 )
639 .with_position(100.0, 200.0);
640
641 assert_eq!(node.name, "Test LLM");
642 assert_eq!(node.position, Some((100.0, 200.0)));
643 }
644
645 #[test]
646 fn test_switch_node() {
647 let switch_config = SwitchConfig {
648 switch_on: "{{status}}".to_string(),
649 cases: vec![
650 SwitchCase {
651 match_value: "success".to_string(),
652 action: "process_success".to_string(),
653 },
654 SwitchCase {
655 match_value: "error".to_string(),
656 action: "handle_error".to_string(),
657 },
658 ],
659 default_case: Some("unknown_status".to_string()),
660 };
661
662 let node = Node::new(
663 "Status Router".to_string(),
664 NodeKind::Switch(switch_config.clone()),
665 );
666
667 assert_eq!(node.name, "Status Router");
668 if let NodeKind::Switch(config) = &node.kind {
669 assert_eq!(config.switch_on, "{{status}}");
670 assert_eq!(config.cases.len(), 2);
671 assert_eq!(config.default_case, Some("unknown_status".to_string()));
672 } else {
673 panic!("Expected Switch node");
674 }
675 }
676
677 #[test]
678 fn test_parallel_node() {
679 let parallel_config = ParallelConfig {
680 strategy: ParallelStrategy::WaitAll,
681 tasks: vec![
682 ParallelTask {
683 id: "task1".to_string(),
684 expression: "{{query1}}".to_string(),
685 description: Some("First query".to_string()),
686 },
687 ParallelTask {
688 id: "task2".to_string(),
689 expression: "{{query2}}".to_string(),
690 description: Some("Second query".to_string()),
691 },
692 ],
693 max_concurrency: Some(2),
694 timeout_ms: Some(30000),
695 };
696
697 let node = Node::new(
698 "Parallel Execution".to_string(),
699 NodeKind::Parallel(parallel_config.clone()),
700 );
701
702 assert_eq!(node.name, "Parallel Execution");
703 if let NodeKind::Parallel(config) = &node.kind {
704 assert_eq!(config.strategy, ParallelStrategy::WaitAll);
705 assert_eq!(config.tasks.len(), 2);
706 assert_eq!(config.max_concurrency, Some(2));
707 assert_eq!(config.timeout_ms, Some(30000));
708 } else {
709 panic!("Expected Parallel node");
710 }
711 }
712
713 #[test]
714 fn test_parallel_strategy_race() {
715 let parallel_config = ParallelConfig {
716 strategy: ParallelStrategy::Race,
717 tasks: vec![
718 ParallelTask {
719 id: "fast".to_string(),
720 expression: "{{fast_api}}".to_string(),
721 description: None,
722 },
723 ParallelTask {
724 id: "slow".to_string(),
725 expression: "{{slow_api}}".to_string(),
726 description: None,
727 },
728 ],
729 max_concurrency: None,
730 timeout_ms: None,
731 };
732
733 let node = Node::new(
734 "Race Condition".to_string(),
735 NodeKind::Parallel(parallel_config),
736 );
737
738 if let NodeKind::Parallel(config) = &node.kind {
739 assert_eq!(config.strategy, ParallelStrategy::Race);
740 } else {
741 panic!("Expected Parallel node");
742 }
743 }
744
745 #[test]
746 fn test_approval_node() {
747 let approval_config = ApprovalConfig {
748 message: "Please approve this action".to_string(),
749 description: Some("This will deploy to production".to_string()),
750 approvers: vec!["admin".to_string(), "manager".to_string()],
751 timeout_seconds: Some(3600),
752 context_data: serde_json::json!({
753 "deployment": "production",
754 "version": "1.0.0"
755 }),
756 };
757
758 let node = Node::new(
759 "Production Approval".to_string(),
760 NodeKind::Approval(approval_config.clone()),
761 );
762
763 assert_eq!(node.name, "Production Approval");
764 if let NodeKind::Approval(config) = &node.kind {
765 assert_eq!(config.message, "Please approve this action");
766 assert_eq!(config.approvers.len(), 2);
767 assert_eq!(config.timeout_seconds, Some(3600));
768 } else {
769 panic!("Expected Approval node");
770 }
771 }
772
773 #[test]
774 fn test_form_node() {
775 let form_config = FormConfig {
776 title: "User Information".to_string(),
777 description: Some("Please provide your details".to_string()),
778 fields: vec![
779 FormField {
780 id: "name".to_string(),
781 label: "Full Name".to_string(),
782 field_type: FormFieldType::Text,
783 required: true,
784 default_value: None,
785 validation: None,
786 options: vec![],
787 },
788 FormField {
789 id: "email".to_string(),
790 label: "Email Address".to_string(),
791 field_type: FormFieldType::Email,
792 required: true,
793 default_value: None,
794 validation: None,
795 options: vec![],
796 },
797 FormField {
798 id: "age".to_string(),
799 label: "Age".to_string(),
800 field_type: FormFieldType::Number,
801 required: false,
802 default_value: None,
803 validation: Some(serde_json::json!({"min": 18, "max": 100})),
804 options: vec![],
805 },
806 ],
807 timeout_seconds: Some(600),
808 allowed_submitters: vec!["user1".to_string()],
809 };
810
811 let node = Node::new("User Form".to_string(), NodeKind::Form(form_config.clone()));
812
813 assert_eq!(node.name, "User Form");
814 if let NodeKind::Form(config) = &node.kind {
815 assert_eq!(config.title, "User Information");
816 assert_eq!(config.fields.len(), 3);
817 assert_eq!(config.fields[0].field_type, FormFieldType::Text);
818 assert_eq!(config.fields[1].field_type, FormFieldType::Email);
819 assert_eq!(config.fields[2].field_type, FormFieldType::Number);
820 assert!(config.fields[0].required);
821 assert!(!config.fields[2].required);
822 } else {
823 panic!("Expected Form node");
824 }
825 }
826
827 #[test]
828 fn test_form_field_types() {
829 let field_types = vec![
830 FormFieldType::Text,
831 FormFieldType::Number,
832 FormFieldType::Email,
833 FormFieldType::Password,
834 FormFieldType::TextArea,
835 FormFieldType::Select,
836 FormFieldType::MultiSelect,
837 FormFieldType::Radio,
838 FormFieldType::Checkbox,
839 FormFieldType::Date,
840 FormFieldType::DateTime,
841 ];
842
843 for field_type in field_types {
845 let _field = FormField {
846 id: "test".to_string(),
847 label: "Test".to_string(),
848 field_type: field_type.clone(),
849 required: false,
850 default_value: None,
851 validation: None,
852 options: vec![],
853 };
854 }
855 }
856
857 #[test]
858 fn test_node_with_retry_and_timeout() {
859 let node = Node::new("Resilient Node".to_string(), NodeKind::Start)
860 .with_retry(RetryConfig {
861 max_retries: 5,
862 initial_delay_ms: 500,
863 backoff_multiplier: 3.0,
864 max_delay_ms: 60000,
865 })
866 .with_timeout(TimeoutConfig {
867 execution_timeout_ms: 10000,
868 idle_timeout_ms: Some(5000),
869 timeout_action: TimeoutAction::Skip,
870 });
871
872 assert!(node.retry_config.is_some());
873 assert!(node.timeout_config.is_some());
874
875 if let Some(retry) = &node.retry_config {
876 assert_eq!(retry.max_retries, 5);
877 assert_eq!(retry.backoff_multiplier, 3.0);
878 }
879
880 if let Some(timeout) = &node.timeout_config {
881 assert_eq!(timeout.execution_timeout_ms, 10000);
882 assert_eq!(timeout.timeout_action, TimeoutAction::Skip);
883 }
884 }
885
886 #[test]
887 fn test_foreach_parallel_execution() {
888 let loop_config = LoopConfig {
889 loop_type: LoopType::ForEach {
890 collection_path: "items".to_string(),
891 item_variable: "item".to_string(),
892 index_variable: Some("idx".to_string()),
893 body_expression: "process({{item}})".to_string(),
894 parallel: true,
895 max_concurrency: Some(10),
896 },
897 max_iterations: 1000,
898 };
899
900 let node = Node::new("Parallel Loop".to_string(), NodeKind::Loop(loop_config));
901
902 if let NodeKind::Loop(config) = &node.kind {
903 if let LoopType::ForEach {
904 parallel,
905 max_concurrency,
906 collection_path,
907 item_variable,
908 ..
909 } = &config.loop_type
910 {
911 assert!(parallel);
912 assert_eq!(*max_concurrency, Some(10));
913 assert_eq!(collection_path, "items");
914 assert_eq!(item_variable, "item");
915 } else {
916 panic!("Expected ForEach loop");
917 }
918 } else {
919 panic!("Expected Loop node");
920 }
921 }
922
923 #[test]
924 fn test_foreach_sequential_execution() {
925 let loop_config = LoopConfig {
926 loop_type: LoopType::ForEach {
927 collection_path: "items".to_string(),
928 item_variable: "item".to_string(),
929 index_variable: None,
930 body_expression: "process({{item}})".to_string(),
931 parallel: false,
932 max_concurrency: None,
933 },
934 max_iterations: 1000,
935 };
936
937 let node = Node::new("Sequential Loop".to_string(), NodeKind::Loop(loop_config));
938
939 if let NodeKind::Loop(config) = &node.kind {
940 if let LoopType::ForEach {
941 parallel,
942 max_concurrency,
943 ..
944 } = &config.loop_type
945 {
946 assert!(!parallel);
947 assert_eq!(*max_concurrency, None);
948 } else {
949 panic!("Expected ForEach loop");
950 }
951 } else {
952 panic!("Expected Loop node");
953 }
954 }
955
956 #[test]
957 fn test_foreach_serialization_with_parallel() {
958 let loop_config = LoopConfig {
959 loop_type: LoopType::ForEach {
960 collection_path: "data".to_string(),
961 item_variable: "x".to_string(),
962 index_variable: Some("i".to_string()),
963 body_expression: "{{x}} * 2".to_string(),
964 parallel: true,
965 max_concurrency: Some(5),
966 },
967 max_iterations: 100,
968 };
969
970 let json = serde_json::to_string(&loop_config).unwrap();
972 let deserialized: LoopConfig = serde_json::from_str(&json).unwrap();
973
974 if let LoopType::ForEach {
975 parallel,
976 max_concurrency,
977 ..
978 } = deserialized.loop_type
979 {
980 assert!(parallel);
981 assert_eq!(max_concurrency, Some(5));
982 } else {
983 panic!("Expected ForEach loop");
984 }
985 }
986}