use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum NodeType {
Start,
End,
Task,
Condition,
Parallel,
SubWorkflow,
Wait,
Approval,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FailureStrategyType {
Retry,
Ignore,
Abort,
Goto,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FailureStrategyConfig {
#[serde(rename = "type", default = "default_failure_strategy_type")]
pub strategy_type: FailureStrategyType,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_attempts: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub interval_ms: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub target: Option<String>,
}
fn default_failure_strategy_type() -> FailureStrategyType {
FailureStrategyType::Abort
}
impl From<FailureStrategyConfig> for FailureStrategy {
fn from(config: FailureStrategyConfig) -> Self {
match config.strategy_type {
FailureStrategyType::Retry => FailureStrategy::Retry {
max_attempts: config.max_attempts.unwrap_or(1),
interval_ms: config.interval_ms,
},
FailureStrategyType::Ignore => FailureStrategy::Ignore,
FailureStrategyType::Abort => FailureStrategy::Abort,
FailureStrategyType::Goto => FailureStrategy::Goto {
target: config.target.unwrap_or_default(),
},
}
}
}
impl From<FailureStrategy> for FailureStrategyConfig {
fn from(strategy: FailureStrategy) -> Self {
match strategy {
FailureStrategy::Retry {
max_attempts,
interval_ms,
} => FailureStrategyConfig {
strategy_type: FailureStrategyType::Retry,
max_attempts: Some(max_attempts),
interval_ms,
target: None,
},
FailureStrategy::Ignore => FailureStrategyConfig {
strategy_type: FailureStrategyType::Ignore,
max_attempts: None,
interval_ms: None,
target: None,
},
FailureStrategy::Abort => FailureStrategyConfig {
strategy_type: FailureStrategyType::Abort,
max_attempts: None,
interval_ms: None,
target: None,
},
FailureStrategy::Goto { target } => FailureStrategyConfig {
strategy_type: FailureStrategyType::Goto,
max_attempts: None,
interval_ms: None,
target: Some(target),
},
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum FailureStrategy {
Retry {
max_attempts: u32,
interval_ms: Option<u64>,
},
Ignore,
#[default]
Abort,
Goto {
target: String,
},
}
impl Serialize for FailureStrategy {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let config: FailureStrategyConfig = self.clone().into();
config.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for FailureStrategy {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let config: FailureStrategyConfig = FailureStrategyConfig::deserialize(deserializer)?;
Ok(config.into())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EdgeDef {
#[serde(default = "generate_edge_id")]
pub id: String,
pub from: String,
pub to: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub condition: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub label: Option<String>,
}
fn generate_edge_id() -> String {
format!("edge_{}", uuid::Uuid::new_v4())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeDef {
pub id: String,
#[serde(rename = "type")]
pub node_type: NodeType,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub task: Option<String>,
#[serde(default)]
pub params: HashMap<String, serde_json::Value>,
#[serde(default)]
pub on_failure: FailureStrategy,
#[serde(skip_serializing_if = "Option::is_none")]
pub timeout_ms: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub branches: Option<Vec<BranchDef>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_branches: Option<Vec<ParallelBranchDef>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub workflow: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub wait_ms: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub approvers: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BranchDef {
pub name: String,
pub condition: String,
pub target: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParallelBranchDef {
pub name: String,
pub nodes: Vec<NodeDef>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowDef {
pub id: String,
pub name: String,
#[serde(default = "default_version")]
pub version: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(default)]
pub inputs: Vec<InputDef>,
#[serde(default)]
pub outputs: Vec<OutputDef>,
pub nodes: Vec<NodeDef>,
#[serde(default)]
pub edges: Vec<EdgeDef>,
#[serde(default)]
pub variables: HashMap<String, serde_json::Value>,
#[serde(default)]
pub default_failure_strategy: FailureStrategy,
#[serde(skip_serializing_if = "Option::is_none")]
pub timeout_ms: Option<u64>,
}
fn default_version() -> String {
"1.0.0".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InputDef {
pub name: String,
#[serde(rename = "type", default = "default_input_type")]
pub input_type: String,
#[serde(default)]
pub required: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub default: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
fn default_input_type() -> String {
"string".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OutputDef {
pub name: String,
pub value: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
impl WorkflowDef {
pub fn get_node(&self, id: &str) -> Option<&NodeDef> {
self.nodes.iter().find(|n| n.id == id)
}
pub fn get_start_node(&self) -> Option<&NodeDef> {
self.nodes.iter().find(|n| n.node_type == NodeType::Start)
}
pub fn get_end_node(&self) -> Option<&NodeDef> {
self.nodes.iter().find(|n| n.node_type == NodeType::End)
}
pub fn get_outgoing_edges(&self, node_id: &str) -> Vec<&EdgeDef> {
self.edges.iter().filter(|e| e.from == node_id).collect()
}
pub fn validate(&self) -> anyhow::Result<()> {
if self.get_start_node().is_none() {
anyhow::bail!("Workflow must have a start node");
}
if self.get_end_node().is_none() {
anyhow::bail!("Workflow must have an end node");
}
let mut node_ids = std::collections::HashSet::new();
for node in &self.nodes {
if !node_ids.insert(&node.id) {
anyhow::bail!("Duplicate node id: {}", node.id);
}
}
for edge in &self.edges {
if !node_ids.contains(&edge.from) {
anyhow::bail!("Edge references unknown source node: {}", edge.from);
}
if !node_ids.contains(&edge.to) {
anyhow::bail!("Edge references unknown target node: {}", edge.to);
}
}
for input in &self.inputs {
if input.required && input.default.is_none() {
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_workflow_def_validation() {
let workflow = WorkflowDef {
id: "test-workflow".to_string(),
name: "Test Workflow".to_string(),
version: "1.0.0".to_string(),
description: None,
inputs: vec![],
outputs: vec![],
nodes: vec![
NodeDef {
id: "start".to_string(),
node_type: NodeType::Start,
name: "Start".to_string(),
description: None,
task: None,
params: HashMap::new(),
on_failure: FailureStrategy::default(),
timeout_ms: None,
branches: None,
parallel_branches: None,
workflow: None,
wait_ms: None,
approvers: None,
},
NodeDef {
id: "end".to_string(),
node_type: NodeType::End,
name: "End".to_string(),
description: None,
task: None,
params: HashMap::new(),
on_failure: FailureStrategy::default(),
timeout_ms: None,
branches: None,
parallel_branches: None,
workflow: None,
wait_ms: None,
approvers: None,
},
],
edges: vec![EdgeDef {
id: "e1".to_string(),
from: "start".to_string(),
to: "end".to_string(),
condition: None,
label: None,
}],
variables: HashMap::new(),
default_failure_strategy: FailureStrategy::default(),
timeout_ms: None,
};
assert!(workflow.validate().is_ok());
}
}