ai_agents_skills/
executor.rs1use std::sync::Arc;
2
3use ai_agents_core::{AgentError, Result};
4use ai_agents_llm::{ChatMessage, LLMRegistry};
5use ai_agents_tools::ToolRegistry;
6use minijinja::Environment;
7
8use crate::definition::{SkillContext, SkillDefinition, SkillStep};
9
10pub struct SkillExecutor {
11 llm_registry: Arc<LLMRegistry>,
12 tools: Arc<ToolRegistry>,
13}
14
15impl SkillExecutor {
16 pub fn new(llm_registry: Arc<LLMRegistry>, tools: Arc<ToolRegistry>) -> Self {
17 Self {
18 llm_registry,
19 tools,
20 }
21 }
22
23 pub async fn execute(
24 &self,
25 skill: &SkillDefinition,
26 user_input: &str,
27 extra_context: serde_json::Value,
28 ) -> Result<String> {
29 let mut ctx = SkillContext::new(user_input).with_extra(extra_context);
30
31 for (index, step) in skill.steps.iter().enumerate() {
32 match step {
33 SkillStep::Tool {
34 tool,
35 args,
36 output_as: _,
37 } => {
38 let rendered_args = self.render_args(args.clone(), &ctx)?;
39 let tool_impl = self
40 .tools
41 .get(tool)
42 .ok_or_else(|| AgentError::Skill(format!("Tool not found: {}", tool)))?;
43
44 let result = tool_impl.execute(rendered_args.clone()).await;
45 eprintln!("[Skill] Tool '{}' returned: {}", tool, result.output);
46
47 let result_value: serde_json::Value = serde_json::from_str(&result.output)
48 .unwrap_or_else(|_| {
49 serde_json::json!({
50 "output": result.output,
51 "success": result.success
52 })
53 });
54
55 ctx.add_result(index, Some(rendered_args), result_value);
56
57 if !result.success {
58 return Err(AgentError::Skill(format!(
59 "Tool '{}' failed: {}",
60 tool, result.output
61 )));
62 }
63 }
64 SkillStep::Prompt { prompt, llm } => {
65 let rendered_prompt = self.render_prompt(prompt, &ctx)?;
66
67 let llm_provider = match llm {
68 Some(alias) => self.llm_registry.get(alias)?,
69 None => self.llm_registry.default()?,
70 };
71
72 let response = llm_provider
73 .complete(&[ChatMessage::user(&rendered_prompt)], None)
74 .await
75 .map_err(|e| AgentError::LLM(e.to_string()))?;
76
77 let result_value =
79 serde_json::Value::String(response.content.trim().to_string());
80 ctx.add_result(index, None, result_value);
81
82 if index == skill.steps.len() - 1 {
84 return Ok(response.content);
85 }
86 }
87 }
88 }
89
90 Err(AgentError::Skill(
91 "Skill has no prompt step to generate response".to_string(),
92 ))
93 }
94
95 fn render_args(
96 &self,
97 args: Option<serde_json::Value>,
98 ctx: &SkillContext,
99 ) -> Result<serde_json::Value> {
100 match args {
101 Some(value) => self.render_value(&value, ctx),
102 None => Ok(serde_json::json!({})),
103 }
104 }
105
106 fn render_value(
107 &self,
108 value: &serde_json::Value,
109 ctx: &SkillContext,
110 ) -> Result<serde_json::Value> {
111 match value {
112 serde_json::Value::String(s) => {
113 let rendered = self.render_template_string(s, ctx)?;
114 Ok(serde_json::Value::String(rendered))
115 }
116 serde_json::Value::Object(map) => {
117 let mut new_map = serde_json::Map::new();
118 for (k, v) in map {
119 new_map.insert(k.clone(), self.render_value(v, ctx)?);
120 }
121 Ok(serde_json::Value::Object(new_map))
122 }
123 serde_json::Value::Array(arr) => {
124 let new_arr: Result<Vec<_>> =
125 arr.iter().map(|v| self.render_value(v, ctx)).collect();
126 Ok(serde_json::Value::Array(new_arr?))
127 }
128 other => Ok(other.clone()),
129 }
130 }
131
132 fn render_prompt(&self, template: &str, ctx: &SkillContext) -> Result<String> {
133 self.render_template_string(template, ctx)
134 }
135
136 fn render_template_string(&self, template: &str, ctx: &SkillContext) -> Result<String> {
137 let env = Environment::new();
138
139 let tmpl = env
140 .template_from_str(template)
141 .map_err(|e| AgentError::Skill(format!("Template parse error: {}", e)))?;
142
143 let steps: Vec<serde_json::Value> = ctx
144 .step_results
145 .iter()
146 .map(|step| {
147 serde_json::json!({
148 "result": step.result,
149 "args": step.args.as_ref().unwrap_or(&serde_json::json!({}))
150 })
151 })
152 .collect();
153
154 let jinja_ctx = minijinja::context! {
155 user_input => &ctx.user_input,
156 steps => steps,
157 context => &ctx.extra,
158 };
159
160 tmpl.render(jinja_ctx)
161 .map_err(|e| AgentError::Skill(format!("Template render error: {}", e)))
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 fn create_test_context() -> SkillContext {
170 let mut ctx = SkillContext::new("What should I wear?");
171 ctx.add_result(
172 0,
173 Some(serde_json::json!({"location": "Seoul"})),
174 serde_json::json!({"temperature": 15, "condition": "sunny"}),
175 );
176 ctx.extra = serde_json::json!({"user_name": "jay"});
177 ctx
178 }
179
180 #[test]
181 fn test_render_complex_template() {
182 let registry = LLMRegistry::new();
183 let tools = ToolRegistry::new();
184 let executor = SkillExecutor::new(Arc::new(registry), Arc::new(tools));
185
186 let ctx = create_test_context();
187 let template = r#"User {{ context.user_name }} asked: {{ user_input }}
188Current weather in {{ steps[0].args.location }}: {{ steps[0].result.temperature }}°C, {{ steps[0].result.condition }}"#;
189
190 let result = executor.render_template_string(template, &ctx).unwrap();
191 assert!(result.contains("User jay asked: What should I wear?"));
192 assert!(result.contains("Current weather in Seoul: 15°C, sunny"));
193 }
194
195 #[test]
196 fn test_render_with_whitespace_variations() {
197 let registry = LLMRegistry::new();
198 let tools = ToolRegistry::new();
199 let executor = SkillExecutor::new(Arc::new(registry), Arc::new(tools));
200
201 let ctx = create_test_context();
202 let template1 = "{{user_input}}";
203 let template2 = "{{ user_input }}";
204 let template3 = "{{ user_input }}";
205
206 let result1 = executor.render_template_string(template1, &ctx).unwrap();
207 let result2 = executor.render_template_string(template2, &ctx).unwrap();
208 let result3 = executor.render_template_string(template3, &ctx).unwrap();
209
210 assert_eq!(result1, "What should I wear?");
211 assert_eq!(result2, "What should I wear?");
212 assert_eq!(result3, "What should I wear?");
213 }
214
215 #[test]
216 fn test_render_with_filters() {
217 let registry = LLMRegistry::new();
218 let tools = ToolRegistry::new();
219 let executor = SkillExecutor::new(Arc::new(registry), Arc::new(tools));
220
221 let ctx = create_test_context();
222 let template = "{{ context.user_name | upper }}";
223
224 let result = executor.render_template_string(template, &ctx).unwrap();
225 assert_eq!(result, "JAY");
226 }
227}