use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum NodeType {
Start,
End,
Task,
Condition,
Parallel,
Pipeline,
SubWorkflow,
Wait,
Approval,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ExecutionMode {
Pipeline,
#[default]
Parallel,
}
impl ExecutionMode {
pub fn has_barrier(&self) -> bool {
match self {
Self::Pipeline => false,
Self::Parallel => true,
}
}
pub fn display_name(&self) -> &'static str {
match self {
Self::Pipeline => "流式",
Self::Parallel => "并行",
}
}
pub fn description(&self) -> &'static str {
match self {
Self::Pipeline => "任务流转执行,不等待其他任务",
Self::Parallel => "所有任务并行执行,等待全部完成",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FailureStrategyType {
Retry,
Ignore,
Abort,
Goto,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FailureStrategyConfig {
#[serde(rename = "type", default = "default_failure_strategy_type")]
pub strategy_type: FailureStrategyType,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_attempts: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub interval_ms: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub target: Option<String>,
}
fn default_failure_strategy_type() -> FailureStrategyType {
FailureStrategyType::Abort
}
impl From<FailureStrategyConfig> for FailureStrategy {
fn from(config: FailureStrategyConfig) -> Self {
match config.strategy_type {
FailureStrategyType::Retry => FailureStrategy::Retry {
max_attempts: config.max_attempts.unwrap_or(1),
interval_ms: config.interval_ms,
},
FailureStrategyType::Ignore => FailureStrategy::Ignore,
FailureStrategyType::Abort => FailureStrategy::Abort,
FailureStrategyType::Goto => FailureStrategy::Goto {
target: config.target.unwrap_or_default(),
},
}
}
}
impl From<FailureStrategy> for FailureStrategyConfig {
fn from(strategy: FailureStrategy) -> Self {
match strategy {
FailureStrategy::Retry {
max_attempts,
interval_ms,
} => FailureStrategyConfig {
strategy_type: FailureStrategyType::Retry,
max_attempts: Some(max_attempts),
interval_ms,
target: None,
},
FailureStrategy::Ignore => FailureStrategyConfig {
strategy_type: FailureStrategyType::Ignore,
max_attempts: None,
interval_ms: None,
target: None,
},
FailureStrategy::Abort => FailureStrategyConfig {
strategy_type: FailureStrategyType::Abort,
max_attempts: None,
interval_ms: None,
target: None,
},
FailureStrategy::Goto { target } => FailureStrategyConfig {
strategy_type: FailureStrategyType::Goto,
max_attempts: None,
interval_ms: None,
target: Some(target),
},
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum FailureStrategy {
Retry {
max_attempts: u32,
interval_ms: Option<u64>,
},
Ignore,
#[default]
Abort,
Goto {
target: String,
},
}
impl Serialize for FailureStrategy {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let config: FailureStrategyConfig = self.clone().into();
config.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for FailureStrategy {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let config: FailureStrategyConfig = FailureStrategyConfig::deserialize(deserializer)?;
Ok(config.into())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EdgeDef {
#[serde(default = "generate_edge_id")]
pub id: String,
pub from: String,
pub to: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub condition: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub label: Option<String>,
}
fn generate_edge_id() -> String {
format!("edge_{}", uuid::Uuid::new_v4())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeDef {
pub id: String,
#[serde(rename = "type")]
pub node_type: NodeType,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub task: Option<String>,
#[serde(default)]
pub params: HashMap<String, serde_json::Value>,
#[serde(default)]
pub on_failure: FailureStrategy,
#[serde(skip_serializing_if = "Option::is_none")]
pub timeout_ms: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub branches: Option<Vec<BranchDef>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_branches: Option<Vec<ParallelBranchDef>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub execution_mode: Option<ExecutionMode>,
#[serde(skip_serializing_if = "Option::is_none")]
pub workflow: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub wait_ms: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub approvers: Option<Vec<String>>,
}
impl NodeDef {
pub fn get_execution_mode(&self) -> Option<ExecutionMode> {
match self.node_type {
NodeType::Pipeline => Some(ExecutionMode::Pipeline),
NodeType::Parallel => Some(
self.execution_mode.unwrap_or(ExecutionMode::Parallel)
),
_ => None,
}
}
pub fn has_barrier(&self) -> bool {
self.get_execution_mode()
.map(|m| m.has_barrier())
.unwrap_or(false)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BranchDef {
pub name: String,
pub condition: String,
pub target: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParallelBranchDef {
pub name: String,
pub nodes: Vec<NodeDef>,
#[serde(default)]
pub mode: ExecutionMode,
}
impl ParallelBranchDef {
pub fn new(name: String, nodes: Vec<NodeDef>) -> Self {
Self {
name,
nodes,
mode: ExecutionMode::default(),
}
}
pub fn pipeline(name: String, nodes: Vec<NodeDef>) -> Self {
Self {
name,
nodes,
mode: ExecutionMode::Pipeline,
}
}
pub fn parallel(name: String, nodes: Vec<NodeDef>) -> Self {
Self {
name,
nodes,
mode: ExecutionMode::Parallel,
}
}
pub fn has_barrier(&self) -> bool {
self.mode.has_barrier()
}
pub fn mode_description(&self) -> &'static str {
self.mode.description()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowDef {
pub id: String,
pub name: String,
#[serde(default = "default_version")]
pub version: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(default)]
pub inputs: Vec<InputDef>,
#[serde(default)]
pub outputs: Vec<OutputDef>,
pub nodes: Vec<NodeDef>,
#[serde(default)]
pub edges: Vec<EdgeDef>,
#[serde(default)]
pub variables: HashMap<String, serde_json::Value>,
#[serde(default)]
pub default_failure_strategy: FailureStrategy,
#[serde(skip_serializing_if = "Option::is_none")]
pub timeout_ms: Option<u64>,
}
fn default_version() -> String {
"1.0.0".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InputDef {
pub name: String,
#[serde(rename = "type", default = "default_input_type")]
pub input_type: String,
#[serde(default)]
pub required: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub default: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
fn default_input_type() -> String {
"string".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OutputDef {
pub name: String,
pub value: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
impl WorkflowDef {
pub fn get_node(&self, id: &str) -> Option<&NodeDef> {
self.nodes.iter().find(|n| n.id == id)
}
pub fn get_start_node(&self) -> Option<&NodeDef> {
self.nodes.iter().find(|n| n.node_type == NodeType::Start)
}
pub fn get_end_node(&self) -> Option<&NodeDef> {
self.nodes.iter().find(|n| n.node_type == NodeType::End)
}
pub fn get_outgoing_edges(&self, node_id: &str) -> Vec<&EdgeDef> {
self.edges.iter().filter(|e| e.from == node_id).collect()
}
pub fn validate(&self) -> anyhow::Result<()> {
if self.get_start_node().is_none() {
anyhow::bail!("Workflow must have a start node");
}
if self.get_end_node().is_none() {
anyhow::bail!("Workflow must have an end node");
}
let mut node_ids = std::collections::HashSet::new();
for node in &self.nodes {
if !node_ids.insert(&node.id) {
anyhow::bail!("Duplicate node id: {}", node.id);
}
}
for edge in &self.edges {
if !node_ids.contains(&edge.from) {
anyhow::bail!("Edge references unknown source node: {}", edge.from);
}
if !node_ids.contains(&edge.to) {
anyhow::bail!("Edge references unknown target node: {}", edge.to);
}
}
for input in &self.inputs {
if input.required && input.default.is_none() {
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_workflow_def_validation() {
let workflow = WorkflowDef {
id: "test-workflow".to_string(),
name: "Test Workflow".to_string(),
version: "1.0.0".to_string(),
description: None,
inputs: vec![],
outputs: vec![],
nodes: vec![
NodeDef {
id: "start".to_string(),
node_type: NodeType::Start,
name: "Start".to_string(),
description: None,
task: None,
params: HashMap::new(),
on_failure: FailureStrategy::default(),
timeout_ms: None,
branches: None,
parallel_branches: None,
execution_mode: None,
workflow: None,
wait_ms: None,
approvers: None,
},
NodeDef {
id: "end".to_string(),
node_type: NodeType::End,
name: "End".to_string(),
description: None,
task: None,
params: HashMap::new(),
on_failure: FailureStrategy::default(),
timeout_ms: None,
branches: None,
parallel_branches: None,
execution_mode: None,
workflow: None,
wait_ms: None,
approvers: None,
},
],
edges: vec![EdgeDef {
id: "e1".to_string(),
from: "start".to_string(),
to: "end".to_string(),
condition: None,
label: None,
}],
variables: HashMap::new(),
default_failure_strategy: FailureStrategy::default(),
timeout_ms: None,
};
assert!(workflow.validate().is_ok());
}
#[test]
fn test_execution_mode_default() {
let mode = ExecutionMode::default();
assert_eq!(mode, ExecutionMode::Parallel);
assert!(mode.has_barrier());
}
#[test]
fn test_execution_mode_pipeline_no_barrier() {
let mode = ExecutionMode::Pipeline;
assert!(!mode.has_barrier());
assert_eq!(mode.display_name(), "流式");
}
#[test]
fn test_execution_mode_parallel_has_barrier() {
let mode = ExecutionMode::Parallel;
assert!(mode.has_barrier());
assert_eq!(mode.display_name(), "并行");
}
#[test]
fn test_parallel_branch_def_default_mode() {
let branch = ParallelBranchDef::new("test".to_string(), vec![]);
assert_eq!(branch.mode, ExecutionMode::Parallel);
assert!(branch.has_barrier());
}
#[test]
fn test_parallel_branch_def_pipeline_mode() {
let branch = ParallelBranchDef::pipeline("test".to_string(), vec![]);
assert_eq!(branch.mode, ExecutionMode::Pipeline);
assert!(!branch.has_barrier());
}
#[test]
fn test_parallel_branch_def_parallel_mode() {
let branch = ParallelBranchDef::parallel("test".to_string(), vec![]);
assert_eq!(branch.mode, ExecutionMode::Parallel);
assert!(branch.has_barrier());
}
#[test]
fn test_node_def_get_execution_mode_pipeline() {
let node = NodeDef {
id: "pipeline-node".to_string(),
node_type: NodeType::Pipeline,
name: "Pipeline Node".to_string(),
description: None,
task: None,
params: HashMap::new(),
on_failure: FailureStrategy::default(),
timeout_ms: None,
branches: None,
parallel_branches: None,
execution_mode: None,
workflow: None,
wait_ms: None,
approvers: None,
};
assert_eq!(node.get_execution_mode(), Some(ExecutionMode::Pipeline));
assert!(!node.has_barrier());
}
#[test]
fn test_node_def_get_execution_mode_parallel() {
let node = NodeDef {
id: "parallel-node".to_string(),
node_type: NodeType::Parallel,
name: "Parallel Node".to_string(),
description: None,
task: None,
params: HashMap::new(),
on_failure: FailureStrategy::default(),
timeout_ms: None,
branches: None,
parallel_branches: None,
execution_mode: None,
workflow: None,
wait_ms: None,
approvers: None,
};
assert_eq!(node.get_execution_mode(), Some(ExecutionMode::Parallel));
assert!(node.has_barrier());
}
#[test]
fn test_node_def_get_execution_mode_custom() {
let node = NodeDef {
id: "custom-node".to_string(),
node_type: NodeType::Parallel,
name: "Custom Node".to_string(),
description: None,
task: None,
params: HashMap::new(),
on_failure: FailureStrategy::default(),
timeout_ms: None,
branches: None,
parallel_branches: None,
execution_mode: Some(ExecutionMode::Pipeline), workflow: None,
wait_ms: None,
approvers: None,
};
assert_eq!(node.get_execution_mode(), Some(ExecutionMode::Pipeline));
assert!(!node.has_barrier());
}
#[test]
fn test_node_def_get_execution_mode_other_types() {
let node = NodeDef {
id: "task-node".to_string(),
node_type: NodeType::Task,
name: "Task Node".to_string(),
description: None,
task: Some("do_something".to_string()),
params: HashMap::new(),
on_failure: FailureStrategy::default(),
timeout_ms: None,
branches: None,
parallel_branches: None,
execution_mode: None,
workflow: None,
wait_ms: None,
approvers: None,
};
assert_eq!(node.get_execution_mode(), None);
}
}