llama-cpp-v3-agent-sdk 0.1.7

Agentic tool-use loop on top of llama-cpp-v3 — local LLM agents with built-in tools
Documentation
use crate::{AgentBuilder, InferenceEngine, InferenceScheduler};
use std::sync::Arc;
use anyhow::{Result, anyhow};
use std::path::PathBuf;
use serde_json::Value;
use std::collections::HashMap;
use regex::Regex;
use serde::{Deserialize, Serialize};

#[derive(Debug, Deserialize, Serialize)]
pub struct Workflow {
    pub name: String,
    pub steps: Vec<WorkflowStep>,
}

#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct WorkflowStep {
    pub name: String,
    pub description: String,
    pub agent_prompt: String,
    pub temperature: f32,
    pub repetition_penalty: Option<f32>,
    pub stop_sequences: Option<Vec<String>>,
    pub output_type: String,
    pub output_model: Option<String>,
    pub conditional: Option<String>,
    pub input_mapping: Option<HashMap<String, String>>,
}

#[derive(Debug, Clone, Serialize)]
pub enum PipelineEvent {
    StepStarted { name: String, description: String },
    StepCompleted { name: String },
    Processing { name: String, message: String },
    Token { name: String, token: String },
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PostProcessSchema {
    pub rules: Vec<PostProcessRule>,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PostProcessRule {
    pub pattern: String,
    pub replacement: String,
}

pub trait WorkflowStorage: Send + Sync {
    fn insert_artifact(
        &self,
        session_id: &str,
        artifact_type: &str,
        content: &str,
        is_json: bool,
    ) -> Result<()>;
    
    fn get_latest_artifacts(&self, session_id: &str) -> Result<HashMap<String, String>>;
}

pub struct WorkflowEngine {
    engine: Arc<InferenceEngine>,
    scheduler: Arc<InferenceScheduler>,
    storage: Arc<dyn WorkflowStorage>,
    skills_path: PathBuf,
}

impl WorkflowEngine {
    pub fn new(
        engine: Arc<InferenceEngine>,
        scheduler: Arc<InferenceScheduler>,
        storage: Arc<dyn WorkflowStorage>,
        skills_path: PathBuf,
    ) -> Self {
        Self { engine, scheduler, storage, skills_path }
    }

    pub fn storage(&self) -> Arc<dyn WorkflowStorage> {
        self.storage.clone()
    }

    pub async fn run<F>(
        &self,
        skill_name: &str,
        session_id: &str,
        mut context: Value,
        resume_from_step: Option<String>,
        force_regenerate: Vec<String>,
        mut on_progress: Option<F>,
    ) -> Result<HashMap<String, String>>
    where
        F: FnMut(PipelineEvent),
    {
        let workflow = self.load_workflow(skill_name)?;
        let post_process_schema = self.load_post_process_schema(skill_name).ok();
        
        // 1. Load existing artifacts to support resuming
        let mut results = self.storage.get_latest_artifacts(session_id).unwrap_or_default();
        
        // Populate context with already existing results
        for (name, response) in &results {
            if let Some(step) = workflow.steps.iter().find(|s| &s.name == name) {
                if step.output_type == "json" {
                    if let Ok(val) = self.extract_json(response) {
                        context.as_object_mut().unwrap().insert(name.clone(), val);
                    }
                } else {
                    context.as_object_mut().unwrap().insert(name.clone(), Value::String(response.clone()));
                }
            }
        }

        let mut found_resume_point = resume_from_step.is_none();

        // Ensure context is an object
        if !context.is_object() {
            context = serde_json::json!({});
        }

        for step in &workflow.steps {
            // Update resume point tracking
            if let Some(resume_name) = &resume_from_step {
                if step.name == *resume_name {
                    found_resume_point = true;
                }
            }

            // Skip step if we haven't reached the resume point AND we already have the result
            if !found_resume_point && results.contains_key(&step.name) && !force_regenerate.contains(&step.name) {
                if let Some(f) = &mut on_progress {
                    f(PipelineEvent::Processing { 
                        name: step.name.clone(), 
                        message: format!("Using cached result for {}.", step.name) 
                    });
                }
                continue;
            }

            // Skip step if we have a result AND it's NOT in force_regenerate
            if results.contains_key(&step.name) && !force_regenerate.contains(&step.name) {
                if let Some(f) = &mut on_progress {
                    f(PipelineEvent::Processing { 
                        name: step.name.clone(), 
                        message: format!("Step {} already exists, skipping.", step.name) 
                    });
                }
                continue;
            }

            // Check conditional dynamically using dot notation (e.g., "critic.rewrite_required")
            if let Some(cond) = &step.conditional {
                let parts: Vec<&str> = cond.split('.').collect();
                if parts.len() == 2 {
                    let source_step = parts[0];
                    let field = parts[1];
                    
                    if let Some(source_val) = results.get(source_step) {
                        if let Ok(source_json) = self.extract_json(source_val) {
                            let should_run = match source_json.get(field) {
                                Some(Value::Bool(b)) => *b,
                                Some(Value::String(s)) => s.to_lowercase() == "true",
                                _ => false, // Default to false if not found or not boolean
                            };
                            
                            if !should_run {
                                if let Some(f) = &mut on_progress {
                                    f(PipelineEvent::Processing { 
                                        name: step.name.clone(), 
                                        message: format!("Skipping step due to conditional: {} is false.", cond) 
                                    });
                                }
                                continue;
                            }
                        } else {
                            if let Some(f) = &mut on_progress {
                                f(PipelineEvent::Processing { 
                                    name: step.name.clone(), 
                                    message: format!("Warning: Could not extract JSON from {} to evaluate {}.", source_step, cond) 
                                });
                            }
                            continue;
                        }
                    } else {
                        if let Some(f) = &mut on_progress {
                            f(PipelineEvent::Processing { 
                                name: step.name.clone(), 
                                message: format!("Warning: Source step {} not found to evaluate {}.", source_step, cond) 
                            });
                        }
                        continue;
                    }
                }
            }

            if let Some(f) = &mut on_progress {
                f(PipelineEvent::StepStarted { 
                    name: step.name.clone(), 
                    description: step.description.clone() 
                });
            }

            let prompt_path = self.skills_path.join(skill_name).join(&step.agent_prompt);
            let system_prompt = std::fs::read_to_string(&prompt_path)?;

            let mut agent_builder = AgentBuilder::new()
                .engine(self.engine.clone())
                .scheduler(self.scheduler.clone())
                .skills_path(self.skills_path.clone())
                .activate_skill(skill_name)
                .skip_builtin_tools()
                .no_agents_md()
                .system_prompt(&system_prompt)
                .temperature(step.temperature);

            if let Some(rp) = step.repetition_penalty {
                agent_builder = agent_builder.repeat_penalty(rp);
            }
            if let Some(stops) = &step.stop_sequences {
                for stop in stops {
                    agent_builder = agent_builder.stop_sequence(stop);
                }
            }

            let mut agent = agent_builder.build()
                .map_err(|e| anyhow!("Failed to build agent '{}': {}", step.name, e))?;

            let input = self.resolve_input(step, &context, &results)?;

            let mut response = String::new();

            {
                let step_name = step.name.clone();
                let on_progress_ref = &mut on_progress;
                
                agent.chat(&format!("Process Input: {}", input), |event| {
                    if let crate::AgentEvent::TextDelta(text) = event {
                        response.push_str(&text);
                        if let Some(f) = on_progress_ref {
                            f(PipelineEvent::Token { 
                                name: step_name.clone(), 
                                token: text 
                            });
                        }
                    }
                }).map_err(|e| anyhow!("Agent '{}' error: {}", step.name, e))?;
            }

            // Post-process if schema is available
            if let Some(schema) = &post_process_schema {
                if step.name == "writer" || step.name == "rewrite" {
                    response = self.apply_post_process(&response, schema);
                }
            }

            // Explicitly drop agent and its associated llama_context to free VRAM
            drop(agent);
            
            results.insert(step.name.clone(), response.clone());
            
            if let Some(f) = &mut on_progress {
                f(PipelineEvent::StepCompleted { name: step.name.clone() });
            }
            
            // Update context with the result if it's JSON
            if step.output_type == "json" {
                let val = self.extract_json(&response)?;
                context.as_object_mut().unwrap().insert(step.name.clone(), val);
            } else {
                context.as_object_mut().unwrap().insert(step.name.clone(), Value::String(response.clone()));
            }

            // Persist the result as an artifact
            if let Err(e) = self.storage.insert_artifact(
                session_id, 
                &step.name, 
                &response, 
                step.output_type == "json"
            ) {
                if let Some(f) = &mut on_progress {
                    f(PipelineEvent::Processing { 
                        name: step.name.clone(), 
                        message: format!("Warning: Failed to save artifact: {}", e) 
                    });
                }
            }
        }

        Ok(results)
    }

    fn load_workflow(&self, skill_name: &str) -> Result<Workflow> {
        let path = self.skills_path.join(skill_name).join("workflow.json");
        let content = std::fs::read_to_string(&path)
            .map_err(|e| anyhow!("Failed to load workflow.json for skill {}: {}", skill_name, e))?;
        serde_json::from_str(&content).map_err(|e| anyhow!("Failed to parse workflow.json: {}", e))
    }

    fn resolve_input(&self, step: &WorkflowStep, context: &Value, results: &HashMap<String, String>) -> Result<String> {
        if let Some(mapping) = &step.input_mapping {
            let mut input_obj = serde_json::Map::new();
            for (key, source_step) in mapping {
                if let Some(result) = results.get(source_step) {
                    if let Ok(val) = serde_json::from_str::<Value>(result) {
                        input_obj.insert(key.clone(), val);
                    } else if let Ok(val) = self.extract_json(result) {
                        input_obj.insert(key.clone(), val);
                    } else {
                        input_obj.insert(key.clone(), Value::String(result.clone()));
                    }
                } else if let Some(val) = context.get(source_step) {
                    input_obj.insert(key.clone(), val.clone());
                } else {
                    return Err(anyhow!("Input mapping error: Source '{}' not found for step '{}'", source_step, step.name));
                }
            }
            Ok(serde_json::to_string(&Value::Object(input_obj))?)
        } else {
            Ok(serde_json::to_string(context)?)
        }
    }

    fn extract_json(&self, input: &str) -> Result<Value> {
        let start = input.find('{').ok_or_else(|| anyhow!("No JSON object found in output"))?;
        let end = input.rfind('}').ok_or_else(|| anyhow!("No closing brace found in output"))?;
        let json_str = &input[start..=end];
        let sanitized = json_str.chars()
            .filter(|c| !c.is_control() || *c == '\n' || *c == '\r' || *c == '\t')
            .collect::<String>();
        serde_json::from_str(&sanitized).map_err(|e| anyhow!("JSON parse error: {} (Input: {})", e, sanitized))
    }

    fn load_post_process_schema(&self, skill_name: &str) -> Result<PostProcessSchema> {
        let path = self.skills_path.join(skill_name).join("schemas").join("post_process.json");
        let content = std::fs::read_to_string(&path)
            .map_err(|e| anyhow!("Failed to load post_process.json for skill {}: {}", skill_name, e))?;
        serde_json::from_str(&content).map_err(|e| anyhow!("Failed to parse post_process.json: {}", e))
    }

    fn apply_post_process(&self, input: &str, schema: &PostProcessSchema) -> String {
        let mut output = input.to_string();
        for rule in &schema.rules {
            if let Ok(re) = Regex::new(&rule.pattern) {
                output = re.replace_all(&output, &rule.replacement).to_string();
            }
        }
        output
    }
}