1use crate::{AgentBuilder, InferenceEngine, InferenceScheduler};
2use std::sync::Arc;
3use anyhow::{Result, anyhow};
4use std::path::PathBuf;
5use serde_json::Value;
6use std::collections::HashMap;
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Deserialize, Serialize)]
11pub struct Workflow {
12 pub name: String,
13 pub steps: Vec<WorkflowStep>,
14}
15
16#[derive(Debug, Deserialize, Serialize, Clone)]
17pub struct WorkflowStep {
18 pub name: String,
19 pub description: String,
20 pub agent_prompt: String,
21 pub temperature: f32,
22 pub repetition_penalty: Option<f32>,
23 pub stop_sequences: Option<Vec<String>>,
24 pub output_type: String,
25 pub output_model: Option<String>,
26 pub conditional: Option<String>,
27 pub input_mapping: Option<HashMap<String, String>>,
28}
29
30#[derive(Debug, Clone, Serialize)]
31pub enum PipelineEvent {
32 StepStarted { name: String, description: String },
33 StepCompleted { name: String },
34 Processing { name: String, message: String },
35 Token { name: String, token: String },
36}
37
38#[derive(Debug, Serialize, Deserialize, Clone)]
39pub struct PostProcessSchema {
40 pub rules: Vec<PostProcessRule>,
41}
42
43#[derive(Debug, Serialize, Deserialize, Clone)]
44pub struct PostProcessRule {
45 pub pattern: String,
46 pub replacement: String,
47}
48
49pub trait WorkflowStorage: Send + Sync {
50 fn insert_artifact(
51 &self,
52 session_id: &str,
53 artifact_type: &str,
54 content: &str,
55 is_json: bool,
56 ) -> Result<()>;
57
58 fn get_latest_artifacts(&self, session_id: &str) -> Result<HashMap<String, String>>;
59}
60
61pub struct WorkflowEngine {
62 engine: Arc<InferenceEngine>,
63 scheduler: Arc<InferenceScheduler>,
64 storage: Arc<dyn WorkflowStorage>,
65 skills_path: PathBuf,
66}
67
68impl WorkflowEngine {
69 pub fn new(
70 engine: Arc<InferenceEngine>,
71 scheduler: Arc<InferenceScheduler>,
72 storage: Arc<dyn WorkflowStorage>,
73 skills_path: PathBuf,
74 ) -> Self {
75 Self { engine, scheduler, storage, skills_path }
76 }
77
78 pub fn storage(&self) -> Arc<dyn WorkflowStorage> {
79 self.storage.clone()
80 }
81
82 pub async fn run<F>(
83 &self,
84 skill_name: &str,
85 session_id: &str,
86 mut context: Value,
87 resume_from_step: Option<String>,
88 force_regenerate: Vec<String>,
89 mut on_progress: Option<F>,
90 ) -> Result<HashMap<String, String>>
91 where
92 F: FnMut(PipelineEvent),
93 {
94 let workflow = self.load_workflow(skill_name)?;
95 let post_process_schema = self.load_post_process_schema(skill_name).ok();
96
97 let mut results = self.storage.get_latest_artifacts(session_id).unwrap_or_default();
99
100 for (name, response) in &results {
102 if let Some(step) = workflow.steps.iter().find(|s| &s.name == name) {
103 if step.output_type == "json" {
104 if let Ok(val) = self.extract_json(response) {
105 context.as_object_mut().unwrap().insert(name.clone(), val);
106 }
107 } else {
108 context.as_object_mut().unwrap().insert(name.clone(), Value::String(response.clone()));
109 }
110 }
111 }
112
113 let mut found_resume_point = resume_from_step.is_none();
114
115 if !context.is_object() {
117 context = serde_json::json!({});
118 }
119
120 for step in &workflow.steps {
121 if let Some(resume_name) = &resume_from_step {
123 if step.name == *resume_name {
124 found_resume_point = true;
125 }
126 }
127
128 if !found_resume_point && results.contains_key(&step.name) && !force_regenerate.contains(&step.name) {
130 if let Some(f) = &mut on_progress {
131 f(PipelineEvent::Processing {
132 name: step.name.clone(),
133 message: format!("Using cached result for {}.", step.name)
134 });
135 }
136 continue;
137 }
138
139 if results.contains_key(&step.name) && !force_regenerate.contains(&step.name) {
141 if let Some(f) = &mut on_progress {
142 f(PipelineEvent::Processing {
143 name: step.name.clone(),
144 message: format!("Step {} already exists, skipping.", step.name)
145 });
146 }
147 continue;
148 }
149
150 if let Some(cond) = &step.conditional {
152 let parts: Vec<&str> = cond.split('.').collect();
153 if parts.len() == 2 {
154 let source_step = parts[0];
155 let field = parts[1];
156
157 if let Some(source_val) = results.get(source_step) {
158 if let Ok(source_json) = self.extract_json(source_val) {
159 let should_run = match source_json.get(field) {
160 Some(Value::Bool(b)) => *b,
161 Some(Value::String(s)) => s.to_lowercase() == "true",
162 _ => false, };
164
165 if !should_run {
166 if let Some(f) = &mut on_progress {
167 f(PipelineEvent::Processing {
168 name: step.name.clone(),
169 message: format!("Skipping step due to conditional: {} is false.", cond)
170 });
171 }
172 continue;
173 }
174 } else {
175 if let Some(f) = &mut on_progress {
176 f(PipelineEvent::Processing {
177 name: step.name.clone(),
178 message: format!("Warning: Could not extract JSON from {} to evaluate {}.", source_step, cond)
179 });
180 }
181 continue;
182 }
183 } else {
184 if let Some(f) = &mut on_progress {
185 f(PipelineEvent::Processing {
186 name: step.name.clone(),
187 message: format!("Warning: Source step {} not found to evaluate {}.", source_step, cond)
188 });
189 }
190 continue;
191 }
192 }
193 }
194
195 if let Some(f) = &mut on_progress {
196 f(PipelineEvent::StepStarted {
197 name: step.name.clone(),
198 description: step.description.clone()
199 });
200 }
201
202 let prompt_path = self.skills_path.join(skill_name).join(&step.agent_prompt);
203 let system_prompt = std::fs::read_to_string(&prompt_path)?;
204
205 let mut agent_builder = AgentBuilder::new()
206 .engine(self.engine.clone())
207 .scheduler(self.scheduler.clone())
208 .skills_path(self.skills_path.clone())
209 .activate_skill(skill_name)
210 .skip_builtin_tools()
211 .no_agents_md()
212 .system_prompt(&system_prompt)
213 .temperature(step.temperature);
214
215 if let Some(rp) = step.repetition_penalty {
216 agent_builder = agent_builder.repeat_penalty(rp);
217 }
218 if let Some(stops) = &step.stop_sequences {
219 for stop in stops {
220 agent_builder = agent_builder.stop_sequence(stop);
221 }
222 }
223
224 let mut agent = agent_builder.build()
225 .map_err(|e| anyhow!("Failed to build agent '{}': {}", step.name, e))?;
226
227 let input = self.resolve_input(step, &context, &results)?;
228
229 let mut response = String::new();
230
231 {
232 let step_name = step.name.clone();
233 let on_progress_ref = &mut on_progress;
234
235 agent.chat(&format!("Process Input: {}", input), |event| {
236 if let crate::AgentEvent::TextDelta(text) = event {
237 response.push_str(&text);
238 if let Some(f) = on_progress_ref {
239 f(PipelineEvent::Token {
240 name: step_name.clone(),
241 token: text
242 });
243 }
244 }
245 }).map_err(|e| anyhow!("Agent '{}' error: {}", step.name, e))?;
246 }
247
248 if let Some(schema) = &post_process_schema {
250 if step.name == "writer" || step.name == "rewrite" {
251 response = self.apply_post_process(&response, schema);
252 }
253 }
254
255 drop(agent);
257
258 results.insert(step.name.clone(), response.clone());
259
260 if let Some(f) = &mut on_progress {
261 f(PipelineEvent::StepCompleted { name: step.name.clone() });
262 }
263
264 if step.output_type == "json" {
266 let val = self.extract_json(&response)?;
267 context.as_object_mut().unwrap().insert(step.name.clone(), val);
268 } else {
269 context.as_object_mut().unwrap().insert(step.name.clone(), Value::String(response.clone()));
270 }
271
272 if let Err(e) = self.storage.insert_artifact(
274 session_id,
275 &step.name,
276 &response,
277 step.output_type == "json"
278 ) {
279 if let Some(f) = &mut on_progress {
280 f(PipelineEvent::Processing {
281 name: step.name.clone(),
282 message: format!("Warning: Failed to save artifact: {}", e)
283 });
284 }
285 }
286 }
287
288 Ok(results)
289 }
290
291 fn load_workflow(&self, skill_name: &str) -> Result<Workflow> {
292 let path = self.skills_path.join(skill_name).join("workflow.json");
293 let content = std::fs::read_to_string(&path)
294 .map_err(|e| anyhow!("Failed to load workflow.json for skill {}: {}", skill_name, e))?;
295 serde_json::from_str(&content).map_err(|e| anyhow!("Failed to parse workflow.json: {}", e))
296 }
297
298 fn resolve_input(&self, step: &WorkflowStep, context: &Value, results: &HashMap<String, String>) -> Result<String> {
299 if let Some(mapping) = &step.input_mapping {
300 let mut input_obj = serde_json::Map::new();
301 for (key, source_step) in mapping {
302 if let Some(result) = results.get(source_step) {
303 if let Ok(val) = serde_json::from_str::<Value>(result) {
304 input_obj.insert(key.clone(), val);
305 } else if let Ok(val) = self.extract_json(result) {
306 input_obj.insert(key.clone(), val);
307 } else {
308 input_obj.insert(key.clone(), Value::String(result.clone()));
309 }
310 } else if let Some(val) = context.get(source_step) {
311 input_obj.insert(key.clone(), val.clone());
312 } else {
313 return Err(anyhow!("Input mapping error: Source '{}' not found for step '{}'", source_step, step.name));
314 }
315 }
316 Ok(serde_json::to_string(&Value::Object(input_obj))?)
317 } else {
318 Ok(serde_json::to_string(context)?)
319 }
320 }
321
322 fn extract_json(&self, input: &str) -> Result<Value> {
323 let start = input.find('{').ok_or_else(|| anyhow!("No JSON object found in output"))?;
324 let end = input.rfind('}').ok_or_else(|| anyhow!("No closing brace found in output"))?;
325 let json_str = &input[start..=end];
326 let sanitized = json_str.chars()
327 .filter(|c| !c.is_control() || *c == '\n' || *c == '\r' || *c == '\t')
328 .collect::<String>();
329 serde_json::from_str(&sanitized).map_err(|e| anyhow!("JSON parse error: {} (Input: {})", e, sanitized))
330 }
331
332 fn load_post_process_schema(&self, skill_name: &str) -> Result<PostProcessSchema> {
333 let path = self.skills_path.join(skill_name).join("schemas").join("post_process.json");
334 let content = std::fs::read_to_string(&path)
335 .map_err(|e| anyhow!("Failed to load post_process.json for skill {}: {}", skill_name, e))?;
336 serde_json::from_str(&content).map_err(|e| anyhow!("Failed to parse post_process.json: {}", e))
337 }
338
339 fn apply_post_process(&self, input: &str, schema: &PostProcessSchema) -> String {
340 let mut output = input.to_string();
341 for rule in &schema.rules {
342 if let Ok(re) = Regex::new(&rule.pattern) {
343 output = re.replace_all(&output, &rule.replacement).to_string();
344 }
345 }
346 output
347 }
348}