1use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20
21#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
23#[serde(rename_all = "snake_case")]
24pub enum NodeType {
25 Start,
27 End,
29 Task,
31 Condition,
33 Parallel,
35 Pipeline,
37 SubWorkflow,
39 Wait,
41 Approval,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
47#[serde(rename_all = "snake_case")]
48pub enum ExecutionMode {
49 Pipeline,
54
55 #[default]
59 Parallel,
60}
61
62impl ExecutionMode {
63 pub fn has_barrier(&self) -> bool {
65 match self {
66 Self::Pipeline => false,
67 Self::Parallel => true,
68 }
69 }
70
71 pub fn display_name(&self) -> &'static str {
73 match self {
74 Self::Pipeline => "流式",
75 Self::Parallel => "并行",
76 }
77 }
78
79 pub fn description(&self) -> &'static str {
81 match self {
82 Self::Pipeline => "任务流转执行,不等待其他任务",
83 Self::Parallel => "所有任务并行执行,等待全部完成",
84 }
85 }
86}
87
88#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
90#[serde(rename_all = "snake_case")]
91pub enum FailureStrategyType {
92 Retry,
93 Ignore,
94 Abort,
95 Goto,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct FailureStrategyConfig {
101 #[serde(rename = "type", default = "default_failure_strategy_type")]
103 pub strategy_type: FailureStrategyType,
104 #[serde(skip_serializing_if = "Option::is_none")]
106 pub max_attempts: Option<u32>,
107 #[serde(skip_serializing_if = "Option::is_none")]
109 pub interval_ms: Option<u64>,
110 #[serde(skip_serializing_if = "Option::is_none")]
112 pub target: Option<String>,
113}
114
115fn default_failure_strategy_type() -> FailureStrategyType {
116 FailureStrategyType::Abort
117}
118
119impl From<FailureStrategyConfig> for FailureStrategy {
120 fn from(config: FailureStrategyConfig) -> Self {
121 match config.strategy_type {
122 FailureStrategyType::Retry => FailureStrategy::Retry {
123 max_attempts: config.max_attempts.unwrap_or(1),
124 interval_ms: config.interval_ms,
125 },
126 FailureStrategyType::Ignore => FailureStrategy::Ignore,
127 FailureStrategyType::Abort => FailureStrategy::Abort,
128 FailureStrategyType::Goto => FailureStrategy::Goto {
129 target: config.target.unwrap_or_default(),
130 },
131 }
132 }
133}
134
135impl From<FailureStrategy> for FailureStrategyConfig {
136 fn from(strategy: FailureStrategy) -> Self {
137 match strategy {
138 FailureStrategy::Retry {
139 max_attempts,
140 interval_ms,
141 } => FailureStrategyConfig {
142 strategy_type: FailureStrategyType::Retry,
143 max_attempts: Some(max_attempts),
144 interval_ms,
145 target: None,
146 },
147 FailureStrategy::Ignore => FailureStrategyConfig {
148 strategy_type: FailureStrategyType::Ignore,
149 max_attempts: None,
150 interval_ms: None,
151 target: None,
152 },
153 FailureStrategy::Abort => FailureStrategyConfig {
154 strategy_type: FailureStrategyType::Abort,
155 max_attempts: None,
156 interval_ms: None,
157 target: None,
158 },
159 FailureStrategy::Goto { target } => FailureStrategyConfig {
160 strategy_type: FailureStrategyType::Goto,
161 max_attempts: None,
162 interval_ms: None,
163 target: Some(target),
164 },
165 }
166 }
167}
168
169#[derive(Debug, Clone, PartialEq, Eq, Default)]
171pub enum FailureStrategy {
172 Retry {
174 max_attempts: u32,
176 interval_ms: Option<u64>,
178 },
179 Ignore,
181 #[default]
183 Abort,
184 Goto {
186 target: String,
188 },
189}
190
191impl Serialize for FailureStrategy {
192 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
193 where
194 S: serde::Serializer,
195 {
196 let config: FailureStrategyConfig = self.clone().into();
197 config.serialize(serializer)
198 }
199}
200
201impl<'de> Deserialize<'de> for FailureStrategy {
202 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
203 where
204 D: serde::Deserializer<'de>,
205 {
206 let config: FailureStrategyConfig = FailureStrategyConfig::deserialize(deserializer)?;
207 Ok(config.into())
208 }
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct EdgeDef {
214 #[serde(default = "generate_edge_id")]
216 pub id: String,
217 pub from: String,
219 pub to: String,
221 #[serde(skip_serializing_if = "Option::is_none")]
223 pub condition: Option<String>,
224 #[serde(skip_serializing_if = "Option::is_none")]
226 pub label: Option<String>,
227}
228
229fn generate_edge_id() -> String {
230 format!("edge_{}", uuid::Uuid::new_v4())
231}
232
233#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct NodeDef {
236 pub id: String,
238 #[serde(rename = "type")]
240 pub node_type: NodeType,
241 pub name: String,
243 #[serde(skip_serializing_if = "Option::is_none")]
245 pub description: Option<String>,
246 #[serde(skip_serializing_if = "Option::is_none")]
248 pub task: Option<String>,
249 #[serde(default)]
251 pub params: HashMap<String, serde_json::Value>,
252 #[serde(default)]
254 pub on_failure: FailureStrategy,
255 #[serde(skip_serializing_if = "Option::is_none")]
257 pub timeout_ms: Option<u64>,
258 #[serde(skip_serializing_if = "Option::is_none")]
260 pub branches: Option<Vec<BranchDef>>,
261 #[serde(skip_serializing_if = "Option::is_none")]
263 pub parallel_branches: Option<Vec<ParallelBranchDef>>,
264 #[serde(skip_serializing_if = "Option::is_none")]
266 pub execution_mode: Option<ExecutionMode>,
267 #[serde(skip_serializing_if = "Option::is_none")]
269 pub workflow: Option<String>,
270 #[serde(skip_serializing_if = "Option::is_none")]
272 pub wait_ms: Option<u64>,
273 #[serde(skip_serializing_if = "Option::is_none")]
275 pub approvers: Option<Vec<String>>,
276}
277
278impl NodeDef {
279 pub fn get_execution_mode(&self) -> Option<ExecutionMode> {
284 match self.node_type {
285 NodeType::Pipeline => Some(ExecutionMode::Pipeline),
286 NodeType::Parallel => Some(
287 self.execution_mode.unwrap_or(ExecutionMode::Parallel)
288 ),
289 _ => None,
290 }
291 }
292
293 pub fn has_barrier(&self) -> bool {
295 self.get_execution_mode()
296 .map(|m| m.has_barrier())
297 .unwrap_or(false)
298 }
299}
300
301#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct BranchDef {
304 pub name: String,
306 pub condition: String,
308 pub target: String,
310}
311
312#[derive(Debug, Clone, Serialize, Deserialize)]
314pub struct ParallelBranchDef {
315 pub name: String,
317 pub nodes: Vec<NodeDef>,
319 #[serde(default)]
323 pub mode: ExecutionMode,
324}
325
326impl ParallelBranchDef {
327 pub fn new(name: String, nodes: Vec<NodeDef>) -> Self {
329 Self {
330 name,
331 nodes,
332 mode: ExecutionMode::default(),
333 }
334 }
335
336 pub fn pipeline(name: String, nodes: Vec<NodeDef>) -> Self {
338 Self {
339 name,
340 nodes,
341 mode: ExecutionMode::Pipeline,
342 }
343 }
344
345 pub fn parallel(name: String, nodes: Vec<NodeDef>) -> Self {
347 Self {
348 name,
349 nodes,
350 mode: ExecutionMode::Parallel,
351 }
352 }
353
354 pub fn has_barrier(&self) -> bool {
356 self.mode.has_barrier()
357 }
358
359 pub fn mode_description(&self) -> &'static str {
361 self.mode.description()
362 }
363}
364
365#[derive(Debug, Clone, Serialize, Deserialize)]
367pub struct WorkflowDef {
368 pub id: String,
370 pub name: String,
372 #[serde(default = "default_version")]
374 pub version: String,
375 #[serde(skip_serializing_if = "Option::is_none")]
377 pub description: Option<String>,
378 #[serde(default)]
380 pub inputs: Vec<InputDef>,
381 #[serde(default)]
383 pub outputs: Vec<OutputDef>,
384 pub nodes: Vec<NodeDef>,
386 #[serde(default)]
388 pub edges: Vec<EdgeDef>,
389 #[serde(default)]
391 pub variables: HashMap<String, serde_json::Value>,
392 #[serde(default)]
394 pub default_failure_strategy: FailureStrategy,
395 #[serde(skip_serializing_if = "Option::is_none")]
397 pub timeout_ms: Option<u64>,
398}
399
400fn default_version() -> String {
401 "1.0.0".to_string()
402}
403
404#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct InputDef {
407 pub name: String,
409 #[serde(rename = "type", default = "default_input_type")]
411 pub input_type: String,
412 #[serde(default)]
414 pub required: bool,
415 #[serde(skip_serializing_if = "Option::is_none")]
417 pub default: Option<serde_json::Value>,
418 #[serde(skip_serializing_if = "Option::is_none")]
420 pub description: Option<String>,
421}
422
423fn default_input_type() -> String {
424 "string".to_string()
425}
426
427#[derive(Debug, Clone, Serialize, Deserialize)]
429pub struct OutputDef {
430 pub name: String,
432 pub value: String,
434 #[serde(skip_serializing_if = "Option::is_none")]
436 pub description: Option<String>,
437}
438
439impl WorkflowDef {
440 pub fn get_node(&self, id: &str) -> Option<&NodeDef> {
442 self.nodes.iter().find(|n| n.id == id)
443 }
444
445 pub fn get_start_node(&self) -> Option<&NodeDef> {
447 self.nodes.iter().find(|n| n.node_type == NodeType::Start)
448 }
449
450 pub fn get_end_node(&self) -> Option<&NodeDef> {
452 self.nodes.iter().find(|n| n.node_type == NodeType::End)
453 }
454
455 pub fn get_outgoing_edges(&self, node_id: &str) -> Vec<&EdgeDef> {
457 self.edges.iter().filter(|e| e.from == node_id).collect()
458 }
459
460 pub fn validate(&self) -> anyhow::Result<()> {
462 if self.get_start_node().is_none() {
464 anyhow::bail!("Workflow must have a start node");
465 }
466
467 if self.get_end_node().is_none() {
469 anyhow::bail!("Workflow must have an end node");
470 }
471
472 let mut node_ids = std::collections::HashSet::new();
474 for node in &self.nodes {
475 if !node_ids.insert(&node.id) {
476 anyhow::bail!("Duplicate node id: {}", node.id);
477 }
478 }
479
480 for edge in &self.edges {
482 if !node_ids.contains(&edge.from) {
483 anyhow::bail!("Edge references unknown source node: {}", edge.from);
484 }
485 if !node_ids.contains(&edge.to) {
486 anyhow::bail!("Edge references unknown target node: {}", edge.to);
487 }
488 }
489
490 for input in &self.inputs {
492 if input.required && input.default.is_none() {
493 }
495 }
496
497 Ok(())
498 }
499}
500
501#[cfg(test)]
502mod tests {
503 use super::*;
504
505 #[test]
506 fn test_workflow_def_validation() {
507 let workflow = WorkflowDef {
508 id: "test-workflow".to_string(),
509 name: "Test Workflow".to_string(),
510 version: "1.0.0".to_string(),
511 description: None,
512 inputs: vec![],
513 outputs: vec![],
514 nodes: vec![
515 NodeDef {
516 id: "start".to_string(),
517 node_type: NodeType::Start,
518 name: "Start".to_string(),
519 description: None,
520 task: None,
521 params: HashMap::new(),
522 on_failure: FailureStrategy::default(),
523 timeout_ms: None,
524 branches: None,
525 parallel_branches: None,
526 execution_mode: None,
527 workflow: None,
528 wait_ms: None,
529 approvers: None,
530 },
531 NodeDef {
532 id: "end".to_string(),
533 node_type: NodeType::End,
534 name: "End".to_string(),
535 description: None,
536 task: None,
537 params: HashMap::new(),
538 on_failure: FailureStrategy::default(),
539 timeout_ms: None,
540 branches: None,
541 parallel_branches: None,
542 execution_mode: None,
543 workflow: None,
544 wait_ms: None,
545 approvers: None,
546 },
547 ],
548 edges: vec![EdgeDef {
549 id: "e1".to_string(),
550 from: "start".to_string(),
551 to: "end".to_string(),
552 condition: None,
553 label: None,
554 }],
555 variables: HashMap::new(),
556 default_failure_strategy: FailureStrategy::default(),
557 timeout_ms: None,
558 };
559
560 assert!(workflow.validate().is_ok());
561 }
562
563 #[test]
564 fn test_execution_mode_default() {
565 let mode = ExecutionMode::default();
566 assert_eq!(mode, ExecutionMode::Parallel);
567 assert!(mode.has_barrier());
568 }
569
570 #[test]
571 fn test_execution_mode_pipeline_no_barrier() {
572 let mode = ExecutionMode::Pipeline;
573 assert!(!mode.has_barrier());
574 assert_eq!(mode.display_name(), "流式");
575 }
576
577 #[test]
578 fn test_execution_mode_parallel_has_barrier() {
579 let mode = ExecutionMode::Parallel;
580 assert!(mode.has_barrier());
581 assert_eq!(mode.display_name(), "并行");
582 }
583
584 #[test]
585 fn test_parallel_branch_def_default_mode() {
586 let branch = ParallelBranchDef::new("test".to_string(), vec![]);
587 assert_eq!(branch.mode, ExecutionMode::Parallel);
588 assert!(branch.has_barrier());
589 }
590
591 #[test]
592 fn test_parallel_branch_def_pipeline_mode() {
593 let branch = ParallelBranchDef::pipeline("test".to_string(), vec![]);
594 assert_eq!(branch.mode, ExecutionMode::Pipeline);
595 assert!(!branch.has_barrier());
596 }
597
598 #[test]
599 fn test_parallel_branch_def_parallel_mode() {
600 let branch = ParallelBranchDef::parallel("test".to_string(), vec![]);
601 assert_eq!(branch.mode, ExecutionMode::Parallel);
602 assert!(branch.has_barrier());
603 }
604
605 #[test]
606 fn test_node_def_get_execution_mode_pipeline() {
607 let node = NodeDef {
608 id: "pipeline-node".to_string(),
609 node_type: NodeType::Pipeline,
610 name: "Pipeline Node".to_string(),
611 description: None,
612 task: None,
613 params: HashMap::new(),
614 on_failure: FailureStrategy::default(),
615 timeout_ms: None,
616 branches: None,
617 parallel_branches: None,
618 execution_mode: None,
619 workflow: None,
620 wait_ms: None,
621 approvers: None,
622 };
623 assert_eq!(node.get_execution_mode(), Some(ExecutionMode::Pipeline));
624 assert!(!node.has_barrier());
625 }
626
627 #[test]
628 fn test_node_def_get_execution_mode_parallel() {
629 let node = NodeDef {
630 id: "parallel-node".to_string(),
631 node_type: NodeType::Parallel,
632 name: "Parallel Node".to_string(),
633 description: None,
634 task: None,
635 params: HashMap::new(),
636 on_failure: FailureStrategy::default(),
637 timeout_ms: None,
638 branches: None,
639 parallel_branches: None,
640 execution_mode: None,
641 workflow: None,
642 wait_ms: None,
643 approvers: None,
644 };
645 assert_eq!(node.get_execution_mode(), Some(ExecutionMode::Parallel));
646 assert!(node.has_barrier());
647 }
648
649 #[test]
650 fn test_node_def_get_execution_mode_custom() {
651 let node = NodeDef {
653 id: "custom-node".to_string(),
654 node_type: NodeType::Parallel,
655 name: "Custom Node".to_string(),
656 description: None,
657 task: None,
658 params: HashMap::new(),
659 on_failure: FailureStrategy::default(),
660 timeout_ms: None,
661 branches: None,
662 parallel_branches: None,
663 execution_mode: Some(ExecutionMode::Pipeline), workflow: None,
665 wait_ms: None,
666 approvers: None,
667 };
668 assert_eq!(node.get_execution_mode(), Some(ExecutionMode::Pipeline));
669 assert!(!node.has_barrier());
670 }
671
672 #[test]
673 fn test_node_def_get_execution_mode_other_types() {
674 let node = NodeDef {
675 id: "task-node".to_string(),
676 node_type: NodeType::Task,
677 name: "Task Node".to_string(),
678 description: None,
679 task: Some("do_something".to_string()),
680 params: HashMap::new(),
681 on_failure: FailureStrategy::default(),
682 timeout_ms: None,
683 branches: None,
684 parallel_branches: None,
685 execution_mode: None,
686 workflow: None,
687 wait_ms: None,
688 approvers: None,
689 };
690 assert_eq!(node.get_execution_mode(), None);
691 }
692}