1use crate::workflow::{
7 task::TaskId,
8 tasks::{GraphQueryTask, GraphQueryType, AgentLoopTask, ShellCommandTask},
9 dag::{Workflow, WorkflowError},
10 builder::WorkflowBuilder,
11};
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::collections::HashMap;
15use std::path::Path;
16use thiserror::Error;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct YamlWorkflow {
21 pub name: String,
23 #[serde(default)]
25 pub version: Option<String>,
26 #[serde(default)]
28 pub description: Option<String>,
29 pub tasks: Vec<YamlTask>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct YamlTask {
36 pub id: String,
38 pub name: String,
40 #[serde(rename = "type")]
42 pub task_type: YamlTaskType,
43 #[serde(default)]
45 pub depends_on: Vec<String>,
46 #[serde(default)]
48 pub params: YamlTaskParams,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
53#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
54pub enum YamlTaskType {
55 GraphQuery,
57 AgentLoop,
59 Shell,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize, Default)]
65pub struct YamlTaskParams {
66 #[serde(flatten)]
68 pub params: HashMap<String, Value>,
69}
70
71#[derive(Error, Debug)]
73pub enum YamlWorkflowError {
74 #[error("Invalid workflow schema: {0}")]
76 InvalidSchema(String),
77
78 #[error("Invalid task type: {0}")]
80 InvalidTaskType(String),
81
82 #[error("Missing required parameter: {0}")]
84 MissingParameter(String),
85
86 #[error("Conversion error: {0}")]
88 ConversionError(String),
89
90 #[error("I/O error: {0}")]
92 Io(#[from] std::io::Error),
93
94 #[error("YAML parsing error: {0}")]
96 YamlParse(#[from] serde_yaml::Error),
97}
98
99impl TryFrom<YamlWorkflow> for Workflow {
100 type Error = YamlWorkflowError;
101
102 fn try_from(yaml_workflow: YamlWorkflow) -> Result<Self, Self::Error> {
103 let mut builder = WorkflowBuilder::new();
104
105 for yaml_task in &yaml_workflow.tasks {
107 let task_id = TaskId::new(yaml_task.id.clone());
108
109 match yaml_task.task_type {
110 YamlTaskType::GraphQuery => {
111 let query_type_str = yaml_task.params.params.get("query_type")
113 .and_then(|v| v.as_str())
114 .ok_or_else(|| YamlWorkflowError::MissingParameter("query_type".to_string()))?;
115
116 let target = yaml_task.params.params.get("target")
117 .and_then(|v| v.as_str())
118 .ok_or_else(|| YamlWorkflowError::MissingParameter("target".to_string()))?;
119
120 let query_type = match query_type_str {
122 "find_symbol" => GraphQueryType::FindSymbol,
123 "references" => GraphQueryType::References,
124 "impact" | "impact_analysis" => GraphQueryType::ImpactAnalysis,
125 _ => return Err(YamlWorkflowError::InvalidSchema(format!("Unknown query_type: {}", query_type_str))),
126 };
127
128 let task = GraphQueryTask::with_id(
129 task_id.clone(),
130 query_type,
131 target,
132 );
133
134 builder = builder.add_task(Box::new(task));
135 }
136 YamlTaskType::AgentLoop => {
137 let query = yaml_task.params.params.get("query")
139 .and_then(|v| v.as_str())
140 .ok_or_else(|| YamlWorkflowError::MissingParameter("query".to_string()))?;
141
142 let task = AgentLoopTask::new(
143 task_id.clone(),
144 yaml_task.name.clone(),
145 query,
146 );
147
148 builder = builder.add_task(Box::new(task));
149 }
150 YamlTaskType::Shell => {
151 let command = yaml_task.params.params.get("command")
153 .and_then(|v| v.as_str())
154 .ok_or_else(|| YamlWorkflowError::MissingParameter("command".to_string()))?;
155
156 let args: Vec<String> = yaml_task.params.params.get("args")
158 .and_then(|v| v.as_array())
159 .map(|arr| {
160 arr.iter()
161 .filter_map(|v| v.as_str())
162 .map(|s| s.to_string())
163 .collect()
164 })
165 .unwrap_or_default();
166
167 let task = ShellCommandTask::new(
168 task_id.clone(),
169 yaml_task.name.clone(),
170 command,
171 ).with_args(args);
172
173 builder = builder.add_task(Box::new(task));
174 }
175 }
176 }
177
178 for yaml_task in &yaml_workflow.tasks {
180 let task_id = TaskId::new(yaml_task.id.clone());
181 for dep_id in &yaml_task.depends_on {
182 builder = builder.dependency(
183 TaskId::new(dep_id.clone()),
184 task_id.clone(),
185 );
186 }
187 }
188
189 let workflow = builder.build()
191 .map_err(|e| match e {
192 WorkflowError::EmptyWorkflow => YamlWorkflowError::InvalidSchema("Workflow has no tasks".to_string()),
193 WorkflowError::CycleDetected(msg) => YamlWorkflowError::ConversionError(format!("Cycle detected: {:?}", msg)),
194 WorkflowError::TaskNotFound(id) => YamlWorkflowError::ConversionError(format!("Task not found: {}", id)),
195 WorkflowError::MissingDependency(id) => YamlWorkflowError::ConversionError(format!("Missing dependency: {}", id)),
196 WorkflowError::CheckpointCorrupted(msg) => YamlWorkflowError::ConversionError(format!("Checkpoint corrupted: {}", msg)),
197 WorkflowError::CheckpointNotFound(msg) => YamlWorkflowError::ConversionError(format!("Checkpoint not found: {}", msg)),
198 WorkflowError::WorkflowChanged(msg) => YamlWorkflowError::ConversionError(format!("Workflow changed: {}", msg)),
199 WorkflowError::Timeout(err) => YamlWorkflowError::ConversionError(format!("Timeout: {}", err)),
200 WorkflowError::TaskFailed(msg) => YamlWorkflowError::ConversionError(format!("Task failed: {}", msg)),
201 })?;
202
203 Ok(workflow)
204 }
205}
206
207pub async fn load_workflow_from_file(path: &Path) -> Result<Workflow, YamlWorkflowError> {
226 let content = tokio::fs::read_to_string(path).await?;
227 Ok(load_workflow_from_string(&content)?)
228}
229
230pub fn load_workflow_from_string(yaml: &str) -> Result<Workflow, YamlWorkflowError> {
260 let yaml_workflow: YamlWorkflow = serde_yaml::from_str(yaml)?;
261 yaml_workflow.try_into()
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_yaml_parse_basic() {
270 let yaml = r#"
271name: "Test Workflow"
272tasks:
273 - id: "task1"
274 name: "First Task"
275 type: GRAPH_QUERY
276 params:
277 query_type: "find_symbol"
278 target: "my_function"
279"#;
280
281 let workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
282 assert_eq!(workflow.name, "Test Workflow");
283 assert_eq!(workflow.tasks.len(), 1);
284 assert_eq!(workflow.tasks[0].id, "task1");
285 assert_eq!(workflow.tasks[0].task_type, YamlTaskType::GraphQuery);
286 }
287
288 #[test]
289 fn test_yaml_parse_with_dependencies() {
290 let yaml = r#"
291name: "Dependent Workflow"
292tasks:
293 - id: "find"
294 name: "Find Symbol"
295 type: GRAPH_QUERY
296 params:
297 query_type: "find_symbol"
298 target: "process_data"
299 - id: "analyze"
300 name: "Analyze Impact"
301 type: GRAPH_QUERY
302 depends_on: ["find"]
303 params:
304 query_type: "impact"
305 target: "process_data"
306"#;
307
308 let workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
309 assert_eq!(workflow.tasks.len(), 2);
310 assert_eq!(workflow.tasks[1].depends_on, vec!["find"]);
311 }
312
313 #[test]
314 fn test_yaml_parse_with_optional_fields() {
315 let yaml = r#"
316name: "Simple Workflow"
317version: "1.0"
318description: "A test workflow"
319tasks:
320 - id: "task1"
321 name: "Task 1"
322 type: AGENT_LOOP
323 params:
324 query: "Find all functions"
325"#;
326
327 let workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
328 assert_eq!(workflow.version, Some("1.0".to_string()));
329 assert_eq!(workflow.description, Some("A test workflow".to_string()));
330 }
331
332 #[test]
333 fn test_yaml_parse_empty_depends_on() {
334 let yaml = r#"
335name: "Simple Workflow"
336tasks:
337 - id: "task1"
338 name: "Task 1"
339 type: GRAPH_QUERY
340 params:
341 query_type: "find_symbol"
342 target: "test"
343"#;
344
345 let workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
346 assert!(workflow.tasks[0].depends_on.is_empty());
347 }
348
349 #[test]
350 fn test_yaml_parse_invalid_schema() {
351 let yaml = r#"
353tasks:
354 - id: "task1"
355 type: GRAPH_QUERY
356"#;
357
358 let result: Result<YamlWorkflow, _> = serde_yaml::from_str(yaml);
359 assert!(result.is_err());
360 }
361
362 #[test]
363 fn test_graph_query_conversion() {
364 let yaml = r#"
365name: "Graph Query Test"
366tasks:
367 - id: "find"
368 name: "Find Symbol"
369 type: GRAPH_QUERY
370 params:
371 query_type: "find_symbol"
372 target: "process_data"
373"#;
374
375 let yaml_workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
376 let workflow: Result<Workflow, _> = yaml_workflow.try_into();
377
378 assert!(workflow.is_ok());
379 let workflow = workflow.unwrap();
380 assert_eq!(workflow.task_count(), 1);
381 }
382
383 #[test]
384 fn test_yaml_to_workflow() {
385 let yaml = r#"
386name: "Test Workflow"
387tasks:
388 - id: "find"
389 name: "Find Symbol"
390 type: GRAPH_QUERY
391 params:
392 query_type: "find_symbol"
393 target: "my_function"
394 - id: "analyze"
395 name: "Analyze Impact"
396 type: GRAPH_QUERY
397 depends_on: ["find"]
398 params:
399 query_type: "impact"
400 target: "my_function"
401"#;
402
403 let yaml_workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
404 let workflow: Result<Workflow, _> = yaml_workflow.try_into();
405
406 assert!(workflow.is_ok());
407 let workflow = workflow.unwrap();
408 assert_eq!(workflow.task_count(), 2);
409
410 let execution_order = workflow.execution_order().unwrap();
412 assert_eq!(execution_order[0], TaskId::new("find"));
413 assert_eq!(execution_order[1], TaskId::new("analyze"));
414 }
415
416 #[test]
417 fn test_missing_parameter_error() {
418 let yaml = r#"
419name: "Missing Parameter Test"
420tasks:
421 - id: "task1"
422 name: "Task 1"
423 type: GRAPH_QUERY
424 params:
425 query_type: "find_symbol"
426 # Missing 'target' parameter
427"#;
428
429 let yaml_workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
430 let result: Result<Workflow, _> = yaml_workflow.try_into();
431
432 assert!(result.is_err());
433 assert!(matches!(result, Err(YamlWorkflowError::MissingParameter(_))));
434 }
435
436 #[test]
437 fn test_agent_loop_conversion() {
438 let yaml = r#"
439name: "Agent Loop Test"
440tasks:
441 - id: "observe"
442 name: "Gather Context"
443 type: AGENT_LOOP
444 params:
445 query: "Find all functions that call process_data"
446"#;
447
448 let yaml_workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
449 let workflow: Result<Workflow, _> = yaml_workflow.try_into();
450
451 assert!(workflow.is_ok());
452 let workflow = workflow.unwrap();
453 assert_eq!(workflow.task_count(), 1);
454 }
455
456 #[test]
457 fn test_agent_loop_missing_query() {
458 let yaml = r#"
459name: "Agent Loop Missing Query"
460tasks:
461 - id: "task1"
462 name: "Task 1"
463 type: AGENT_LOOP
464 params:
465 # Missing 'query' parameter
466 other: "value"
467"#;
468
469 let yaml_workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
470 let result: Result<Workflow, _> = yaml_workflow.try_into();
471
472 assert!(result.is_err());
473 assert!(matches!(result, Err(YamlWorkflowError::MissingParameter(_))));
474 }
475
476 #[test]
477 fn test_shell_task_stub() {
478 let yaml = r#"
479name: "Shell Task Test"
480tasks:
481 - id: "run"
482 name: "Run Command"
483 type: SHELL
484 params:
485 command: "echo"
486 args: ["hello", "world"]
487"#;
488
489 let yaml_workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
490 let workflow: Result<Workflow, _> = yaml_workflow.try_into();
491
492 assert!(workflow.is_ok());
493 let workflow = workflow.unwrap();
494 assert_eq!(workflow.task_count(), 1);
495 }
496
497 #[test]
498 fn test_shell_task_args_default() {
499 let yaml = r#"
500name: "Shell Task No Args"
501tasks:
502 - id: "run"
503 name: "Run Command"
504 type: SHELL
505 params:
506 command: "ls"
507"#;
508
509 let yaml_workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
510 let workflow: Result<Workflow, _> = yaml_workflow.try_into();
511
512 assert!(workflow.is_ok());
513 let workflow = workflow.unwrap();
514 assert_eq!(workflow.task_count(), 1);
515 }
516
517 #[tokio::test]
518 async fn test_load_from_string() {
519 let yaml = r#"
520name: "Test Workflow"
521tasks:
522 - id: "task1"
523 name: "Task 1"
524 type: GRAPH_QUERY
525 params:
526 query_type: "find_symbol"
527 target: "test_function"
528"#;
529
530 let workflow = load_workflow_from_string(yaml).unwrap();
531 assert_eq!(workflow.task_count(), 1);
532 }
533
534 #[tokio::test]
535 async fn test_load_from_file() {
536 use tempfile::NamedTempFile;
537 use std::io::Write;
538
539 let yaml = r#"
540name: "File Workflow"
541tasks:
542 - id: "task1"
543 name: "Task 1"
544 type: GRAPH_QUERY
545 params:
546 query_type: "find_symbol"
547 target: "my_function"
548"#;
549
550 let mut temp_file = NamedTempFile::new().unwrap();
551 write!(temp_file, "{}", yaml).unwrap();
552
553 let workflow = load_workflow_from_file(temp_file.path()).await.unwrap();
554 assert_eq!(workflow.task_count(), 1);
555 }
556
557 #[tokio::test]
558 async fn test_yaml_round_trip() {
559 let yaml = r#"
560name: "Round Trip Test"
561version: "1.0"
562description: "Test serialization round trip"
563tasks:
564 - id: "task1"
565 name: "Task 1"
566 type: GRAPH_QUERY
567 params:
568 query_type: "find_symbol"
569 target: "test"
570"#;
571
572 let yaml_workflow: YamlWorkflow = serde_yaml::from_str(yaml).unwrap();
574
575 let yaml_out = serde_yaml::to_string(&yaml_workflow).unwrap();
577
578 let yaml_workflow2: YamlWorkflow = serde_yaml::from_str(&yaml_out).unwrap();
580
581 assert_eq!(yaml_workflow.name, yaml_workflow2.name);
583 assert_eq!(yaml_workflow.tasks.len(), yaml_workflow2.tasks.len());
584 }
585
586 #[tokio::test]
587 async fn test_load_simple_graph_query_example() {
588 let yaml = include_str!("examples/simple_graph_query.yaml");
589
590 let workflow = load_workflow_from_string(yaml).unwrap();
591 assert_eq!(workflow.task_count(), 2);
592
593 let execution_order = workflow.execution_order().unwrap();
594 assert_eq!(execution_order[0], TaskId::new("find"));
595 assert_eq!(execution_order[1], TaskId::new("analyze"));
596 }
597
598 #[tokio::test]
599 async fn test_load_agent_assisted_example() {
600 let yaml = include_str!("examples/agent_assisted.yaml");
601
602 let workflow = load_workflow_from_string(yaml).unwrap();
603 assert_eq!(workflow.task_count(), 2);
604
605 let execution_order = workflow.execution_order().unwrap();
606 assert_eq!(execution_order[0], TaskId::new("observe"));
607 assert_eq!(execution_order[1], TaskId::new("plan"));
608 }
609
610 #[tokio::test]
611 async fn test_load_complex_dependencies_example() {
612 let yaml = include_str!("examples/complex_dependencies.yaml");
613
614 let workflow = load_workflow_from_string(yaml).unwrap();
615 assert_eq!(workflow.task_count(), 4);
616
617 let execution_order = workflow.execution_order().unwrap();
618 assert_eq!(execution_order[0], TaskId::new("a"));
620 let b_index = execution_order.iter().position(|id| id == &TaskId::new("b")).unwrap();
622 let c_index = execution_order.iter().position(|id| id == &TaskId::new("c")).unwrap();
623 assert!(b_index > 0 && c_index > 0);
624 let d_index = execution_order.iter().position(|id| id == &TaskId::new("d")).unwrap();
626 assert!(d_index > b_index && d_index > c_index);
627 }
628}