use crate::{AgentBuilder, InferenceEngine, InferenceScheduler};
use std::sync::Arc;
use anyhow::{Result, anyhow};
use std::path::PathBuf;
use serde_json::Value;
use std::collections::HashMap;
use regex::Regex;
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize, Serialize)]
pub struct Workflow {
pub name: String,
pub steps: Vec<WorkflowStep>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct WorkflowStep {
pub name: String,
pub description: String,
pub agent_prompt: String,
pub temperature: f32,
pub repetition_penalty: Option<f32>,
pub stop_sequences: Option<Vec<String>>,
pub output_type: String,
pub output_model: Option<String>,
pub conditional: Option<String>,
pub input_mapping: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize)]
pub enum PipelineEvent {
StepStarted { name: String, description: String },
StepCompleted { name: String },
Processing { name: String, message: String },
Token { name: String, token: String },
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PostProcessSchema {
pub rules: Vec<PostProcessRule>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PostProcessRule {
pub pattern: String,
pub replacement: String,
}
pub trait WorkflowStorage: Send + Sync {
fn insert_artifact(
&self,
session_id: &str,
artifact_type: &str,
content: &str,
is_json: bool,
) -> Result<()>;
fn get_latest_artifacts(&self, session_id: &str) -> Result<HashMap<String, String>>;
}
pub struct WorkflowEngine {
engine: Arc<InferenceEngine>,
scheduler: Arc<InferenceScheduler>,
storage: Arc<dyn WorkflowStorage>,
skills_path: PathBuf,
}
impl WorkflowEngine {
pub fn new(
engine: Arc<InferenceEngine>,
scheduler: Arc<InferenceScheduler>,
storage: Arc<dyn WorkflowStorage>,
skills_path: PathBuf,
) -> Self {
Self { engine, scheduler, storage, skills_path }
}
pub fn storage(&self) -> Arc<dyn WorkflowStorage> {
self.storage.clone()
}
pub async fn run<F>(
&self,
skill_name: &str,
session_id: &str,
mut context: Value,
resume_from_step: Option<String>,
force_regenerate: Vec<String>,
mut on_progress: Option<F>,
) -> Result<HashMap<String, String>>
where
F: FnMut(PipelineEvent),
{
let workflow = self.load_workflow(skill_name)?;
let post_process_schema = self.load_post_process_schema(skill_name).ok();
let mut results = self.storage.get_latest_artifacts(session_id).unwrap_or_default();
for (name, response) in &results {
if let Some(step) = workflow.steps.iter().find(|s| &s.name == name) {
if step.output_type == "json" {
if let Ok(val) = self.extract_json(response) {
context.as_object_mut().unwrap().insert(name.clone(), val);
}
} else {
context.as_object_mut().unwrap().insert(name.clone(), Value::String(response.clone()));
}
}
}
let mut found_resume_point = resume_from_step.is_none();
if !context.is_object() {
context = serde_json::json!({});
}
for step in &workflow.steps {
if let Some(resume_name) = &resume_from_step {
if step.name == *resume_name {
found_resume_point = true;
}
}
if !found_resume_point && results.contains_key(&step.name) && !force_regenerate.contains(&step.name) {
if let Some(f) = &mut on_progress {
f(PipelineEvent::Processing {
name: step.name.clone(),
message: format!("Using cached result for {}.", step.name)
});
}
continue;
}
if results.contains_key(&step.name) && !force_regenerate.contains(&step.name) {
if let Some(f) = &mut on_progress {
f(PipelineEvent::Processing {
name: step.name.clone(),
message: format!("Step {} already exists, skipping.", step.name)
});
}
continue;
}
if let Some(cond) = &step.conditional {
let parts: Vec<&str> = cond.split('.').collect();
if parts.len() == 2 {
let source_step = parts[0];
let field = parts[1];
if let Some(source_val) = results.get(source_step) {
if let Ok(source_json) = self.extract_json(source_val) {
let should_run = match source_json.get(field) {
Some(Value::Bool(b)) => *b,
Some(Value::String(s)) => s.to_lowercase() == "true",
_ => false, };
if !should_run {
if let Some(f) = &mut on_progress {
f(PipelineEvent::Processing {
name: step.name.clone(),
message: format!("Skipping step due to conditional: {} is false.", cond)
});
}
continue;
}
} else {
if let Some(f) = &mut on_progress {
f(PipelineEvent::Processing {
name: step.name.clone(),
message: format!("Warning: Could not extract JSON from {} to evaluate {}.", source_step, cond)
});
}
continue;
}
} else {
if let Some(f) = &mut on_progress {
f(PipelineEvent::Processing {
name: step.name.clone(),
message: format!("Warning: Source step {} not found to evaluate {}.", source_step, cond)
});
}
continue;
}
}
}
if let Some(f) = &mut on_progress {
f(PipelineEvent::StepStarted {
name: step.name.clone(),
description: step.description.clone()
});
}
let prompt_path = self.skills_path.join(skill_name).join(&step.agent_prompt);
let system_prompt = std::fs::read_to_string(&prompt_path)?;
let mut agent_builder = AgentBuilder::new()
.engine(self.engine.clone())
.scheduler(self.scheduler.clone())
.skills_path(self.skills_path.clone())
.activate_skill(skill_name)
.skip_builtin_tools()
.no_agents_md()
.system_prompt(&system_prompt)
.temperature(step.temperature);
if let Some(rp) = step.repetition_penalty {
agent_builder = agent_builder.repeat_penalty(rp);
}
if let Some(stops) = &step.stop_sequences {
for stop in stops {
agent_builder = agent_builder.stop_sequence(stop);
}
}
let mut agent = agent_builder.build()
.map_err(|e| anyhow!("Failed to build agent '{}': {}", step.name, e))?;
let input = self.resolve_input(step, &context, &results)?;
let mut response = String::new();
{
let step_name = step.name.clone();
let on_progress_ref = &mut on_progress;
agent.chat(&format!("Process Input: {}", input), |event| {
if let crate::AgentEvent::TextDelta(text) = event {
response.push_str(&text);
if let Some(f) = on_progress_ref {
f(PipelineEvent::Token {
name: step_name.clone(),
token: text
});
}
}
}).map_err(|e| anyhow!("Agent '{}' error: {}", step.name, e))?;
}
if let Some(schema) = &post_process_schema {
if step.name == "writer" || step.name == "rewrite" {
response = self.apply_post_process(&response, schema);
}
}
drop(agent);
results.insert(step.name.clone(), response.clone());
if let Some(f) = &mut on_progress {
f(PipelineEvent::StepCompleted { name: step.name.clone() });
}
if step.output_type == "json" {
let val = self.extract_json(&response)?;
context.as_object_mut().unwrap().insert(step.name.clone(), val);
} else {
context.as_object_mut().unwrap().insert(step.name.clone(), Value::String(response.clone()));
}
if let Err(e) = self.storage.insert_artifact(
session_id,
&step.name,
&response,
step.output_type == "json"
) {
if let Some(f) = &mut on_progress {
f(PipelineEvent::Processing {
name: step.name.clone(),
message: format!("Warning: Failed to save artifact: {}", e)
});
}
}
}
Ok(results)
}
fn load_workflow(&self, skill_name: &str) -> Result<Workflow> {
let path = self.skills_path.join(skill_name).join("workflow.json");
let content = std::fs::read_to_string(&path)
.map_err(|e| anyhow!("Failed to load workflow.json for skill {}: {}", skill_name, e))?;
serde_json::from_str(&content).map_err(|e| anyhow!("Failed to parse workflow.json: {}", e))
}
fn resolve_input(&self, step: &WorkflowStep, context: &Value, results: &HashMap<String, String>) -> Result<String> {
if let Some(mapping) = &step.input_mapping {
let mut input_obj = serde_json::Map::new();
for (key, source_step) in mapping {
if let Some(result) = results.get(source_step) {
if let Ok(val) = serde_json::from_str::<Value>(result) {
input_obj.insert(key.clone(), val);
} else if let Ok(val) = self.extract_json(result) {
input_obj.insert(key.clone(), val);
} else {
input_obj.insert(key.clone(), Value::String(result.clone()));
}
} else if let Some(val) = context.get(source_step) {
input_obj.insert(key.clone(), val.clone());
} else {
return Err(anyhow!("Input mapping error: Source '{}' not found for step '{}'", source_step, step.name));
}
}
Ok(serde_json::to_string(&Value::Object(input_obj))?)
} else {
Ok(serde_json::to_string(context)?)
}
}
fn extract_json(&self, input: &str) -> Result<Value> {
let start = input.find('{').ok_or_else(|| anyhow!("No JSON object found in output"))?;
let end = input.rfind('}').ok_or_else(|| anyhow!("No closing brace found in output"))?;
let json_str = &input[start..=end];
let sanitized = json_str.chars()
.filter(|c| !c.is_control() || *c == '\n' || *c == '\r' || *c == '\t')
.collect::<String>();
serde_json::from_str(&sanitized).map_err(|e| anyhow!("JSON parse error: {} (Input: {})", e, sanitized))
}
fn load_post_process_schema(&self, skill_name: &str) -> Result<PostProcessSchema> {
let path = self.skills_path.join(skill_name).join("schemas").join("post_process.json");
let content = std::fs::read_to_string(&path)
.map_err(|e| anyhow!("Failed to load post_process.json for skill {}: {}", skill_name, e))?;
serde_json::from_str(&content).map_err(|e| anyhow!("Failed to parse post_process.json: {}", e))
}
fn apply_post_process(&self, input: &str, schema: &PostProcessSchema) -> String {
let mut output = input.to_string();
for rule in &schema.rules {
if let Ok(re) = Regex::new(&rule.pattern) {
output = re.replace_all(&output, &rule.replacement).to_string();
}
}
output
}
}