Skip to main content

forge_agent/workflow/
yaml.rs

1//! YAML workflow definition and parsing.
2//!
3//! Provides YAML-based workflow definition for simple workflows.
4//! Complex workflows with custom task types should use the Rust API.
5
6use crate::workflow::{
7    task::TaskId,
8    tasks::{GraphQueryTask, GraphQueryType, AgentLoopTask, ShellCommandTask},
9    dag::{Workflow, WorkflowError},
10    builder::WorkflowBuilder,
11};
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::collections::HashMap;
15use std::path::Path;
16use thiserror::Error;
17
18/// Workflow definition from YAML.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct YamlWorkflow {
21    /// Workflow name
22    pub name: String,
23    /// Optional version for future compatibility
24    #[serde(default)]
25    pub version: Option<String>,
26    /// Optional description
27    #[serde(default)]
28    pub description: Option<String>,
29    /// Workflow tasks
30    pub tasks: Vec<YamlTask>,
31}
32
33/// Task definition from YAML.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct YamlTask {
36    /// Unique task identifier
37    pub id: String,
38    /// Human-readable task name
39    pub name: String,
40    /// Task type
41    #[serde(rename = "type")]
42    pub task_type: YamlTaskType,
43    /// Task dependencies (task IDs)
44    #[serde(default)]
45    pub depends_on: Vec<String>,
46    /// Task-specific parameters
47    #[serde(default)]
48    pub params: YamlTaskParams,
49}
50
51/// Task type enumeration for YAML workflows.
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
53#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
54pub enum YamlTaskType {
55    /// Graph query task (find symbols, references, impact)
56    GraphQuery,
57    /// Agent loop task (AI-driven operations)
58    AgentLoop,
59    /// Shell command task (stub for Phase 11)
60    Shell,
61}
62
63/// Task parameters as flexible JSON values.
64#[derive(Debug, Clone, Serialize, Deserialize, Default)]
65pub struct YamlTaskParams {
66    /// Parameter map
67    #[serde(flatten)]
68    pub params: HashMap<String, Value>,
69}
70
71/// Errors that can occur during YAML workflow parsing.
72#[derive(Error, Debug)]
73pub enum YamlWorkflowError {
74    /// Invalid YAML schema
75    #[error("Invalid workflow schema: {0}")]
76    InvalidSchema(String),
77
78    /// Invalid task type
79    #[error("Invalid task type: {0}")]
80    InvalidTaskType(String),
81
82    /// Missing required parameter
83    #[error("Missing required parameter: {0}")]
84    MissingParameter(String),
85
86    /// Error during workflow conversion
87    #[error("Conversion error: {0}")]
88    ConversionError(String),
89
90    /// I/O error
91    #[error("I/O error: {0}")]
92    Io(#[from] std::io::Error),
93
94    /// YAML parsing error
95    #[error("YAML parsing error: {0}")]
96    YamlParse(#[from] serde_yaml::Error),
97}
98
99impl TryFrom<YamlWorkflow> for Workflow {
100    type Error = YamlWorkflowError;
101
102    fn try_from(yaml_workflow: YamlWorkflow) -> Result<Self, Self::Error> {
103        let mut builder = WorkflowBuilder::new();
104
105        // Add all tasks first
106        for yaml_task in &yaml_workflow.tasks {
107            let task_id = TaskId::new(yaml_task.id.clone());
108
109            match yaml_task.task_type {
110                YamlTaskType::GraphQuery => {
111                    // Extract required parameters
112                    let query_type_str = yaml_task.params.params.get("query_type")
113                        .and_then(|v| v.as_str())
114                        .ok_or_else(|| YamlWorkflowError::MissingParameter("query_type".to_string()))?;
115
116                    let target = yaml_task.params.params.get("target")
117                        .and_then(|v| v.as_str())
118                        .ok_or_else(|| YamlWorkflowError::MissingParameter("target".to_string()))?;
119
120                    // Convert query_type string to enum
121                    let query_type = match query_type_str {
122                        "find_symbol" => GraphQueryType::FindSymbol,
123                        "references" => GraphQueryType::References,
124                        "impact" | "impact_analysis" => GraphQueryType::ImpactAnalysis,
125                        _ => return Err(YamlWorkflowError::InvalidSchema(format!("Unknown query_type: {}", query_type_str))),
126                    };
127
128                    let task = GraphQueryTask::with_id(
129                        task_id.clone(),
130                        query_type,
131                        target,
132                    );
133
134                    builder = builder.add_task(Box::new(task));
135                }
136                YamlTaskType::AgentLoop => {
137                    // Extract query parameter
138                    let query = yaml_task.params.params.get("query")
139                        .and_then(|v| v.as_str())
140                        .ok_or_else(|| YamlWorkflowError::MissingParameter("query".to_string()))?;
141
142                    let task = AgentLoopTask::new(
143                        task_id.clone(),
144                        yaml_task.name.clone(),
145                        query,
146                    );
147
148                    builder = builder.add_task(Box::new(task));
149                }
150                YamlTaskType::Shell => {
151                    // Extract command parameter
152                    let command = yaml_task.params.params.get("command")
153                        .and_then(|v| v.as_str())
154                        .ok_or_else(|| YamlWorkflowError::MissingParameter("command".to_string()))?;
155
156                    // Extract args (optional)
157                    let args: Vec<String> = yaml_task.params.params.get("args")
158                        .and_then(|v| v.as_array())
159                        .map(|arr| {
160                            arr.iter()
161                                .filter_map(|v| v.as_str())
162                                .map(|s| s.to_string())
163                                .collect()
164                        })
165                        .unwrap_or_default();
166
167                    let task = ShellCommandTask::new(
168                        task_id.clone(),
169                        yaml_task.name.clone(),
170                        command,
171                    ).with_args(args);
172
173                    builder = builder.add_task(Box::new(task));
174                }
175            }
176        }
177
178        // Add dependencies after all tasks are added
179        for yaml_task in &yaml_workflow.tasks {
180            let task_id = TaskId::new(yaml_task.id.clone());
181            for dep_id in &yaml_task.depends_on {
182                builder = builder.dependency(
183                    TaskId::new(dep_id.clone()),
184                    task_id.clone(),
185                );
186            }
187        }
188
189        // Build and validate workflow
190        let workflow = builder.build()
191            .map_err(|e| match e {
192                WorkflowError::EmptyWorkflow => YamlWorkflowError::InvalidSchema("Workflow has no tasks".to_string()),
193                WorkflowError::CycleDetected(msg) => YamlWorkflowError::ConversionError(format!("Cycle detected: {:?}", msg)),
194                WorkflowError::TaskNotFound(id) => YamlWorkflowError::ConversionError(format!("Task not found: {}", id)),
195                WorkflowError::MissingDependency(id) => YamlWorkflowError::ConversionError(format!("Missing dependency: {}", id)),
196                WorkflowError::CheckpointCorrupted(msg) => YamlWorkflowError::ConversionError(format!("Checkpoint corrupted: {}", msg)),
197                WorkflowError::CheckpointNotFound(msg) => YamlWorkflowError::ConversionError(format!("Checkpoint not found: {}", msg)),
198                WorkflowError::WorkflowChanged(msg) => YamlWorkflowError::ConversionError(format!("Workflow changed: {}", msg)),
199                WorkflowError::Timeout(err) => YamlWorkflowError::ConversionError(format!("Timeout: {}", err)),
200                WorkflowError::TaskFailed(msg) => YamlWorkflowError::ConversionError(format!("Task failed: {}", msg)),
201            })?;
202
203        Ok(workflow)
204    }
205}
206
207/// Loads a workflow from a YAML file.
208///
209/// # Arguments
210///
211/// * `path` - Path to the YAML file
212///
213/// # Returns
214///
215/// - `Ok(Workflow)` - If workflow loaded and converted successfully
216/// - `Err(YamlWorkflowError)` - If file cannot be read or YAML is invalid
217///
218/// # Example
219///
220/// ```ignore
221/// use forge_agent::workflow::yaml::load_workflow_from_file;
222///
223/// let workflow = load_workflow_from_file(Path::new("workflow.yaml")).await?;
224/// ```
225pub async fn load_workflow_from_file(path: &Path) -> Result<Workflow, YamlWorkflowError> {
226    let content = tokio::fs::read_to_string(path).await?;
227    Ok(load_workflow_from_string(&content)?)
228}
229
230/// Loads a workflow from a YAML string.
231///
232/// # Arguments
233///
234/// * `yaml` - YAML string containing workflow definition
235///
236/// # Returns
237///
238/// - `Ok(Workflow)` - If workflow parsed and converted successfully
239/// - `Err(YamlWorkflowError)` - If YAML is invalid
240///
241/// # Example
242///
243/// ```ignore
244/// use forge_agent::workflow::yaml::load_workflow_from_string;
245///
246/// let yaml = r#"
247/// name: "My Workflow"
248/// tasks:
249///   - id: "task1"
250///     name: "Task 1"
251///     type: GRAPH_QUERY
252///     params:
253///       query_type: "find_symbol"
254///       target: "my_function"
255/// "#;
256///
257/// let workflow = load_workflow_from_string(yaml)?;
258/// ```
259pub fn load_workflow_from_string(yaml: &str) -> Result<Workflow, YamlWorkflowError> {
260    let yaml_workflow: YamlWorkflow = serde_yaml::from_str(yaml)?;
261    yaml_workflow.try_into()
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn test_yaml_parse_basic() {
270        let yaml = r#"
271name: "Test Workflow"
272tasks:
273  - id: "task1"
274    name: "First Task"
275    type: GRAPH_QUERY
276    params:
277      query_type: "find_symbol"
278      target: "my_function"
279"#;
280
281        let workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
282        assert_eq!(workflow.name, "Test Workflow");
283        assert_eq!(workflow.tasks.len(), 1);
284        assert_eq!(workflow.tasks[0].id, "task1");
285        assert_eq!(workflow.tasks[0].task_type, YamlTaskType::GraphQuery);
286    }
287
288    #[test]
289    fn test_yaml_parse_with_dependencies() {
290        let yaml = r#"
291name: "Dependent Workflow"
292tasks:
293  - id: "find"
294    name: "Find Symbol"
295    type: GRAPH_QUERY
296    params:
297      query_type: "find_symbol"
298      target: "process_data"
299  - id: "analyze"
300    name: "Analyze Impact"
301    type: GRAPH_QUERY
302    depends_on: ["find"]
303    params:
304      query_type: "impact"
305      target: "process_data"
306"#;
307
308        let workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
309        assert_eq!(workflow.tasks.len(), 2);
310        assert_eq!(workflow.tasks[1].depends_on, vec!["find"]);
311    }
312
313    #[test]
314    fn test_yaml_parse_with_optional_fields() {
315        let yaml = r#"
316name: "Simple Workflow"
317version: "1.0"
318description: "A test workflow"
319tasks:
320  - id: "task1"
321    name: "Task 1"
322    type: AGENT_LOOP
323    params:
324      query: "Find all functions"
325"#;
326
327        let workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
328        assert_eq!(workflow.version, Some("1.0".to_string()));
329        assert_eq!(workflow.description, Some("A test workflow".to_string()));
330    }
331
332    #[test]
333    fn test_yaml_parse_empty_depends_on() {
334        let yaml = r#"
335name: "Simple Workflow"
336tasks:
337  - id: "task1"
338    name: "Task 1"
339    type: GRAPH_QUERY
340    params:
341      query_type: "find_symbol"
342      target: "test"
343"#;
344
345        let workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
346        assert!(workflow.tasks[0].depends_on.is_empty());
347    }
348
349    #[test]
350    fn test_yaml_parse_invalid_schema() {
351        // Missing required 'name' field
352        let yaml = r#"
353tasks:
354  - id: "task1"
355    type: GRAPH_QUERY
356"#;
357
358        let result: Result<YamlWorkflow, _> = serde_yaml::from_str(yaml);
359        assert!(result.is_err());
360    }
361
362    #[test]
363    fn test_graph_query_conversion() {
364        let yaml = r#"
365name: "Graph Query Test"
366tasks:
367  - id: "find"
368    name: "Find Symbol"
369    type: GRAPH_QUERY
370    params:
371      query_type: "find_symbol"
372      target: "process_data"
373"#;
374
375        let yaml_workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
376        let workflow: Result<Workflow, _> = yaml_workflow.try_into();
377
378        assert!(workflow.is_ok());
379        let workflow = workflow.unwrap();
380        assert_eq!(workflow.task_count(), 1);
381    }
382
383    #[test]
384    fn test_yaml_to_workflow() {
385        let yaml = r#"
386name: "Test Workflow"
387tasks:
388  - id: "find"
389    name: "Find Symbol"
390    type: GRAPH_QUERY
391    params:
392      query_type: "find_symbol"
393      target: "my_function"
394  - id: "analyze"
395    name: "Analyze Impact"
396    type: GRAPH_QUERY
397    depends_on: ["find"]
398    params:
399      query_type: "impact"
400      target: "my_function"
401"#;
402
403        let yaml_workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
404        let workflow: Result<Workflow, _> = yaml_workflow.try_into();
405
406        assert!(workflow.is_ok());
407        let workflow = workflow.unwrap();
408        assert_eq!(workflow.task_count(), 2);
409
410        // Verify execution order respects dependencies
411        let execution_order = workflow.execution_order().unwrap();
412        assert_eq!(execution_order[0], TaskId::new("find"));
413        assert_eq!(execution_order[1], TaskId::new("analyze"));
414    }
415
416    #[test]
417    fn test_missing_parameter_error() {
418        let yaml = r#"
419name: "Missing Parameter Test"
420tasks:
421  - id: "task1"
422    name: "Task 1"
423    type: GRAPH_QUERY
424    params:
425      query_type: "find_symbol"
426      # Missing 'target' parameter
427"#;
428
429        let yaml_workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
430        let result: Result<Workflow, _> = yaml_workflow.try_into();
431
432        assert!(result.is_err());
433        assert!(matches!(result, Err(YamlWorkflowError::MissingParameter(_))));
434    }
435
436    #[test]
437    fn test_agent_loop_conversion() {
438        let yaml = r#"
439name: "Agent Loop Test"
440tasks:
441  - id: "observe"
442    name: "Gather Context"
443    type: AGENT_LOOP
444    params:
445      query: "Find all functions that call process_data"
446"#;
447
448        let yaml_workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
449        let workflow: Result<Workflow, _> = yaml_workflow.try_into();
450
451        assert!(workflow.is_ok());
452        let workflow = workflow.unwrap();
453        assert_eq!(workflow.task_count(), 1);
454    }
455
456    #[test]
457    fn test_agent_loop_missing_query() {
458        let yaml = r#"
459name: "Agent Loop Missing Query"
460tasks:
461  - id: "task1"
462    name: "Task 1"
463    type: AGENT_LOOP
464    params:
465      # Missing 'query' parameter
466      other: "value"
467"#;
468
469        let yaml_workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
470        let result: Result<Workflow, _> = yaml_workflow.try_into();
471
472        assert!(result.is_err());
473        assert!(matches!(result, Err(YamlWorkflowError::MissingParameter(_))));
474    }
475
476    #[test]
477    fn test_shell_task_stub() {
478        let yaml = r#"
479name: "Shell Task Test"
480tasks:
481  - id: "run"
482    name: "Run Command"
483    type: SHELL
484    params:
485      command: "echo"
486      args: ["hello", "world"]
487"#;
488
489        let yaml_workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
490        let workflow: Result<Workflow, _> = yaml_workflow.try_into();
491
492        assert!(workflow.is_ok());
493        let workflow = workflow.unwrap();
494        assert_eq!(workflow.task_count(), 1);
495    }
496
497    #[test]
498    fn test_shell_task_args_default() {
499        let yaml = r#"
500name: "Shell Task No Args"
501tasks:
502  - id: "run"
503    name: "Run Command"
504    type: SHELL
505    params:
506      command: "ls"
507"#;
508
509        let yaml_workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
510        let workflow: Result<Workflow, _> = yaml_workflow.try_into();
511
512        assert!(workflow.is_ok());
513        let workflow = workflow.unwrap();
514        assert_eq!(workflow.task_count(), 1);
515    }
516
517    #[tokio::test]
518    async fn test_load_from_string() {
519        let yaml = r#"
520name: "Test Workflow"
521tasks:
522  - id: "task1"
523    name: "Task 1"
524    type: GRAPH_QUERY
525    params:
526      query_type: "find_symbol"
527      target: "test_function"
528"#;
529
530        let workflow = load_workflow_from_string(yaml).unwrap();
531        assert_eq!(workflow.task_count(), 1);
532    }
533
534    #[tokio::test]
535    async fn test_load_from_file() {
536        use tempfile::NamedTempFile;
537        use std::io::Write;
538
539        let yaml = r#"
540name: "File Workflow"
541tasks:
542  - id: "task1"
543    name: "Task 1"
544    type: GRAPH_QUERY
545    params:
546      query_type: "find_symbol"
547      target: "my_function"
548"#;
549
550        let mut temp_file = NamedTempFile::new().unwrap();
551        write!(temp_file, "{}", yaml).unwrap();
552
553        let workflow = load_workflow_from_file(temp_file.path()).await.unwrap();
554        assert_eq!(workflow.task_count(), 1);
555    }
556
557    #[tokio::test]
558    async fn test_yaml_round_trip() {
559        let yaml = r#"
560name: "Round Trip Test"
561version: "1.0"
562description: "Test serialization round trip"
563tasks:
564  - id: "task1"
565    name: "Task 1"
566    type: GRAPH_QUERY
567    params:
568      query_type: "find_symbol"
569      target: "test"
570"#;
571
572        // Parse YAML to YamlWorkflow
573        let yaml_workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
574
575        // Serialize back to YAML
576        let yaml_out = serde_yaml::to_string(&yaml_workflow).unwrap();
577
578        // Parse again
579        let yaml_workflow2: YamlWorkflow = serde_yaml::from_str(&yaml_out).unwrap();
580
581        // Should be identical
582        assert_eq!(yaml_workflow.name, yaml_workflow2.name);
583        assert_eq!(yaml_workflow.tasks.len(), yaml_workflow2.tasks.len());
584    }
585
586    #[tokio::test]
587    async fn test_load_simple_graph_query_example() {
588        let yaml = include_str!("examples/simple_graph_query.yaml");
589
590        let workflow = load_workflow_from_string(yaml).unwrap();
591        assert_eq!(workflow.task_count(), 2);
592
593        let execution_order = workflow.execution_order().unwrap();
594        assert_eq!(execution_order[0], TaskId::new("find"));
595        assert_eq!(execution_order[1], TaskId::new("analyze"));
596    }
597
598    #[tokio::test]
599    async fn test_load_agent_assisted_example() {
600        let yaml = include_str!("examples/agent_assisted.yaml");
601
602        let workflow = load_workflow_from_string(yaml).unwrap();
603        assert_eq!(workflow.task_count(), 2);
604
605        let execution_order = workflow.execution_order().unwrap();
606        assert_eq!(execution_order[0], TaskId::new("observe"));
607        assert_eq!(execution_order[1], TaskId::new("plan"));
608    }
609
610    #[tokio::test]
611    async fn test_load_complex_dependencies_example() {
612        let yaml = include_str!("examples/complex_dependencies.yaml");
613
614        let workflow = load_workflow_from_string(yaml).unwrap();
615        assert_eq!(workflow.task_count(), 4);
616
617        let execution_order = workflow.execution_order().unwrap();
618        // Verify topological sort respects diamond pattern
619        assert_eq!(execution_order[0], TaskId::new("a"));
620        // B and C can be in either order after A
621        let b_index = execution_order.iter().position(|id| id == &TaskId::new("b")).unwrap();
622        let c_index = execution_order.iter().position(|id| id == &TaskId::new("c")).unwrap();
623        assert!(b_index > 0 && c_index > 0);
624        // D must be after both B and C
625        let d_index = execution_order.iter().position(|id| id == &TaskId::new("d")).unwrap();
626        assert!(d_index > b_index && d_index > c_index);
627    }
628}