Skip to main content

llama_cpp_v3_agent_sdk/
workflow.rs

1use crate::{AgentBuilder, InferenceEngine, InferenceScheduler};
2use std::sync::Arc;
3use anyhow::{Result, anyhow};
4use std::path::PathBuf;
5use serde_json::Value;
6use std::collections::HashMap;
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Deserialize, Serialize)]
11pub struct Workflow {
12    pub name: String,
13    pub steps: Vec<WorkflowStep>,
14}
15
16#[derive(Debug, Deserialize, Serialize, Clone)]
17pub struct WorkflowStep {
18    pub name: String,
19    pub description: String,
20    pub agent_prompt: String,
21    pub temperature: f32,
22    pub repetition_penalty: Option<f32>,
23    pub stop_sequences: Option<Vec<String>>,
24    pub output_type: String,
25    pub output_model: Option<String>,
26    pub conditional: Option<String>,
27    pub input_mapping: Option<HashMap<String, String>>,
28}
29
30#[derive(Debug, Clone, Serialize)]
31pub enum PipelineEvent {
32    StepStarted { name: String, description: String },
33    StepCompleted { name: String },
34    Processing { name: String, message: String },
35    Token { name: String, token: String },
36}
37
38#[derive(Debug, Serialize, Deserialize, Clone)]
39pub struct PostProcessSchema {
40    pub rules: Vec<PostProcessRule>,
41}
42
43#[derive(Debug, Serialize, Deserialize, Clone)]
44pub struct PostProcessRule {
45    pub pattern: String,
46    pub replacement: String,
47}
48
49pub trait WorkflowStorage: Send + Sync {
50    fn insert_artifact(
51        &self,
52        session_id: &str,
53        artifact_type: &str,
54        content: &str,
55        is_json: bool,
56    ) -> Result<()>;
57    
58    fn get_latest_artifacts(&self, session_id: &str) -> Result<HashMap<String, String>>;
59}
60
61pub struct WorkflowEngine {
62    engine: Arc<InferenceEngine>,
63    scheduler: Arc<InferenceScheduler>,
64    storage: Arc<dyn WorkflowStorage>,
65    skills_path: PathBuf,
66}
67
68impl WorkflowEngine {
69    pub fn new(
70        engine: Arc<InferenceEngine>,
71        scheduler: Arc<InferenceScheduler>,
72        storage: Arc<dyn WorkflowStorage>,
73        skills_path: PathBuf,
74    ) -> Self {
75        Self { engine, scheduler, storage, skills_path }
76    }
77
78    pub fn storage(&self) -> Arc<dyn WorkflowStorage> {
79        self.storage.clone()
80    }
81
82    pub async fn run<F>(
83        &self,
84        skill_name: &str,
85        session_id: &str,
86        mut context: Value,
87        resume_from_step: Option<String>,
88        force_regenerate: Vec<String>,
89        mut on_progress: Option<F>,
90    ) -> Result<HashMap<String, String>>
91    where
92        F: FnMut(PipelineEvent),
93    {
94        let workflow = self.load_workflow(skill_name)?;
95        let post_process_schema = self.load_post_process_schema(skill_name).ok();
96        
97        // 1. Load existing artifacts to support resuming
98        let mut results = self.storage.get_latest_artifacts(session_id).unwrap_or_default();
99        
100        // Populate context with already existing results
101        for (name, response) in &results {
102            if let Some(step) = workflow.steps.iter().find(|s| &s.name == name) {
103                if step.output_type == "json" {
104                    if let Ok(val) = self.extract_json(response) {
105                        context.as_object_mut().unwrap().insert(name.clone(), val);
106                    }
107                } else {
108                    context.as_object_mut().unwrap().insert(name.clone(), Value::String(response.clone()));
109                }
110            }
111        }
112
113        let mut found_resume_point = resume_from_step.is_none();
114
115        // Ensure context is an object
116        if !context.is_object() {
117            context = serde_json::json!({});
118        }
119
120        for step in &workflow.steps {
121            // Update resume point tracking
122            if let Some(resume_name) = &resume_from_step {
123                if step.name == *resume_name {
124                    found_resume_point = true;
125                }
126            }
127
128            // Skip step if we haven't reached the resume point AND we already have the result
129            if !found_resume_point && results.contains_key(&step.name) && !force_regenerate.contains(&step.name) {
130                if let Some(f) = &mut on_progress {
131                    f(PipelineEvent::Processing { 
132                        name: step.name.clone(), 
133                        message: format!("Using cached result for {}.", step.name) 
134                    });
135                }
136                continue;
137            }
138
139            // Skip step if we have a result AND it's NOT in force_regenerate
140            if results.contains_key(&step.name) && !force_regenerate.contains(&step.name) {
141                if let Some(f) = &mut on_progress {
142                    f(PipelineEvent::Processing { 
143                        name: step.name.clone(), 
144                        message: format!("Step {} already exists, skipping.", step.name) 
145                    });
146                }
147                continue;
148            }
149
150            // Check conditional dynamically using dot notation (e.g., "critic.rewrite_required")
151            if let Some(cond) = &step.conditional {
152                let parts: Vec<&str> = cond.split('.').collect();
153                if parts.len() == 2 {
154                    let source_step = parts[0];
155                    let field = parts[1];
156                    
157                    if let Some(source_val) = results.get(source_step) {
158                        if let Ok(source_json) = self.extract_json(source_val) {
159                            let should_run = match source_json.get(field) {
160                                Some(Value::Bool(b)) => *b,
161                                Some(Value::String(s)) => s.to_lowercase() == "true",
162                                _ => false, // Default to false if not found or not boolean
163                            };
164                            
165                            if !should_run {
166                                if let Some(f) = &mut on_progress {
167                                    f(PipelineEvent::Processing { 
168                                        name: step.name.clone(), 
169                                        message: format!("Skipping step due to conditional: {} is false.", cond) 
170                                    });
171                                }
172                                continue;
173                            }
174                        } else {
175                            if let Some(f) = &mut on_progress {
176                                f(PipelineEvent::Processing { 
177                                    name: step.name.clone(), 
178                                    message: format!("Warning: Could not extract JSON from {} to evaluate {}.", source_step, cond) 
179                                });
180                            }
181                            continue;
182                        }
183                    } else {
184                        if let Some(f) = &mut on_progress {
185                            f(PipelineEvent::Processing { 
186                                name: step.name.clone(), 
187                                message: format!("Warning: Source step {} not found to evaluate {}.", source_step, cond) 
188                            });
189                        }
190                        continue;
191                    }
192                }
193            }
194
195            if let Some(f) = &mut on_progress {
196                f(PipelineEvent::StepStarted { 
197                    name: step.name.clone(), 
198                    description: step.description.clone() 
199                });
200            }
201
202            let prompt_path = self.skills_path.join(skill_name).join(&step.agent_prompt);
203            let system_prompt = std::fs::read_to_string(&prompt_path)?;
204
205            let mut agent_builder = AgentBuilder::new()
206                .engine(self.engine.clone())
207                .scheduler(self.scheduler.clone())
208                .skills_path(self.skills_path.clone())
209                .activate_skill(skill_name)
210                .skip_builtin_tools()
211                .no_agents_md()
212                .system_prompt(&system_prompt)
213                .temperature(step.temperature);
214
215            if let Some(rp) = step.repetition_penalty {
216                agent_builder = agent_builder.repeat_penalty(rp);
217            }
218            if let Some(stops) = &step.stop_sequences {
219                for stop in stops {
220                    agent_builder = agent_builder.stop_sequence(stop);
221                }
222            }
223
224            let mut agent = agent_builder.build()
225                .map_err(|e| anyhow!("Failed to build agent '{}': {}", step.name, e))?;
226
227            let input = self.resolve_input(step, &context, &results)?;
228
229            let mut response = String::new();
230
231            {
232                let step_name = step.name.clone();
233                let on_progress_ref = &mut on_progress;
234                
235                agent.chat(&format!("Process Input: {}", input), |event| {
236                    if let crate::AgentEvent::TextDelta(text) = event {
237                        response.push_str(&text);
238                        if let Some(f) = on_progress_ref {
239                            f(PipelineEvent::Token { 
240                                name: step_name.clone(), 
241                                token: text 
242                            });
243                        }
244                    }
245                }).map_err(|e| anyhow!("Agent '{}' error: {}", step.name, e))?;
246            }
247
248            // Post-process if schema is available
249            if let Some(schema) = &post_process_schema {
250                if step.name == "writer" || step.name == "rewrite" {
251                    response = self.apply_post_process(&response, schema);
252                }
253            }
254
255            // Explicitly drop agent and its associated llama_context to free VRAM
256            drop(agent);
257            
258            results.insert(step.name.clone(), response.clone());
259            
260            if let Some(f) = &mut on_progress {
261                f(PipelineEvent::StepCompleted { name: step.name.clone() });
262            }
263            
264            // Update context with the result if it's JSON
265            if step.output_type == "json" {
266                let val = self.extract_json(&response)?;
267                context.as_object_mut().unwrap().insert(step.name.clone(), val);
268            } else {
269                context.as_object_mut().unwrap().insert(step.name.clone(), Value::String(response.clone()));
270            }
271
272            // Persist the result as an artifact
273            if let Err(e) = self.storage.insert_artifact(
274                session_id, 
275                &step.name, 
276                &response, 
277                step.output_type == "json"
278            ) {
279                if let Some(f) = &mut on_progress {
280                    f(PipelineEvent::Processing { 
281                        name: step.name.clone(), 
282                        message: format!("Warning: Failed to save artifact: {}", e) 
283                    });
284                }
285            }
286        }
287
288        Ok(results)
289    }
290
291    fn load_workflow(&self, skill_name: &str) -> Result<Workflow> {
292        let path = self.skills_path.join(skill_name).join("workflow.json");
293        let content = std::fs::read_to_string(&path)
294            .map_err(|e| anyhow!("Failed to load workflow.json for skill {}: {}", skill_name, e))?;
295        serde_json::from_str(&content).map_err(|e| anyhow!("Failed to parse workflow.json: {}", e))
296    }
297
298    fn resolve_input(&self, step: &WorkflowStep, context: &Value, results: &HashMap<String, String>) -> Result<String> {
299        if let Some(mapping) = &step.input_mapping {
300            let mut input_obj = serde_json::Map::new();
301            for (key, source_step) in mapping {
302                if let Some(result) = results.get(source_step) {
303                    if let Ok(val) = serde_json::from_str::<Value>(result) {
304                        input_obj.insert(key.clone(), val);
305                    } else if let Ok(val) = self.extract_json(result) {
306                        input_obj.insert(key.clone(), val);
307                    } else {
308                        input_obj.insert(key.clone(), Value::String(result.clone()));
309                    }
310                } else if let Some(val) = context.get(source_step) {
311                    input_obj.insert(key.clone(), val.clone());
312                } else {
313                    return Err(anyhow!("Input mapping error: Source '{}' not found for step '{}'", source_step, step.name));
314                }
315            }
316            Ok(serde_json::to_string(&Value::Object(input_obj))?)
317        } else {
318            Ok(serde_json::to_string(context)?)
319        }
320    }
321
322    fn extract_json(&self, input: &str) -> Result<Value> {
323        let start = input.find('{').ok_or_else(|| anyhow!("No JSON object found in output"))?;
324        let end = input.rfind('}').ok_or_else(|| anyhow!("No closing brace found in output"))?;
325        let json_str = &input[start..=end];
326        let sanitized = json_str.chars()
327            .filter(|c| !c.is_control() || *c == '\n' || *c == '\r' || *c == '\t')
328            .collect::<String>();
329        serde_json::from_str(&sanitized).map_err(|e| anyhow!("JSON parse error: {} (Input: {})", e, sanitized))
330    }
331
332    fn load_post_process_schema(&self, skill_name: &str) -> Result<PostProcessSchema> {
333        let path = self.skills_path.join(skill_name).join("schemas").join("post_process.json");
334        let content = std::fs::read_to_string(&path)
335            .map_err(|e| anyhow!("Failed to load post_process.json for skill {}: {}", skill_name, e))?;
336        serde_json::from_str(&content).map_err(|e| anyhow!("Failed to parse post_process.json: {}", e))
337    }
338
339    fn apply_post_process(&self, input: &str, schema: &PostProcessSchema) -> String {
340        let mut output = input.to_string();
341        for rule in &schema.rules {
342            if let Ok(re) = Regex::new(&rule.pattern) {
343                output = re.replace_all(&output, &rule.replacement).to_string();
344            }
345        }
346        output
347    }
348}