1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "snake_case")]
11pub enum NodeType {
12 Start,
14 End,
16 Task,
18 Condition,
20 Parallel,
22 SubWorkflow,
24 Wait,
26 Approval,
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
32#[serde(rename_all = "snake_case")]
33pub enum FailureStrategyType {
34 Retry,
35 Ignore,
36 Abort,
37 Goto,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct FailureStrategyConfig {
43 #[serde(rename = "type", default = "default_failure_strategy_type")]
45 pub strategy_type: FailureStrategyType,
46 #[serde(skip_serializing_if = "Option::is_none")]
48 pub max_attempts: Option<u32>,
49 #[serde(skip_serializing_if = "Option::is_none")]
51 pub interval_ms: Option<u64>,
52 #[serde(skip_serializing_if = "Option::is_none")]
54 pub target: Option<String>,
55}
56
57fn default_failure_strategy_type() -> FailureStrategyType {
58 FailureStrategyType::Abort
59}
60
61impl From<FailureStrategyConfig> for FailureStrategy {
62 fn from(config: FailureStrategyConfig) -> Self {
63 match config.strategy_type {
64 FailureStrategyType::Retry => FailureStrategy::Retry {
65 max_attempts: config.max_attempts.unwrap_or(1),
66 interval_ms: config.interval_ms,
67 },
68 FailureStrategyType::Ignore => FailureStrategy::Ignore,
69 FailureStrategyType::Abort => FailureStrategy::Abort,
70 FailureStrategyType::Goto => FailureStrategy::Goto {
71 target: config.target.unwrap_or_default(),
72 },
73 }
74 }
75}
76
77impl From<FailureStrategy> for FailureStrategyConfig {
78 fn from(strategy: FailureStrategy) -> Self {
79 match strategy {
80 FailureStrategy::Retry { max_attempts, interval_ms } => FailureStrategyConfig {
81 strategy_type: FailureStrategyType::Retry,
82 max_attempts: Some(max_attempts),
83 interval_ms,
84 target: None,
85 },
86 FailureStrategy::Ignore => FailureStrategyConfig {
87 strategy_type: FailureStrategyType::Ignore,
88 max_attempts: None,
89 interval_ms: None,
90 target: None,
91 },
92 FailureStrategy::Abort => FailureStrategyConfig {
93 strategy_type: FailureStrategyType::Abort,
94 max_attempts: None,
95 interval_ms: None,
96 target: None,
97 },
98 FailureStrategy::Goto { target } => FailureStrategyConfig {
99 strategy_type: FailureStrategyType::Goto,
100 max_attempts: None,
101 interval_ms: None,
102 target: Some(target),
103 },
104 }
105 }
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
110#[derive(Default)]
111pub enum FailureStrategy {
112 Retry {
114 max_attempts: u32,
116 interval_ms: Option<u64>,
118 },
119 Ignore,
121 #[default]
123 Abort,
124 Goto {
126 target: String,
128 },
129}
130
131
132impl Serialize for FailureStrategy {
133 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
134 where
135 S: serde::Serializer,
136 {
137 let config: FailureStrategyConfig = self.clone().into();
138 config.serialize(serializer)
139 }
140}
141
142impl<'de> Deserialize<'de> for FailureStrategy {
143 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
144 where
145 D: serde::Deserializer<'de>,
146 {
147 let config: FailureStrategyConfig = FailureStrategyConfig::deserialize(deserializer)?;
148 Ok(config.into())
149 }
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct EdgeDef {
155 #[serde(default = "generate_edge_id")]
157 pub id: String,
158 pub from: String,
160 pub to: String,
162 #[serde(skip_serializing_if = "Option::is_none")]
164 pub condition: Option<String>,
165 #[serde(skip_serializing_if = "Option::is_none")]
167 pub label: Option<String>,
168}
169
170fn generate_edge_id() -> String {
171 format!("edge_{}", uuid::Uuid::new_v4())
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct NodeDef {
177 pub id: String,
179 #[serde(rename = "type")]
181 pub node_type: NodeType,
182 pub name: String,
184 #[serde(skip_serializing_if = "Option::is_none")]
186 pub description: Option<String>,
187 #[serde(skip_serializing_if = "Option::is_none")]
189 pub task: Option<String>,
190 #[serde(default)]
192 pub params: HashMap<String, serde_json::Value>,
193 #[serde(default)]
195 pub on_failure: FailureStrategy,
196 #[serde(skip_serializing_if = "Option::is_none")]
198 pub timeout_ms: Option<u64>,
199 #[serde(skip_serializing_if = "Option::is_none")]
201 pub branches: Option<Vec<BranchDef>>,
202 #[serde(skip_serializing_if = "Option::is_none")]
204 pub parallel_branches: Option<Vec<ParallelBranchDef>>,
205 #[serde(skip_serializing_if = "Option::is_none")]
207 pub workflow: Option<String>,
208 #[serde(skip_serializing_if = "Option::is_none")]
210 pub wait_ms: Option<u64>,
211 #[serde(skip_serializing_if = "Option::is_none")]
213 pub approvers: Option<Vec<String>>,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct BranchDef {
219 pub name: String,
221 pub condition: String,
223 pub target: String,
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct ParallelBranchDef {
230 pub name: String,
232 pub nodes: Vec<NodeDef>,
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct WorkflowDef {
239 pub id: String,
241 pub name: String,
243 #[serde(default = "default_version")]
245 pub version: String,
246 #[serde(skip_serializing_if = "Option::is_none")]
248 pub description: Option<String>,
249 #[serde(default)]
251 pub inputs: Vec<InputDef>,
252 #[serde(default)]
254 pub outputs: Vec<OutputDef>,
255 pub nodes: Vec<NodeDef>,
257 #[serde(default)]
259 pub edges: Vec<EdgeDef>,
260 #[serde(default)]
262 pub variables: HashMap<String, serde_json::Value>,
263 #[serde(default)]
265 pub default_failure_strategy: FailureStrategy,
266 #[serde(skip_serializing_if = "Option::is_none")]
268 pub timeout_ms: Option<u64>,
269}
270
271fn default_version() -> String {
272 "1.0.0".to_string()
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct InputDef {
278 pub name: String,
280 #[serde(rename = "type", default = "default_input_type")]
282 pub input_type: String,
283 #[serde(default)]
285 pub required: bool,
286 #[serde(skip_serializing_if = "Option::is_none")]
288 pub default: Option<serde_json::Value>,
289 #[serde(skip_serializing_if = "Option::is_none")]
291 pub description: Option<String>,
292}
293
294fn default_input_type() -> String {
295 "string".to_string()
296}
297
298#[derive(Debug, Clone, Serialize, Deserialize)]
300pub struct OutputDef {
301 pub name: String,
303 pub value: String,
305 #[serde(skip_serializing_if = "Option::is_none")]
307 pub description: Option<String>,
308}
309
310impl WorkflowDef {
311 pub fn get_node(&self, id: &str) -> Option<&NodeDef> {
313 self.nodes.iter().find(|n| n.id == id)
314 }
315
316 pub fn get_start_node(&self) -> Option<&NodeDef> {
318 self.nodes.iter().find(|n| n.node_type == NodeType::Start)
319 }
320
321 pub fn get_end_node(&self) -> Option<&NodeDef> {
323 self.nodes.iter().find(|n| n.node_type == NodeType::End)
324 }
325
326 pub fn get_outgoing_edges(&self, node_id: &str) -> Vec<&EdgeDef> {
328 self.edges.iter().filter(|e| e.from == node_id).collect()
329 }
330
331 pub fn validate(&self) -> anyhow::Result<()> {
333 if self.get_start_node().is_none() {
335 anyhow::bail!("Workflow must have a start node");
336 }
337
338 if self.get_end_node().is_none() {
340 anyhow::bail!("Workflow must have an end node");
341 }
342
343 let mut node_ids = std::collections::HashSet::new();
345 for node in &self.nodes {
346 if !node_ids.insert(&node.id) {
347 anyhow::bail!("Duplicate node id: {}", node.id);
348 }
349 }
350
351 for edge in &self.edges {
353 if !node_ids.contains(&edge.from) {
354 anyhow::bail!("Edge references unknown source node: {}", edge.from);
355 }
356 if !node_ids.contains(&edge.to) {
357 anyhow::bail!("Edge references unknown target node: {}", edge.to);
358 }
359 }
360
361 for input in &self.inputs {
363 if input.required && input.default.is_none() {
364 }
366 }
367
368 Ok(())
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 #[test]
377 fn test_workflow_def_validation() {
378 let workflow = WorkflowDef {
379 id: "test-workflow".to_string(),
380 name: "Test Workflow".to_string(),
381 version: "1.0.0".to_string(),
382 description: None,
383 inputs: vec![],
384 outputs: vec![],
385 nodes: vec![
386 NodeDef {
387 id: "start".to_string(),
388 node_type: NodeType::Start,
389 name: "Start".to_string(),
390 description: None,
391 task: None,
392 params: HashMap::new(),
393 on_failure: FailureStrategy::default(),
394 timeout_ms: None,
395 branches: None,
396 parallel_branches: None,
397 workflow: None,
398 wait_ms: None,
399 approvers: None,
400 },
401 NodeDef {
402 id: "end".to_string(),
403 node_type: NodeType::End,
404 name: "End".to_string(),
405 description: None,
406 task: None,
407 params: HashMap::new(),
408 on_failure: FailureStrategy::default(),
409 timeout_ms: None,
410 branches: None,
411 parallel_branches: None,
412 workflow: None,
413 wait_ms: None,
414 approvers: None,
415 },
416 ],
417 edges: vec![EdgeDef {
418 id: "e1".to_string(),
419 from: "start".to_string(),
420 to: "end".to_string(),
421 condition: None,
422 label: None,
423 }],
424 variables: HashMap::new(),
425 default_failure_strategy: FailureStrategy::default(),
426 timeout_ms: None,
427 };
428
429 assert!(workflow.validate().is_ok());
430 }
431}