Skip to main content

mofa_foundation/workflow/dsl/
schema.rs

1//! Workflow DSL Schema
2//!
3//! Defines the data structures for declarative workflow configuration.
4
5use crate::workflow::node::NodeType;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// Workflow definition from YAML/TOML
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct WorkflowDefinition {
12    /// Workflow metadata
13    pub metadata: WorkflowMetadata,
14
15    /// Workflow configuration
16    #[serde(default)]
17    pub config: WorkflowConfig,
18
19    /// Node definitions
20    pub nodes: Vec<NodeDefinition>,
21
22    /// Edge definitions
23    #[serde(default)]
24    pub edges: Vec<EdgeDefinition>,
25
26    /// Agent definitions (inline or reusable)
27    #[serde(default)]
28    pub agents: HashMap<String, LlmAgentConfig>,
29}
30
31/// Workflow metadata
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct WorkflowMetadata {
34    /// Unique workflow identifier
35    pub id: String,
36
37    /// Human-readable name
38    pub name: String,
39
40    /// Workflow description
41    #[serde(default)]
42    pub description: String,
43
44    /// Workflow version
45    #[serde(default)]
46    pub version: Option<String>,
47
48    /// Author
49    #[serde(default)]
50    pub author: Option<String>,
51
52    /// Tags
53    #[serde(default)]
54    pub tags: Vec<String>,
55}
56
57/// Workflow-level configuration
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct WorkflowConfig {
60    /// Maximum parallel executions
61    #[serde(default = "default_max_parallel")]
62    pub max_parallel: usize,
63
64    /// Default timeout (milliseconds)
65    #[serde(default = "default_timeout")]
66    pub default_timeout_ms: u64,
67
68    /// Enable checkpoints
69    #[serde(default)]
70    pub enable_checkpoints: bool,
71
72    /// Retry policy for all nodes
73    #[serde(default)]
74    pub retry_policy: Option<RetryPolicy>,
75}
76
77impl Default for WorkflowConfig {
78    fn default() -> Self {
79        Self {
80            max_parallel: default_max_parallel(),
81            default_timeout_ms: default_timeout(),
82            enable_checkpoints: false,
83            retry_policy: None,
84        }
85    }
86}
87
88fn default_max_parallel() -> usize {
89    10
90}
91
92fn default_timeout() -> u64 {
93    60000 // 1 minute
94}
95
96/// Node definition (tagged enum for different node types)
97#[derive(Debug, Clone, Serialize, Deserialize)]
98#[serde(tag = "type", rename_all = "snake_case")]
99pub enum NodeDefinition {
100    /// Start node
101    Start {
102        id: String,
103        #[serde(default)]
104        name: Option<String>,
105    },
106
107    /// End node
108    End {
109        id: String,
110        #[serde(default)]
111        name: Option<String>,
112    },
113
114    /// Task node (custom executor)
115    Task {
116        id: String,
117        name: String,
118        #[serde(flatten)]
119        executor: TaskExecutorDef,
120        #[serde(default)]
121        config: NodeConfigDef,
122    },
123
124    /// LLM Agent node
125    #[serde(rename = "llm_agent")]
126    LLM_AGENT {
127        id: String,
128        name: String,
129        /// Agent reference (agent_id for registry agents, inline for embedded)
130        agent: AgentRef,
131        /// Optional prompt template
132        #[serde(default)]
133        prompt_template: Option<String>,
134        #[serde(default)]
135        config: NodeConfigDef,
136    },
137
138    /// Condition node
139    Condition {
140        id: String,
141        name: String,
142        condition: ConditionDef,
143        #[serde(default)]
144        config: NodeConfigDef,
145    },
146
147    /// Parallel node
148    Parallel {
149        id: String,
150        name: String,
151        #[serde(default)]
152        config: NodeConfigDef,
153    },
154
155    /// Join node
156    Join {
157        id: String,
158        name: String,
159        /// List of node IDs to wait for
160        #[serde(default)]
161        wait_for: Vec<String>,
162        #[serde(default)]
163        config: NodeConfigDef,
164    },
165
166    /// Loop node
167    Loop {
168        id: String,
169        name: String,
170        #[serde(flatten)]
171        body: TaskExecutorDef,
172        condition: LoopConditionDef,
173        #[serde(default)]
174        max_iterations: u32,
175        #[serde(default)]
176        config: NodeConfigDef,
177    },
178
179    /// Transform node
180    Transform {
181        id: String,
182        name: String,
183        #[serde(flatten)]
184        transform: TransformDef,
185        #[serde(default)]
186        config: NodeConfigDef,
187    },
188
189    /// Sub-workflow node
190    SubWorkflow {
191        id: String,
192        name: String,
193        /// Reference to another workflow
194        workflow_id: String,
195        #[serde(default)]
196        config: NodeConfigDef,
197    },
198
199    /// Wait node
200    Wait {
201        id: String,
202        name: String,
203        /// Event type to wait for
204        event_type: String,
205        #[serde(default)]
206        config: NodeConfigDef,
207    },
208}
209
210impl NodeDefinition {
211    /// Get the node ID
212    pub fn id(&self) -> &str {
213        match self {
214            NodeDefinition::Start { id, .. } => id,
215            NodeDefinition::End { id, .. } => id,
216            NodeDefinition::Task { id, .. } => id,
217            NodeDefinition::LLM_AGENT { id, .. } => id,
218            NodeDefinition::Condition { id, .. } => id,
219            NodeDefinition::Parallel { id, .. } => id,
220            NodeDefinition::Join { id, .. } => id,
221            NodeDefinition::Loop { id, .. } => id,
222            NodeDefinition::Transform { id, .. } => id,
223            NodeDefinition::SubWorkflow { id, .. } => id,
224            NodeDefinition::Wait { id, .. } => id,
225        }
226    }
227
228    /// Get the node type
229    pub fn node_type(&self) -> NodeType {
230        match self {
231            NodeDefinition::Start { .. } => NodeType::Start,
232            NodeDefinition::End { .. } => NodeType::End,
233            NodeDefinition::Task { .. } => NodeType::Task,
234            NodeDefinition::LLM_AGENT { .. } => NodeType::Agent,
235            NodeDefinition::Condition { .. } => NodeType::Condition,
236            NodeDefinition::Parallel { .. } => NodeType::Parallel,
237            NodeDefinition::Join { .. } => NodeType::Join,
238            NodeDefinition::Loop { .. } => NodeType::Loop,
239            NodeDefinition::Transform { .. } => NodeType::Transform,
240            NodeDefinition::SubWorkflow { .. } => NodeType::SubWorkflow,
241            NodeDefinition::Wait { .. } => NodeType::Wait,
242        }
243    }
244}
245
246/// Edge definition
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct EdgeDefinition {
249    /// Source node ID
250    pub from: String,
251
252    /// Target node ID
253    pub to: String,
254
255    /// Conditional edge (optional)
256    #[serde(default)]
257    pub condition: Option<String>,
258
259    /// Edge label (optional)
260    #[serde(default)]
261    pub label: Option<String>,
262}
263
264/// LLM Agent configuration
265#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct LlmAgentConfig {
267    /// Model identifier
268    pub model: String,
269
270    /// System prompt
271    #[serde(default)]
272    pub system_prompt: Option<String>,
273
274    /// Temperature
275    #[serde(default)]
276    pub temperature: Option<f32>,
277
278    /// Max tokens
279    #[serde(default)]
280    pub max_tokens: Option<u32>,
281
282    /// Context window size (in rounds)
283    #[serde(default)]
284    pub context_window_size: Option<usize>,
285
286    /// User ID for persistence
287    #[serde(default)]
288    pub user_id: Option<String>,
289
290    /// Tenant ID for persistence
291    #[serde(default)]
292    pub tenant_id: Option<String>,
293}
294
295/// Agent reference (can be registry or inline)
296#[derive(Debug, Clone, Serialize, Deserialize)]
297#[serde(untagged)]
298pub enum AgentRef {
299    /// Reference to agent by ID (in registry)
300    Registry { agent_id: String },
301
302    /// Inline agent configuration
303    Inline(Box<LlmAgentConfig>),
304}
305
306/// Task executor definition
307#[derive(Debug, Clone, Serialize, Deserialize)]
308#[serde(tag = "executor_type", rename_all = "snake_case")]
309pub enum TaskExecutorDef {
310    /// Function executor (for code-defined tasks)
311    Function { function: String },
312
313    /// HTTP request executor
314    Http {
315        url: String,
316        #[serde(default)]
317        method: Option<String>,
318    },
319
320    /// Script executor (Rhai)
321    Script { script: String },
322
323    /// No-op executor (for testing)
324    None,
325}
326
327/// Condition definition
328#[derive(Debug, Clone, Serialize, Deserialize)]
329#[serde(tag = "condition_type", rename_all = "snake_case")]
330pub enum ConditionDef {
331    /// Expression-based condition
332    Expression { expr: String },
333
334    /// Value-based condition
335    Value {
336        field: String,
337        operator: String,
338        value: serde_json::Value,
339    },
340}
341
342/// Loop condition definition
343#[derive(Debug, Clone, Serialize, Deserialize)]
344#[serde(tag = "condition_type", rename_all = "snake_case")]
345pub enum LoopConditionDef {
346    /// While-style loop
347    While { expr: String },
348
349    /// Until-style loop
350    Until { expr: String },
351
352    /// Count-based loop
353    Count { max: u32 },
354}
355
356/// Transform definition
357#[derive(Debug, Clone, Serialize, Deserialize)]
358#[serde(tag = "transform_type", rename_all = "snake_case")]
359pub enum TransformDef {
360    /// Jinja-style template
361    Template { template: String },
362
363    /// JavaScript expression
364    Expression { expr: String },
365
366    /// Map/reduce operation
367    MapReduce {
368        #[serde(default)]
369        map: Option<String>,
370        #[serde(default)]
371        reduce: Option<String>,
372    },
373}
374
375/// Node-level configuration
376#[derive(Debug, Clone, Serialize, Deserialize, Default)]
377pub struct NodeConfigDef {
378    /// Retry policy
379    #[serde(default)]
380    pub retry_policy: Option<RetryPolicy>,
381
382    /// Timeout (milliseconds)
383    #[serde(default)]
384    pub timeout_ms: Option<u64>,
385
386    /// Custom metadata
387    #[serde(default)]
388    pub metadata: HashMap<String, String>,
389}
390
391/// Retry policy configuration
392#[derive(Debug, Clone, Serialize, Deserialize)]
393pub struct RetryPolicy {
394    /// Maximum retry attempts
395    #[serde(default = "default_max_retries")]
396    pub max_retries: u32,
397
398    /// Delay between retries (milliseconds)
399    #[serde(default = "default_retry_delay")]
400    pub retry_delay_ms: u64,
401
402    /// Enable exponential backoff
403    #[serde(default = "default_exponential_backoff")]
404    pub exponential_backoff: bool,
405
406    /// Maximum delay (milliseconds)
407    #[serde(default = "default_max_delay")]
408    pub max_delay_ms: u64,
409}
410
411impl Default for RetryPolicy {
412    fn default() -> Self {
413        Self {
414            max_retries: default_max_retries(),
415            retry_delay_ms: default_retry_delay(),
416            exponential_backoff: default_exponential_backoff(),
417            max_delay_ms: default_max_delay(),
418        }
419    }
420}
421
422fn default_max_retries() -> u32 {
423    3
424}
425
426fn default_retry_delay() -> u64 {
427    1000
428}
429
430fn default_exponential_backoff() -> bool {
431    true
432}
433
434fn default_max_delay() -> u64 {
435    30000
436}
437
438/// Timeout configuration
439#[derive(Debug, Clone, Serialize, Deserialize)]
440pub struct TimeoutConfig {
441    /// Execution timeout (milliseconds)
442    pub execution_timeout_ms: u64,
443
444    /// Cancel on timeout
445    #[serde(default = "default_cancel_on_timeout")]
446    pub cancel_on_timeout: bool,
447}
448
449impl Default for TimeoutConfig {
450    fn default() -> Self {
451        Self {
452            execution_timeout_ms: 60000,
453            cancel_on_timeout: true,
454        }
455    }
456}
457
458fn default_cancel_on_timeout() -> bool {
459    true
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465
466    #[test]
467    fn test_parse_workflow_yaml() {
468        let yaml = r#"
469metadata:
470  id: test_workflow
471  name: Test Workflow
472  description: A test workflow
473
474nodes:
475  - type: start
476    id: start
477
478  - type: llm_agent
479    id: agent1
480    name: First Agent
481    agent:
482      agent_id: my_agent
483
484  - type: end
485    id: end
486
487edges:
488  - from: start
489    to: agent1
490  - from: agent1
491    to: end
492"#;
493
494        let def: WorkflowDefinition = serde_yaml::from_str(yaml).unwrap();
495        assert_eq!(def.metadata.id, "test_workflow");
496        assert_eq!(def.nodes.len(), 3);
497        assert_eq!(def.edges.len(), 2);
498    }
499
500    #[test]
501    fn test_parse_agent_config() {
502        let yaml = r#"
503metadata:
504  id: agent_config_test
505  name: Agent Config Test
506
507nodes:
508  - type: start
509    id: start
510  - type: end
511    id: end
512
513agents:
514  my_agent:
515    model: gpt-4
516    system_prompt: "You are helpful"
517    temperature: 0.7
518    max_tokens: 2000
519"#;
520
521        let def: WorkflowDefinition = serde_yaml::from_str(yaml).unwrap();
522        assert_eq!(def.agents.len(), 1);
523        let agent = def.agents.get("my_agent").unwrap();
524        assert_eq!(agent.model, "gpt-4");
525        assert_eq!(agent.temperature, Some(0.7));
526    }
527
528    #[test]
529    fn test_parse_toml() {
530        let toml = r#"
531[metadata]
532id = "test_workflow"
533name = "Test Workflow"
534
535[[nodes]]
536type = "start"
537id = "start"
538
539[[nodes]]
540type = "end"
541id = "end"
542
543[[edges]]
544from = "start"
545to = "end"
546"#;
547
548        let def: WorkflowDefinition = toml::from_str(toml).unwrap();
549        assert_eq!(def.metadata.id, "test_workflow");
550        assert_eq!(def.nodes.len(), 2);
551    }
552}