1pub mod contract;
13pub mod health;
14pub mod progress;
15pub mod role;
16pub mod story_loop;
17
18use crate::graph::schema::{GraphDefinition, NodeDefinition};
19use anyhow::{anyhow, Context, Result};
20use contract::{FailureAction, StepContract};
21use health::{HealthCheckConfig, RunHealthWatchdog};
22use progress::ProgressJournalWriter;
23use role::{RoleDefinition, RoleRegistry};
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26use std::path::Path;
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct WorkflowDefinition {
31 pub name: String,
33 #[serde(default = "default_version")]
35 pub version: String,
36 pub description: Option<String>,
38 #[serde(default)]
40 pub model: Option<String>,
41 #[serde(default)]
43 pub triggers: Vec<String>,
44 #[serde(default)]
46 pub inputs: HashMap<String, InputParameter>,
47 #[serde(default)]
49 pub roles: Vec<RoleDefinition>,
50 pub steps: Vec<WorkflowStep>,
52 #[serde(default)]
54 pub progress_journal: ProgressJournalConfig,
55 #[serde(default)]
57 pub health_checks: HealthCheckConfig,
58}
59
60fn default_version() -> String {
61 "1.0.0".to_string()
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct InputParameter {
67 #[serde(rename = "type")]
69 pub param_type: String,
70 pub description: Option<String>,
72 #[serde(default)]
74 pub required: bool,
75 #[serde(default)]
77 pub default: Option<serde_json::Value>,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct WorkflowStep {
83 pub id: String,
85 pub name: String,
87 pub role: String,
89 pub description: Option<String>,
91 #[serde(default = "default_step_type")]
93 pub step_type: StepType,
94 #[serde(default)]
96 pub loop_config: Option<story_loop::StoryLoopConfig>,
97 pub input: String,
99 #[serde(default)]
101 pub model: Option<String>,
102 #[serde(default)]
104 pub contract: Option<StepContract>,
105}
106
107fn default_step_type() -> StepType {
108 StepType::Standard
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113#[serde(rename_all = "snake_case")]
114pub enum StepType {
115 Standard,
117 StoryLoop,
119 Verification,
121 Test,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct ProgressJournalConfig {
128 #[serde(default = "default_true")]
130 pub enabled: bool,
131 #[serde(default)]
133 pub template: Option<String>,
134}
135
136fn default_true() -> bool {
137 true
138}
139
140impl Default for ProgressJournalConfig {
141 fn default() -> Self {
142 Self {
143 enabled: true,
144 template: None,
145 }
146 }
147}
148
149pub struct WorkflowLoader;
151
152impl WorkflowLoader {
153 pub async fn load_from_file(path: &Path) -> Result<WorkflowDefinition> {
155 let content = tokio::fs::read_to_string(path)
156 .await
157 .context("Failed to read workflow file")?;
158
159 Self::load_from_str(&content)
160 }
161
162 pub fn load_from_str(yaml: &str) -> Result<WorkflowDefinition> {
164 serde_yaml::from_str(yaml).context("Failed to parse workflow YAML")
165 }
166
167 pub async fn load_from_directory(dir: &Path) -> Result<WorkflowDefinition> {
169 let workflow_file = dir.join("workflow.yml");
170
171 if !workflow_file.exists() {
172 return Err(anyhow!(
173 "Workflow file not found: {}",
174 workflow_file.display()
175 ));
176 }
177
178 Self::load_from_file(&workflow_file).await
179 }
180}
181
182pub struct WorkflowCompiler;
184
185impl WorkflowCompiler {
186 pub fn compile(workflow: &WorkflowDefinition) -> Result<GraphDefinition> {
188 let mut nodes = HashMap::new();
189
190 for step in workflow.steps.iter() {
192 let node_def = Self::compile_step(step)?;
193 nodes.insert(step.id.clone(), node_def);
194 }
195
196 for i in 0..workflow.steps.len() - 1 {
199 let current_id = &workflow.steps[i].id;
200 let next_id = &workflow.steps[i + 1].id;
201
202 if let Some(node) = nodes.get_mut(current_id) {
203 let mut edges = node.edges().clone();
204 edges.insert("_default".to_string(), next_id.clone());
205 *node = Self::update_node_edges(node.clone(), edges)?;
206 }
207 }
208
209 if let Some(last_step) = workflow.steps.last() {
211 if let Some(node) = nodes.get_mut(&last_step.id) {
212 let mut edges = node.edges().clone();
213 edges.insert("_default".to_string(), "END".to_string());
214 *node = Self::update_node_edges(node.clone(), edges)?;
215 }
216 }
217
218 Ok(GraphDefinition {
219 name: workflow.name.clone(),
220 version: workflow.version.clone(),
221 description: workflow.description.clone(),
222 model: workflow.model.clone(),
223 triggers: workflow.triggers.clone(),
224 inputs: workflow
225 .inputs
226 .iter()
227 .map(|(k, v)| (k.clone(), v.description.clone().unwrap_or_default()))
228 .collect(),
229 nodes,
230 })
231 }
232
233 fn compile_step(step: &WorkflowStep) -> Result<NodeDefinition> {
235 match step.step_type {
236 StepType::Standard | StepType::Verification | StepType::Test => {
237 Ok(NodeDefinition::Llm {
238 model: step.model.clone(),
239 system_prompt: format!(
240 "You are the {} agent. {}",
241 step.name,
242 step.description.clone().unwrap_or_default()
243 ),
244 tools: vec![],
245 edges: HashMap::new(),
246 })
247 }
248 StepType::StoryLoop => {
249 Ok(NodeDefinition::Graph {
251 graph_name: format!("{}_loop", step.id),
252 edges: HashMap::new(),
253 })
254 }
255 }
256 }
257
258 fn update_node_edges(
260 node: NodeDefinition,
261 edges: HashMap<String, String>,
262 ) -> Result<NodeDefinition> {
263 match node {
264 NodeDefinition::Llm {
265 model,
266 system_prompt,
267 tools,
268 ..
269 } => Ok(NodeDefinition::Llm {
270 model,
271 system_prompt,
272 tools,
273 edges,
274 }),
275 NodeDefinition::Function { action, .. } => {
276 Ok(NodeDefinition::Function { action, edges })
277 }
278 NodeDefinition::Condition { expr, .. } => Ok(NodeDefinition::Condition { expr, edges }),
279 NodeDefinition::Graph { graph_name, .. } => {
280 Ok(NodeDefinition::Graph { graph_name, edges })
281 }
282 }
283 }
284}
285
286pub struct WorkflowValidator;
288
289impl WorkflowValidator {
290 pub fn validate(workflow: &WorkflowDefinition) -> Result<Vec<ValidationIssue>> {
292 let mut issues = vec![];
293
294 let mut seen_ids = std::collections::HashSet::new();
296 for step in &workflow.steps {
297 if !seen_ids.insert(step.id.clone()) {
298 issues.push(ValidationIssue {
299 severity: ValidationSeverity::Error,
300 message: format!("Duplicate step ID: {}", step.id),
301 location: format!("steps.{}", step.id),
302 });
303 }
304 }
305
306 let role_ids: std::collections::HashSet<_> = workflow.roles.iter().map(|r| &r.id).collect();
308
309 for step in &workflow.steps {
310 if !role_ids.contains(&step.role) {
311 issues.push(ValidationIssue {
312 severity: ValidationSeverity::Error,
313 message: format!(
314 "Step '{}' references undefined role '{}'",
315 step.id, step.role
316 ),
317 location: format!("steps.{}.role", step.id),
318 });
319 }
320 }
321
322 for step in &workflow.steps {
324 if let Some(contract) = &step.contract {
325 if let Err(e) = Self::validate_contract(contract) {
326 issues.push(ValidationIssue {
327 severity: ValidationSeverity::Warning,
328 message: format!("Invalid contract in step '{}': {}", step.id, e),
329 location: format!("steps.{}.contract", step.id),
330 });
331 }
332 }
333 }
334
335 Ok(issues)
336 }
337
338 fn validate_contract(contract: &StepContract) -> Result<()> {
340 if let Some(FailureAction::Retry {
342 retry_target: Some(target),
343 ..
344 }) = &contract.on_failure
345 {
346 if target.is_empty() {
349 return Err(anyhow!("Empty retry target"));
350 }
351 }
352
353 Ok(())
354 }
355}
356
357#[derive(Debug, Clone)]
359pub struct ValidationIssue {
360 pub severity: ValidationSeverity,
362 pub message: String,
364 pub location: String,
366}
367
368#[derive(Debug, Clone, Copy, PartialEq, Eq)]
370pub enum ValidationSeverity {
371 Info,
372 Warning,
373 Error,
374}
375
376pub struct WorkflowContext {
378 pub workflow: WorkflowDefinition,
380 pub role_registry: RoleRegistry,
382 pub progress_writer: Option<ProgressJournalWriter>,
384 pub health_watchdog: Option<RunHealthWatchdog>,
386 pub inputs: HashMap<String, serde_json::Value>,
388 pub step_outputs: HashMap<String, contract::ParsedOutput>,
390}
391
392impl WorkflowContext {
393 pub fn new(workflow: WorkflowDefinition, inputs: HashMap<String, serde_json::Value>) -> Self {
395 let mut role_registry = RoleRegistry::new();
396 role_registry.load_from_workflow(workflow.roles.clone());
397
398 Self {
399 workflow,
400 role_registry,
401 progress_writer: None,
402 health_watchdog: None,
403 inputs,
404 step_outputs: HashMap::new(),
405 }
406 }
407
408 pub fn with_progress_journal(mut self, run_id: String, task: String) -> Self {
410 self.progress_writer = Some(ProgressJournalWriter::new(
411 run_id,
412 self.workflow.name.clone(),
413 task,
414 ));
415 self
416 }
417
418 pub fn with_health_watchdog(mut self, config: HealthCheckConfig) -> Self {
420 self.health_watchdog = Some(RunHealthWatchdog::new(config));
421 self
422 }
423
424 pub fn get_step_output(&self, step_id: &str) -> Option<&contract::ParsedOutput> {
426 self.step_outputs.get(step_id)
427 }
428
429 pub fn set_step_output(&mut self, step_id: String, output: contract::ParsedOutput) {
431 self.step_outputs.insert(step_id, output);
432 }
433
434 pub fn substitute_variables(&self, template: &str) -> String {
436 let mut result = template.to_string();
437
438 for (key, value) in &self.inputs {
440 let placeholder = format!("{{{{{}}}}}", key);
441 let value_str = match value {
442 serde_json::Value::String(s) => s.clone(),
443 _ => value.to_string(),
444 };
445 result = result.replace(&placeholder, &value_str);
446 }
447
448 for (step_id, output) in &self.step_outputs {
450 for (field_name, field_value) in &output.fields {
451 let placeholder = format!("{{{{{}.{}}}}}", step_id, field_name);
452 let value_str = match field_value {
453 serde_json::Value::String(s) => s.clone(),
454 _ => field_value.to_string(),
455 };
456 result = result.replace(&placeholder, &value_str);
457 }
458 }
459
460 result
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use contract::StepStatus;
468
469 const TEST_WORKFLOW_YAML: &str = r#"
470name: test-workflow
471version: "1.0.0"
472description: A test workflow
473
474inputs:
475 task:
476 type: string
477 required: true
478
479roles:
480 - id: planner
481 name: Planner
482 profile: analysis
483
484steps:
485 - id: plan
486 name: Plan
487 role: planner
488 input: "Task: {{task}}"
489 contract:
490 expects:
491 status: done
492 on_failure:
493 action: retry
494 max_retries: 2
495"#;
496
497 #[test]
498 fn test_load_workflow() {
499 let workflow = WorkflowLoader::load_from_str(TEST_WORKFLOW_YAML).unwrap();
500 assert_eq!(workflow.name, "test-workflow");
501 assert_eq!(workflow.steps.len(), 1);
502 }
503
504 #[test]
505 fn test_validate_workflow() {
506 let workflow = WorkflowLoader::load_from_str(TEST_WORKFLOW_YAML).unwrap();
507 let issues = WorkflowValidator::validate(&workflow).unwrap();
508 assert!(issues.is_empty());
509 }
510
511 #[test]
512 fn test_workflow_context_substitution() {
513 let workflow = WorkflowLoader::load_from_str(TEST_WORKFLOW_YAML).unwrap();
514 let mut context = WorkflowContext::new(
515 workflow,
516 [(
517 "task".to_string(),
518 serde_json::Value::String("Implement feature".to_string()),
519 )]
520 .into_iter()
521 .collect(),
522 );
523
524 context.set_step_output(
526 "plan".to_string(),
527 contract::ParsedOutput {
528 status: StepStatus::Done,
529 fields: [(
530 "REPO".to_string(),
531 serde_json::Value::String("/path/to/repo".to_string()),
532 )]
533 .into_iter()
534 .collect(),
535 raw_output: "STATUS: done\nREPO: /path/to/repo".to_string(),
536 },
537 );
538
539 let template = "Task: {{task}}, Repo: {{plan.REPO}}";
540 let result = context.substitute_variables(template);
541
542 assert!(result.contains("Implement feature"));
543 assert!(result.contains("/path/to/repo"));
544 }
545}