Skip to main content

ai_agents_skills/
executor.rs

1use std::sync::Arc;
2
3use ai_agents_core::{AgentError, Result};
4use ai_agents_llm::{ChatMessage, LLMRegistry};
5use ai_agents_tools::ToolRegistry;
6use minijinja::Environment;
7
8use crate::definition::{SkillContext, SkillDefinition, SkillStep};
9
10pub struct SkillExecutor {
11    llm_registry: Arc<LLMRegistry>,
12    tools: Arc<ToolRegistry>,
13}
14
15impl SkillExecutor {
16    pub fn new(llm_registry: Arc<LLMRegistry>, tools: Arc<ToolRegistry>) -> Self {
17        Self {
18            llm_registry,
19            tools,
20        }
21    }
22
23    pub async fn execute(
24        &self,
25        skill: &SkillDefinition,
26        user_input: &str,
27        extra_context: serde_json::Value,
28    ) -> Result<String> {
29        let mut ctx = SkillContext::new(user_input).with_extra(extra_context);
30
31        for (index, step) in skill.steps.iter().enumerate() {
32            match step {
33                SkillStep::Tool {
34                    tool,
35                    args,
36                    output_as: _,
37                } => {
38                    let rendered_args = self.render_args(args.clone(), &ctx)?;
39                    let tool_impl = self
40                        .tools
41                        .get(tool)
42                        .ok_or_else(|| AgentError::Skill(format!("Tool not found: {}", tool)))?;
43
44                    let result = tool_impl.execute(rendered_args.clone()).await;
45                    eprintln!("[Skill] Tool '{}' returned: {}", tool, result.output);
46
47                    let result_value: serde_json::Value = serde_json::from_str(&result.output)
48                        .unwrap_or_else(|_| {
49                            serde_json::json!({
50                                "output": result.output,
51                                "success": result.success
52                            })
53                        });
54
55                    ctx.add_result(index, Some(rendered_args), result_value);
56
57                    if !result.success {
58                        return Err(AgentError::Skill(format!(
59                            "Tool '{}' failed: {}",
60                            tool, result.output
61                        )));
62                    }
63                }
64                SkillStep::Prompt { prompt, llm } => {
65                    let rendered_prompt = self.render_prompt(prompt, &ctx)?;
66
67                    let llm_provider = match llm {
68                        Some(alias) => self.llm_registry.get(alias)?,
69                        None => self.llm_registry.default()?,
70                    };
71
72                    let response = llm_provider
73                        .complete(&[ChatMessage::user(&rendered_prompt)], None)
74                        .await
75                        .map_err(|e| AgentError::LLM(e.to_string()))?;
76
77                    // Store prompt result directly as string for simpler template access
78                    let result_value =
79                        serde_json::Value::String(response.content.trim().to_string());
80                    ctx.add_result(index, None, result_value);
81
82                    // Only return on the last step
83                    if index == skill.steps.len() - 1 {
84                        return Ok(response.content);
85                    }
86                }
87            }
88        }
89
90        Err(AgentError::Skill(
91            "Skill has no prompt step to generate response".to_string(),
92        ))
93    }
94
95    fn render_args(
96        &self,
97        args: Option<serde_json::Value>,
98        ctx: &SkillContext,
99    ) -> Result<serde_json::Value> {
100        match args {
101            Some(value) => self.render_value(&value, ctx),
102            None => Ok(serde_json::json!({})),
103        }
104    }
105
106    fn render_value(
107        &self,
108        value: &serde_json::Value,
109        ctx: &SkillContext,
110    ) -> Result<serde_json::Value> {
111        match value {
112            serde_json::Value::String(s) => {
113                let rendered = self.render_template_string(s, ctx)?;
114                Ok(serde_json::Value::String(rendered))
115            }
116            serde_json::Value::Object(map) => {
117                let mut new_map = serde_json::Map::new();
118                for (k, v) in map {
119                    new_map.insert(k.clone(), self.render_value(v, ctx)?);
120                }
121                Ok(serde_json::Value::Object(new_map))
122            }
123            serde_json::Value::Array(arr) => {
124                let new_arr: Result<Vec<_>> =
125                    arr.iter().map(|v| self.render_value(v, ctx)).collect();
126                Ok(serde_json::Value::Array(new_arr?))
127            }
128            other => Ok(other.clone()),
129        }
130    }
131
132    fn render_prompt(&self, template: &str, ctx: &SkillContext) -> Result<String> {
133        self.render_template_string(template, ctx)
134    }
135
136    fn render_template_string(&self, template: &str, ctx: &SkillContext) -> Result<String> {
137        let env = Environment::new();
138
139        let tmpl = env
140            .template_from_str(template)
141            .map_err(|e| AgentError::Skill(format!("Template parse error: {}", e)))?;
142
143        let steps: Vec<serde_json::Value> = ctx
144            .step_results
145            .iter()
146            .map(|step| {
147                serde_json::json!({
148                    "result": step.result,
149                    "args": step.args.as_ref().unwrap_or(&serde_json::json!({}))
150                })
151            })
152            .collect();
153
154        let jinja_ctx = minijinja::context! {
155            user_input => &ctx.user_input,
156            steps => steps,
157            context => &ctx.extra,
158        };
159
160        tmpl.render(jinja_ctx)
161            .map_err(|e| AgentError::Skill(format!("Template render error: {}", e)))
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    fn create_test_context() -> SkillContext {
170        let mut ctx = SkillContext::new("What should I wear?");
171        ctx.add_result(
172            0,
173            Some(serde_json::json!({"location": "Seoul"})),
174            serde_json::json!({"temperature": 15, "condition": "sunny"}),
175        );
176        ctx.extra = serde_json::json!({"user_name": "jay"});
177        ctx
178    }
179
180    #[test]
181    fn test_render_complex_template() {
182        let registry = LLMRegistry::new();
183        let tools = ToolRegistry::new();
184        let executor = SkillExecutor::new(Arc::new(registry), Arc::new(tools));
185
186        let ctx = create_test_context();
187        let template = r#"User {{ context.user_name }} asked: {{ user_input }}
188Current weather in {{ steps[0].args.location }}: {{ steps[0].result.temperature }}°C, {{ steps[0].result.condition }}"#;
189
190        let result = executor.render_template_string(template, &ctx).unwrap();
191        assert!(result.contains("User jay asked: What should I wear?"));
192        assert!(result.contains("Current weather in Seoul: 15°C, sunny"));
193    }
194
195    #[test]
196    fn test_render_with_whitespace_variations() {
197        let registry = LLMRegistry::new();
198        let tools = ToolRegistry::new();
199        let executor = SkillExecutor::new(Arc::new(registry), Arc::new(tools));
200
201        let ctx = create_test_context();
202        let template1 = "{{user_input}}";
203        let template2 = "{{ user_input }}";
204        let template3 = "{{  user_input  }}";
205
206        let result1 = executor.render_template_string(template1, &ctx).unwrap();
207        let result2 = executor.render_template_string(template2, &ctx).unwrap();
208        let result3 = executor.render_template_string(template3, &ctx).unwrap();
209
210        assert_eq!(result1, "What should I wear?");
211        assert_eq!(result2, "What should I wear?");
212        assert_eq!(result3, "What should I wear?");
213    }
214
215    #[test]
216    fn test_render_with_filters() {
217        let registry = LLMRegistry::new();
218        let tools = ToolRegistry::new();
219        let executor = SkillExecutor::new(Arc::new(registry), Arc::new(tools));
220
221        let ctx = create_test_context();
222        let template = "{{ context.user_name | upper }}";
223
224        let result = executor.render_template_string(template, &ctx).unwrap();
225        assert_eq!(result, "JAY");
226    }
227}