1use crate::workflow::node::NodeType;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct WorkflowDefinition {
12 pub metadata: WorkflowMetadata,
14
15 #[serde(default)]
17 pub config: WorkflowConfig,
18
19 pub nodes: Vec<NodeDefinition>,
21
22 #[serde(default)]
24 pub edges: Vec<EdgeDefinition>,
25
26 #[serde(default)]
28 pub agents: HashMap<String, LlmAgentConfig>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct WorkflowMetadata {
34 pub id: String,
36
37 pub name: String,
39
40 #[serde(default)]
42 pub description: String,
43
44 #[serde(default)]
46 pub version: Option<String>,
47
48 #[serde(default)]
50 pub author: Option<String>,
51
52 #[serde(default)]
54 pub tags: Vec<String>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct WorkflowConfig {
60 #[serde(default = "default_max_parallel")]
62 pub max_parallel: usize,
63
64 #[serde(default = "default_timeout")]
66 pub default_timeout_ms: u64,
67
68 #[serde(default)]
70 pub enable_checkpoints: bool,
71
72 #[serde(default)]
74 pub retry_policy: Option<RetryPolicy>,
75}
76
77impl Default for WorkflowConfig {
78 fn default() -> Self {
79 Self {
80 max_parallel: default_max_parallel(),
81 default_timeout_ms: default_timeout(),
82 enable_checkpoints: false,
83 retry_policy: None,
84 }
85 }
86}
87
88fn default_max_parallel() -> usize {
89 10
90}
91
92fn default_timeout() -> u64 {
93 60000 }
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98#[serde(tag = "type", rename_all = "snake_case")]
99pub enum NodeDefinition {
100 Start {
102 id: String,
103 #[serde(default)]
104 name: Option<String>,
105 },
106
107 End {
109 id: String,
110 #[serde(default)]
111 name: Option<String>,
112 },
113
114 Task {
116 id: String,
117 name: String,
118 #[serde(flatten)]
119 executor: TaskExecutorDef,
120 #[serde(default)]
121 config: NodeConfigDef,
122 },
123
124 #[serde(rename = "llm_agent")]
126 LLM_AGENT {
127 id: String,
128 name: String,
129 agent: AgentRef,
131 #[serde(default)]
133 prompt_template: Option<String>,
134 #[serde(default)]
135 config: NodeConfigDef,
136 },
137
138 Condition {
140 id: String,
141 name: String,
142 condition: ConditionDef,
143 #[serde(default)]
144 config: NodeConfigDef,
145 },
146
147 Parallel {
149 id: String,
150 name: String,
151 #[serde(default)]
152 config: NodeConfigDef,
153 },
154
155 Join {
157 id: String,
158 name: String,
159 #[serde(default)]
161 wait_for: Vec<String>,
162 #[serde(default)]
163 config: NodeConfigDef,
164 },
165
166 Loop {
168 id: String,
169 name: String,
170 #[serde(flatten)]
171 body: TaskExecutorDef,
172 condition: LoopConditionDef,
173 #[serde(default)]
174 max_iterations: u32,
175 #[serde(default)]
176 config: NodeConfigDef,
177 },
178
179 Transform {
181 id: String,
182 name: String,
183 #[serde(flatten)]
184 transform: TransformDef,
185 #[serde(default)]
186 config: NodeConfigDef,
187 },
188
189 SubWorkflow {
191 id: String,
192 name: String,
193 workflow_id: String,
195 #[serde(default)]
196 config: NodeConfigDef,
197 },
198
199 Wait {
201 id: String,
202 name: String,
203 event_type: String,
205 #[serde(default)]
206 config: NodeConfigDef,
207 },
208}
209
210impl NodeDefinition {
211 pub fn id(&self) -> &str {
213 match self {
214 NodeDefinition::Start { id, .. } => id,
215 NodeDefinition::End { id, .. } => id,
216 NodeDefinition::Task { id, .. } => id,
217 NodeDefinition::LLM_AGENT { id, .. } => id,
218 NodeDefinition::Condition { id, .. } => id,
219 NodeDefinition::Parallel { id, .. } => id,
220 NodeDefinition::Join { id, .. } => id,
221 NodeDefinition::Loop { id, .. } => id,
222 NodeDefinition::Transform { id, .. } => id,
223 NodeDefinition::SubWorkflow { id, .. } => id,
224 NodeDefinition::Wait { id, .. } => id,
225 }
226 }
227
228 pub fn node_type(&self) -> NodeType {
230 match self {
231 NodeDefinition::Start { .. } => NodeType::Start,
232 NodeDefinition::End { .. } => NodeType::End,
233 NodeDefinition::Task { .. } => NodeType::Task,
234 NodeDefinition::LLM_AGENT { .. } => NodeType::Agent,
235 NodeDefinition::Condition { .. } => NodeType::Condition,
236 NodeDefinition::Parallel { .. } => NodeType::Parallel,
237 NodeDefinition::Join { .. } => NodeType::Join,
238 NodeDefinition::Loop { .. } => NodeType::Loop,
239 NodeDefinition::Transform { .. } => NodeType::Transform,
240 NodeDefinition::SubWorkflow { .. } => NodeType::SubWorkflow,
241 NodeDefinition::Wait { .. } => NodeType::Wait,
242 }
243 }
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct EdgeDefinition {
249 pub from: String,
251
252 pub to: String,
254
255 #[serde(default)]
257 pub condition: Option<String>,
258
259 #[serde(default)]
261 pub label: Option<String>,
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct LlmAgentConfig {
267 pub model: String,
269
270 #[serde(default)]
272 pub system_prompt: Option<String>,
273
274 #[serde(default)]
276 pub temperature: Option<f32>,
277
278 #[serde(default)]
280 pub max_tokens: Option<u32>,
281
282 #[serde(default)]
284 pub context_window_size: Option<usize>,
285
286 #[serde(default)]
288 pub user_id: Option<String>,
289
290 #[serde(default)]
292 pub tenant_id: Option<String>,
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize)]
297#[serde(untagged)]
298pub enum AgentRef {
299 Registry { agent_id: String },
301
302 Inline(Box<LlmAgentConfig>),
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize)]
308#[serde(tag = "executor_type", rename_all = "snake_case")]
309pub enum TaskExecutorDef {
310 Function { function: String },
312
313 Http {
315 url: String,
316 #[serde(default)]
317 method: Option<String>,
318 },
319
320 Script { script: String },
322
323 None,
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
329#[serde(tag = "condition_type", rename_all = "snake_case")]
330pub enum ConditionDef {
331 Expression { expr: String },
333
334 Value {
336 field: String,
337 operator: String,
338 value: serde_json::Value,
339 },
340}
341
342#[derive(Debug, Clone, Serialize, Deserialize)]
344#[serde(tag = "condition_type", rename_all = "snake_case")]
345pub enum LoopConditionDef {
346 While { expr: String },
348
349 Until { expr: String },
351
352 Count { max: u32 },
354}
355
356#[derive(Debug, Clone, Serialize, Deserialize)]
358#[serde(tag = "transform_type", rename_all = "snake_case")]
359pub enum TransformDef {
360 Template { template: String },
362
363 Expression { expr: String },
365
366 MapReduce {
368 #[serde(default)]
369 map: Option<String>,
370 #[serde(default)]
371 reduce: Option<String>,
372 },
373}
374
375#[derive(Debug, Clone, Serialize, Deserialize, Default)]
377pub struct NodeConfigDef {
378 #[serde(default)]
380 pub retry_policy: Option<RetryPolicy>,
381
382 #[serde(default)]
384 pub timeout_ms: Option<u64>,
385
386 #[serde(default)]
388 pub metadata: HashMap<String, String>,
389}
390
391#[derive(Debug, Clone, Serialize, Deserialize)]
393pub struct RetryPolicy {
394 #[serde(default = "default_max_retries")]
396 pub max_retries: u32,
397
398 #[serde(default = "default_retry_delay")]
400 pub retry_delay_ms: u64,
401
402 #[serde(default = "default_exponential_backoff")]
404 pub exponential_backoff: bool,
405
406 #[serde(default = "default_max_delay")]
408 pub max_delay_ms: u64,
409}
410
411impl Default for RetryPolicy {
412 fn default() -> Self {
413 Self {
414 max_retries: default_max_retries(),
415 retry_delay_ms: default_retry_delay(),
416 exponential_backoff: default_exponential_backoff(),
417 max_delay_ms: default_max_delay(),
418 }
419 }
420}
421
422fn default_max_retries() -> u32 {
423 3
424}
425
426fn default_retry_delay() -> u64 {
427 1000
428}
429
430fn default_exponential_backoff() -> bool {
431 true
432}
433
434fn default_max_delay() -> u64 {
435 30000
436}
437
438#[derive(Debug, Clone, Serialize, Deserialize)]
440pub struct TimeoutConfig {
441 pub execution_timeout_ms: u64,
443
444 #[serde(default = "default_cancel_on_timeout")]
446 pub cancel_on_timeout: bool,
447}
448
449impl Default for TimeoutConfig {
450 fn default() -> Self {
451 Self {
452 execution_timeout_ms: 60000,
453 cancel_on_timeout: true,
454 }
455 }
456}
457
458fn default_cancel_on_timeout() -> bool {
459 true
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465
466 #[test]
467 fn test_parse_workflow_yaml() {
468 let yaml = r#"
469metadata:
470 id: test_workflow
471 name: Test Workflow
472 description: A test workflow
473
474nodes:
475 - type: start
476 id: start
477
478 - type: llm_agent
479 id: agent1
480 name: First Agent
481 agent:
482 agent_id: my_agent
483
484 - type: end
485 id: end
486
487edges:
488 - from: start
489 to: agent1
490 - from: agent1
491 to: end
492"#;
493
494 let def: WorkflowDefinition = serde_yaml::from_str(yaml).unwrap();
495 assert_eq!(def.metadata.id, "test_workflow");
496 assert_eq!(def.nodes.len(), 3);
497 assert_eq!(def.edges.len(), 2);
498 }
499
500 #[test]
501 fn test_parse_agent_config() {
502 let yaml = r#"
503metadata:
504 id: agent_config_test
505 name: Agent Config Test
506
507nodes:
508 - type: start
509 id: start
510 - type: end
511 id: end
512
513agents:
514 my_agent:
515 model: gpt-4
516 system_prompt: "You are helpful"
517 temperature: 0.7
518 max_tokens: 2000
519"#;
520
521 let def: WorkflowDefinition = serde_yaml::from_str(yaml).unwrap();
522 assert_eq!(def.agents.len(), 1);
523 let agent = def.agents.get("my_agent").unwrap();
524 assert_eq!(agent.model, "gpt-4");
525 assert_eq!(agent.temperature, Some(0.7));
526 }
527
528 #[test]
529 fn test_parse_toml() {
530 let toml = r#"
531[metadata]
532id = "test_workflow"
533name = "Test Workflow"
534
535[[nodes]]
536type = "start"
537id = "start"
538
539[[nodes]]
540type = "end"
541id = "end"
542
543[[edges]]
544from = "start"
545to = "end"
546"#;
547
548 let def: WorkflowDefinition = toml::from_str(toml).unwrap();
549 assert_eq!(def.metadata.id, "test_workflow");
550 assert_eq!(def.nodes.len(), 2);
551 }
552}