pub mod contract;
pub mod health;
pub mod progress;
pub mod role;
pub mod story_loop;
use crate::graph::schema::{GraphDefinition, NodeDefinition};
use anyhow::{anyhow, Context, Result};
use contract::{FailureAction, StepContract};
use health::{HealthCheckConfig, RunHealthWatchdog};
use progress::ProgressJournalWriter;
use role::{RoleDefinition, RoleRegistry};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowDefinition {
pub name: String,
#[serde(default = "default_version")]
pub version: String,
pub description: Option<String>,
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub triggers: Vec<String>,
#[serde(default)]
pub inputs: HashMap<String, InputParameter>,
#[serde(default)]
pub roles: Vec<RoleDefinition>,
pub steps: Vec<WorkflowStep>,
#[serde(default)]
pub progress_journal: ProgressJournalConfig,
#[serde(default)]
pub health_checks: HealthCheckConfig,
}
fn default_version() -> String {
"1.0.0".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InputParameter {
#[serde(rename = "type")]
pub param_type: String,
pub description: Option<String>,
#[serde(default)]
pub required: bool,
#[serde(default)]
pub default: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowStep {
pub id: String,
pub name: String,
pub role: String,
pub description: Option<String>,
#[serde(default = "default_step_type")]
pub step_type: StepType,
#[serde(default)]
pub loop_config: Option<story_loop::StoryLoopConfig>,
pub input: String,
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub contract: Option<StepContract>,
}
fn default_step_type() -> StepType {
StepType::Standard
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StepType {
Standard,
StoryLoop,
Verification,
Test,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProgressJournalConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default)]
pub template: Option<String>,
}
fn default_true() -> bool {
true
}
impl Default for ProgressJournalConfig {
fn default() -> Self {
Self {
enabled: true,
template: None,
}
}
}
pub struct WorkflowLoader;
impl WorkflowLoader {
pub async fn load_from_file(path: &Path) -> Result<WorkflowDefinition> {
let content = tokio::fs::read_to_string(path)
.await
.context("Failed to read workflow file")?;
Self::load_from_str(&content)
}
pub fn load_from_str(yaml: &str) -> Result<WorkflowDefinition> {
serde_yaml::from_str(yaml).context("Failed to parse workflow YAML")
}
pub async fn load_from_directory(dir: &Path) -> Result<WorkflowDefinition> {
let workflow_file = dir.join("workflow.yml");
if !workflow_file.exists() {
return Err(anyhow!(
"Workflow file not found: {}",
workflow_file.display()
));
}
Self::load_from_file(&workflow_file).await
}
}
pub struct WorkflowCompiler;
impl WorkflowCompiler {
pub fn compile(workflow: &WorkflowDefinition) -> Result<GraphDefinition> {
let mut nodes = HashMap::new();
for step in workflow.steps.iter() {
let node_def = Self::compile_step(step)?;
nodes.insert(step.id.clone(), node_def);
}
for i in 0..workflow.steps.len() - 1 {
let current_id = &workflow.steps[i].id;
let next_id = &workflow.steps[i + 1].id;
if let Some(node) = nodes.get_mut(current_id) {
let mut edges = node.edges().clone();
edges.insert("_default".to_string(), next_id.clone());
*node = Self::update_node_edges(node.clone(), edges)?;
}
}
if let Some(last_step) = workflow.steps.last() {
if let Some(node) = nodes.get_mut(&last_step.id) {
let mut edges = node.edges().clone();
edges.insert("_default".to_string(), "END".to_string());
*node = Self::update_node_edges(node.clone(), edges)?;
}
}
Ok(GraphDefinition {
name: workflow.name.clone(),
version: workflow.version.clone(),
description: workflow.description.clone(),
model: workflow.model.clone(),
triggers: workflow.triggers.clone(),
inputs: workflow
.inputs
.iter()
.map(|(k, v)| (k.clone(), v.description.clone().unwrap_or_default()))
.collect(),
nodes,
})
}
fn compile_step(step: &WorkflowStep) -> Result<NodeDefinition> {
match step.step_type {
StepType::Standard | StepType::Verification | StepType::Test => {
Ok(NodeDefinition::Llm {
model: step.model.clone(),
system_prompt: format!(
"You are the {} agent. {}",
step.name,
step.description.clone().unwrap_or_default()
),
tools: vec![],
edges: HashMap::new(),
})
}
StepType::StoryLoop => {
Ok(NodeDefinition::Graph {
graph_name: format!("{}_loop", step.id),
edges: HashMap::new(),
})
}
}
}
fn update_node_edges(
node: NodeDefinition,
edges: HashMap<String, String>,
) -> Result<NodeDefinition> {
match node {
NodeDefinition::Llm {
model,
system_prompt,
tools,
..
} => Ok(NodeDefinition::Llm {
model,
system_prompt,
tools,
edges,
}),
NodeDefinition::Function { action, .. } => {
Ok(NodeDefinition::Function { action, edges })
}
NodeDefinition::Condition { expr, .. } => Ok(NodeDefinition::Condition { expr, edges }),
NodeDefinition::Graph { graph_name, .. } => {
Ok(NodeDefinition::Graph { graph_name, edges })
}
}
}
}
pub struct WorkflowValidator;
impl WorkflowValidator {
pub fn validate(workflow: &WorkflowDefinition) -> Result<Vec<ValidationIssue>> {
let mut issues = vec![];
let mut seen_ids = std::collections::HashSet::new();
for step in &workflow.steps {
if !seen_ids.insert(step.id.clone()) {
issues.push(ValidationIssue {
severity: ValidationSeverity::Error,
message: format!("Duplicate step ID: {}", step.id),
location: format!("steps.{}", step.id),
});
}
}
let role_ids: std::collections::HashSet<_> = workflow.roles.iter().map(|r| &r.id).collect();
for step in &workflow.steps {
if !role_ids.contains(&step.role) {
issues.push(ValidationIssue {
severity: ValidationSeverity::Error,
message: format!(
"Step '{}' references undefined role '{}'",
step.id, step.role
),
location: format!("steps.{}.role", step.id),
});
}
}
for step in &workflow.steps {
if let Some(contract) = &step.contract {
if let Err(e) = Self::validate_contract(contract) {
issues.push(ValidationIssue {
severity: ValidationSeverity::Warning,
message: format!("Invalid contract in step '{}': {}", step.id, e),
location: format!("steps.{}.contract", step.id),
});
}
}
}
Ok(issues)
}
fn validate_contract(contract: &StepContract) -> Result<()> {
if let Some(FailureAction::Retry {
retry_target: Some(target),
..
}) = &contract.on_failure
{
if target.is_empty() {
return Err(anyhow!("Empty retry target"));
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ValidationIssue {
pub severity: ValidationSeverity,
pub message: String,
pub location: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ValidationSeverity {
Info,
Warning,
Error,
}
pub struct WorkflowContext {
pub workflow: WorkflowDefinition,
pub role_registry: RoleRegistry,
pub progress_writer: Option<ProgressJournalWriter>,
pub health_watchdog: Option<RunHealthWatchdog>,
pub inputs: HashMap<String, serde_json::Value>,
pub step_outputs: HashMap<String, contract::ParsedOutput>,
}
impl WorkflowContext {
pub fn new(workflow: WorkflowDefinition, inputs: HashMap<String, serde_json::Value>) -> Self {
let mut role_registry = RoleRegistry::new();
role_registry.load_from_workflow(workflow.roles.clone());
Self {
workflow,
role_registry,
progress_writer: None,
health_watchdog: None,
inputs,
step_outputs: HashMap::new(),
}
}
pub fn with_progress_journal(mut self, run_id: String, task: String) -> Self {
self.progress_writer = Some(ProgressJournalWriter::new(
run_id,
self.workflow.name.clone(),
task,
));
self
}
pub fn with_health_watchdog(mut self, config: HealthCheckConfig) -> Self {
self.health_watchdog = Some(RunHealthWatchdog::new(config));
self
}
pub fn get_step_output(&self, step_id: &str) -> Option<&contract::ParsedOutput> {
self.step_outputs.get(step_id)
}
pub fn set_step_output(&mut self, step_id: String, output: contract::ParsedOutput) {
self.step_outputs.insert(step_id, output);
}
pub fn substitute_variables(&self, template: &str) -> String {
let mut result = template.to_string();
for (key, value) in &self.inputs {
let placeholder = format!("{{{{{}}}}}", key);
let value_str = match value {
serde_json::Value::String(s) => s.clone(),
_ => value.to_string(),
};
result = result.replace(&placeholder, &value_str);
}
for (step_id, output) in &self.step_outputs {
for (field_name, field_value) in &output.fields {
let placeholder = format!("{{{{{}.{}}}}}", step_id, field_name);
let value_str = match field_value {
serde_json::Value::String(s) => s.clone(),
_ => field_value.to_string(),
};
result = result.replace(&placeholder, &value_str);
}
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use contract::StepStatus;
const TEST_WORKFLOW_YAML: &str = r#"
name: test-workflow
version: "1.0.0"
description: A test workflow
inputs:
task:
type: string
required: true
roles:
- id: planner
name: Planner
profile: analysis
steps:
- id: plan
name: Plan
role: planner
input: "Task: {{task}}"
contract:
expects:
status: done
on_failure:
action: retry
max_retries: 2
"#;
#[test]
fn test_load_workflow() {
let workflow = WorkflowLoader::load_from_str(TEST_WORKFLOW_YAML).unwrap();
assert_eq!(workflow.name, "test-workflow");
assert_eq!(workflow.steps.len(), 1);
}
#[test]
fn test_validate_workflow() {
let workflow = WorkflowLoader::load_from_str(TEST_WORKFLOW_YAML).unwrap();
let issues = WorkflowValidator::validate(&workflow).unwrap();
assert!(issues.is_empty());
}
#[test]
fn test_workflow_context_substitution() {
let workflow = WorkflowLoader::load_from_str(TEST_WORKFLOW_YAML).unwrap();
let mut context = WorkflowContext::new(
workflow,
[(
"task".to_string(),
serde_json::Value::String("Implement feature".to_string()),
)]
.into_iter()
.collect(),
);
context.set_step_output(
"plan".to_string(),
contract::ParsedOutput {
status: StepStatus::Done,
fields: [(
"REPO".to_string(),
serde_json::Value::String("/path/to/repo".to_string()),
)]
.into_iter()
.collect(),
raw_output: "STATUS: done\nREPO: /path/to/repo".to_string(),
},
);
let template = "Task: {{task}}, Repo: {{plan.REPO}}";
let result = context.substitute_variables(template);
assert!(result.contains("Implement feature"));
assert!(result.contains("/path/to/repo"));
}
}