1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "snake_case")]
11pub enum NodeType {
12 Start,
14 End,
16 Task,
18 Condition,
20 Parallel,
22 SubWorkflow,
24 Wait,
26 Approval,
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
32#[serde(rename_all = "snake_case")]
33pub enum FailureStrategyType {
34 Retry,
35 Ignore,
36 Abort,
37 Goto,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct FailureStrategyConfig {
43 #[serde(rename = "type", default = "default_failure_strategy_type")]
45 pub strategy_type: FailureStrategyType,
46 #[serde(skip_serializing_if = "Option::is_none")]
48 pub max_attempts: Option<u32>,
49 #[serde(skip_serializing_if = "Option::is_none")]
51 pub interval_ms: Option<u64>,
52 #[serde(skip_serializing_if = "Option::is_none")]
54 pub target: Option<String>,
55}
56
57fn default_failure_strategy_type() -> FailureStrategyType {
58 FailureStrategyType::Abort
59}
60
61impl From<FailureStrategyConfig> for FailureStrategy {
62 fn from(config: FailureStrategyConfig) -> Self {
63 match config.strategy_type {
64 FailureStrategyType::Retry => FailureStrategy::Retry {
65 max_attempts: config.max_attempts.unwrap_or(1),
66 interval_ms: config.interval_ms,
67 },
68 FailureStrategyType::Ignore => FailureStrategy::Ignore,
69 FailureStrategyType::Abort => FailureStrategy::Abort,
70 FailureStrategyType::Goto => FailureStrategy::Goto {
71 target: config.target.unwrap_or_default(),
72 },
73 }
74 }
75}
76
77impl From<FailureStrategy> for FailureStrategyConfig {
78 fn from(strategy: FailureStrategy) -> Self {
79 match strategy {
80 FailureStrategy::Retry {
81 max_attempts,
82 interval_ms,
83 } => FailureStrategyConfig {
84 strategy_type: FailureStrategyType::Retry,
85 max_attempts: Some(max_attempts),
86 interval_ms,
87 target: None,
88 },
89 FailureStrategy::Ignore => FailureStrategyConfig {
90 strategy_type: FailureStrategyType::Ignore,
91 max_attempts: None,
92 interval_ms: None,
93 target: None,
94 },
95 FailureStrategy::Abort => FailureStrategyConfig {
96 strategy_type: FailureStrategyType::Abort,
97 max_attempts: None,
98 interval_ms: None,
99 target: None,
100 },
101 FailureStrategy::Goto { target } => FailureStrategyConfig {
102 strategy_type: FailureStrategyType::Goto,
103 max_attempts: None,
104 interval_ms: None,
105 target: Some(target),
106 },
107 }
108 }
109}
110
111#[derive(Debug, Clone, PartialEq, Eq, Default)]
113pub enum FailureStrategy {
114 Retry {
116 max_attempts: u32,
118 interval_ms: Option<u64>,
120 },
121 Ignore,
123 #[default]
125 Abort,
126 Goto {
128 target: String,
130 },
131}
132
133impl Serialize for FailureStrategy {
134 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
135 where
136 S: serde::Serializer,
137 {
138 let config: FailureStrategyConfig = self.clone().into();
139 config.serialize(serializer)
140 }
141}
142
143impl<'de> Deserialize<'de> for FailureStrategy {
144 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
145 where
146 D: serde::Deserializer<'de>,
147 {
148 let config: FailureStrategyConfig = FailureStrategyConfig::deserialize(deserializer)?;
149 Ok(config.into())
150 }
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct EdgeDef {
156 #[serde(default = "generate_edge_id")]
158 pub id: String,
159 pub from: String,
161 pub to: String,
163 #[serde(skip_serializing_if = "Option::is_none")]
165 pub condition: Option<String>,
166 #[serde(skip_serializing_if = "Option::is_none")]
168 pub label: Option<String>,
169}
170
171fn generate_edge_id() -> String {
172 format!("edge_{}", uuid::Uuid::new_v4())
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct NodeDef {
178 pub id: String,
180 #[serde(rename = "type")]
182 pub node_type: NodeType,
183 pub name: String,
185 #[serde(skip_serializing_if = "Option::is_none")]
187 pub description: Option<String>,
188 #[serde(skip_serializing_if = "Option::is_none")]
190 pub task: Option<String>,
191 #[serde(default)]
193 pub params: HashMap<String, serde_json::Value>,
194 #[serde(default)]
196 pub on_failure: FailureStrategy,
197 #[serde(skip_serializing_if = "Option::is_none")]
199 pub timeout_ms: Option<u64>,
200 #[serde(skip_serializing_if = "Option::is_none")]
202 pub branches: Option<Vec<BranchDef>>,
203 #[serde(skip_serializing_if = "Option::is_none")]
205 pub parallel_branches: Option<Vec<ParallelBranchDef>>,
206 #[serde(skip_serializing_if = "Option::is_none")]
208 pub workflow: Option<String>,
209 #[serde(skip_serializing_if = "Option::is_none")]
211 pub wait_ms: Option<u64>,
212 #[serde(skip_serializing_if = "Option::is_none")]
214 pub approvers: Option<Vec<String>>,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct BranchDef {
220 pub name: String,
222 pub condition: String,
224 pub target: String,
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct ParallelBranchDef {
231 pub name: String,
233 pub nodes: Vec<NodeDef>,
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct WorkflowDef {
240 pub id: String,
242 pub name: String,
244 #[serde(default = "default_version")]
246 pub version: String,
247 #[serde(skip_serializing_if = "Option::is_none")]
249 pub description: Option<String>,
250 #[serde(default)]
252 pub inputs: Vec<InputDef>,
253 #[serde(default)]
255 pub outputs: Vec<OutputDef>,
256 pub nodes: Vec<NodeDef>,
258 #[serde(default)]
260 pub edges: Vec<EdgeDef>,
261 #[serde(default)]
263 pub variables: HashMap<String, serde_json::Value>,
264 #[serde(default)]
266 pub default_failure_strategy: FailureStrategy,
267 #[serde(skip_serializing_if = "Option::is_none")]
269 pub timeout_ms: Option<u64>,
270}
271
272fn default_version() -> String {
273 "1.0.0".to_string()
274}
275
276#[derive(Debug, Clone, Serialize, Deserialize)]
278pub struct InputDef {
279 pub name: String,
281 #[serde(rename = "type", default = "default_input_type")]
283 pub input_type: String,
284 #[serde(default)]
286 pub required: bool,
287 #[serde(skip_serializing_if = "Option::is_none")]
289 pub default: Option<serde_json::Value>,
290 #[serde(skip_serializing_if = "Option::is_none")]
292 pub description: Option<String>,
293}
294
295fn default_input_type() -> String {
296 "string".to_string()
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
301pub struct OutputDef {
302 pub name: String,
304 pub value: String,
306 #[serde(skip_serializing_if = "Option::is_none")]
308 pub description: Option<String>,
309}
310
311impl WorkflowDef {
312 pub fn get_node(&self, id: &str) -> Option<&NodeDef> {
314 self.nodes.iter().find(|n| n.id == id)
315 }
316
317 pub fn get_start_node(&self) -> Option<&NodeDef> {
319 self.nodes.iter().find(|n| n.node_type == NodeType::Start)
320 }
321
322 pub fn get_end_node(&self) -> Option<&NodeDef> {
324 self.nodes.iter().find(|n| n.node_type == NodeType::End)
325 }
326
327 pub fn get_outgoing_edges(&self, node_id: &str) -> Vec<&EdgeDef> {
329 self.edges.iter().filter(|e| e.from == node_id).collect()
330 }
331
332 pub fn validate(&self) -> anyhow::Result<()> {
334 if self.get_start_node().is_none() {
336 anyhow::bail!("Workflow must have a start node");
337 }
338
339 if self.get_end_node().is_none() {
341 anyhow::bail!("Workflow must have an end node");
342 }
343
344 let mut node_ids = std::collections::HashSet::new();
346 for node in &self.nodes {
347 if !node_ids.insert(&node.id) {
348 anyhow::bail!("Duplicate node id: {}", node.id);
349 }
350 }
351
352 for edge in &self.edges {
354 if !node_ids.contains(&edge.from) {
355 anyhow::bail!("Edge references unknown source node: {}", edge.from);
356 }
357 if !node_ids.contains(&edge.to) {
358 anyhow::bail!("Edge references unknown target node: {}", edge.to);
359 }
360 }
361
362 for input in &self.inputs {
364 if input.required && input.default.is_none() {
365 }
367 }
368
369 Ok(())
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
378 fn test_workflow_def_validation() {
379 let workflow = WorkflowDef {
380 id: "test-workflow".to_string(),
381 name: "Test Workflow".to_string(),
382 version: "1.0.0".to_string(),
383 description: None,
384 inputs: vec![],
385 outputs: vec![],
386 nodes: vec![
387 NodeDef {
388 id: "start".to_string(),
389 node_type: NodeType::Start,
390 name: "Start".to_string(),
391 description: None,
392 task: None,
393 params: HashMap::new(),
394 on_failure: FailureStrategy::default(),
395 timeout_ms: None,
396 branches: None,
397 parallel_branches: None,
398 workflow: None,
399 wait_ms: None,
400 approvers: None,
401 },
402 NodeDef {
403 id: "end".to_string(),
404 node_type: NodeType::End,
405 name: "End".to_string(),
406 description: None,
407 task: None,
408 params: HashMap::new(),
409 on_failure: FailureStrategy::default(),
410 timeout_ms: None,
411 branches: None,
412 parallel_branches: None,
413 workflow: None,
414 wait_ms: None,
415 approvers: None,
416 },
417 ],
418 edges: vec![EdgeDef {
419 id: "e1".to_string(),
420 from: "start".to_string(),
421 to: "end".to_string(),
422 condition: None,
423 label: None,
424 }],
425 variables: HashMap::new(),
426 default_failure_strategy: FailureStrategy::default(),
427 timeout_ms: None,
428 };
429
430 assert!(workflow.validate().is_ok());
431 }
432}