gent/runtime/
agent.rs

1//! Agent execution for GENT
2
3use crate::errors::{GentError, GentResult};
4use crate::interpreter::{AgentValue, OutputSchema};
5use crate::logging::{LogLevel, Logger, NullLogger};
6use crate::runtime::validation::validate_output;
7use crate::runtime::{LLMClient, LLMResponse, Message, ToolDefinition, ToolRegistry, ToolResult};
8
9const DEFAULT_MAX_STEPS: u32 = 10;
10
11/// Run an agent with the given input (simple, no tools)
12pub async fn run_agent(
13    agent: &AgentValue,
14    input: Option<String>,
15    llm: &dyn LLMClient,
16) -> GentResult<String> {
17    let registry = ToolRegistry::new();
18    let logger = NullLogger;
19    run_agent_with_tools(agent, input, llm, &registry, &logger).await
20}
21
22/// Run an agent with tools
23pub async fn run_agent_with_tools(
24    agent: &AgentValue,
25    input: Option<String>,
26    llm: &dyn LLMClient,
27    tools: &ToolRegistry,
28    logger: &dyn Logger,
29) -> GentResult<String> {
30    let max_steps = agent.max_steps.unwrap_or(DEFAULT_MAX_STEPS);
31    let tool_defs = tools.definitions_for(&agent.tools);
32    let model = agent.model.as_deref();
33    let json_mode = agent.output_schema.is_some();
34
35    logger.log(
36        LogLevel::Debug,
37        "agent",
38        &format!("Agent '{}' requested tools: {:?}", agent.name, agent.tools),
39    );
40    logger.log(
41        LogLevel::Debug,
42        "agent",
43        &format!("Tool definitions provided to LLM: {}", tool_defs.len()),
44    );
45    for def in &tool_defs {
46        logger.log(
47            LogLevel::Trace,
48            "agent",
49            &format!("  - {} : {}", def.name, def.description),
50        );
51    }
52
53    // Build messages based on which prompts are present
54    let mut messages = Vec::new();
55
56    // Add system message if prompt is not empty
57    if !agent.system_prompt.is_empty() {
58        let system_prompt = if let Some(schema) = &agent.output_schema {
59            logger.log(
60                LogLevel::Debug,
61                "agent",
62                "Agent has output schema, enabling JSON mode",
63            );
64            let default_instructions = "You must respond with JSON matching this schema:";
65            let instructions = agent
66                .output_instructions
67                .as_deref()
68                .unwrap_or(default_instructions);
69            format!(
70                "{}\n\n{}\n{}",
71                agent.system_prompt,
72                instructions,
73                serde_json::to_string_pretty(&schema.to_json_schema())
74                    .unwrap_or_else(|_| "<schema>".to_string())
75            )
76        } else {
77            agent.system_prompt.clone()
78        };
79        messages.push(Message::system(&system_prompt));
80    }
81
82    // Add user message from agent's user_prompt or from input parameter
83    if let Some(user_prompt) = &agent.user_prompt {
84        messages.push(Message::user(user_prompt.clone()));
85    } else if let Some(user_input) = input {
86        messages.push(Message::user(user_input));
87    }
88
89    // If no messages at all, return empty result
90    if messages.is_empty() {
91        logger.log(
92            LogLevel::Debug,
93            "agent",
94            "No prompts provided, returning empty result",
95        );
96        return Ok(String::new());
97    }
98
99    for step in 0..max_steps {
100        logger.log(
101            LogLevel::Debug,
102            "agent",
103            &format!("Step {}/{}", step + 1, max_steps),
104        );
105        let response = llm
106            .chat(messages.clone(), tool_defs.clone(), model, json_mode)
107            .await?;
108
109        // If no tool calls, validate and return the response content
110        if response.tool_calls.is_empty() {
111            logger.log(
112                LogLevel::Debug,
113                "agent",
114                "No tool calls, returning response",
115            );
116            let content = response.content.unwrap_or_default();
117
118            // Validate output if schema exists
119            if let Some(schema) = &agent.output_schema {
120                return validate_and_retry_output(
121                    &content, schema, agent, &messages, llm, &tool_defs, model, logger,
122                )
123                .await;
124            }
125
126            return Ok(content);
127        }
128
129        logger.log(
130            LogLevel::Debug,
131            "agent",
132            &format!("LLM made {} tool call(s)", response.tool_calls.len()),
133        );
134        for call in &response.tool_calls {
135            logger.log(
136                LogLevel::Trace,
137                "agent",
138                &format!("  - {}({})", call.name, call.arguments),
139            );
140        }
141
142        // Add assistant message with tool calls
143        messages.push(Message::assistant_with_tool_calls(
144            response.tool_calls.clone(),
145        ));
146
147        // Execute each tool call
148        for call in &response.tool_calls {
149            let result = match tools.get(&call.name) {
150                Some(tool) => match tool.execute(call.arguments.clone()).await {
151                    Ok(output) => {
152                        logger.log(
153                            LogLevel::Debug,
154                            "agent",
155                            &format!("Tool '{}' returned: {}", call.name, output),
156                        );
157                        ToolResult {
158                            call_id: call.id.clone(),
159                            content: output,
160                            is_error: false,
161                        }
162                    }
163                    Err(error) => {
164                        logger.log(
165                            LogLevel::Warn,
166                            "agent",
167                            &format!("Tool '{}' error: {}", call.name, error),
168                        );
169                        ToolResult {
170                            call_id: call.id.clone(),
171                            content: error,
172                            is_error: true,
173                        }
174                    }
175                },
176                None => {
177                    logger.log(
178                        LogLevel::Warn,
179                        "agent",
180                        &format!("Unknown tool: {}", call.name),
181                    );
182                    ToolResult {
183                        call_id: call.id.clone(),
184                        content: format!("Unknown tool: {}", call.name),
185                        is_error: true,
186                    }
187                }
188            };
189
190            messages.push(Message::tool_result(result));
191        }
192    }
193
194    Err(GentError::MaxStepsExceeded { limit: max_steps })
195}
196
197/// Run an agent and return the full LLM response
198pub async fn run_agent_full(
199    agent: &AgentValue,
200    input: Option<String>,
201    llm: &dyn LLMClient,
202) -> GentResult<LLMResponse> {
203    // Build messages based on which prompts are present
204    let mut messages = Vec::new();
205
206    // Add system message if prompt is not empty
207    if !agent.system_prompt.is_empty() {
208        messages.push(Message::system(&agent.system_prompt));
209    }
210
211    // Add user message from agent's user_prompt or from input parameter
212    if let Some(user_prompt) = &agent.user_prompt {
213        messages.push(Message::user(user_prompt.clone()));
214    } else if let Some(user_input) = input {
215        messages.push(Message::user(user_input));
216    }
217
218    // If no messages at all, return empty response
219    if messages.is_empty() {
220        return Ok(LLMResponse {
221            content: Some(String::new()),
222            tool_calls: vec![],
223        });
224    }
225
226    let model = agent.model.as_deref();
227    llm.chat(messages, vec![], model, false).await
228}
229
230/// Validate output and retry on failure
231#[allow(clippy::too_many_arguments)]
232async fn validate_and_retry_output(
233    content: &str,
234    schema: &OutputSchema,
235    agent: &AgentValue,
236    messages: &[Message],
237    llm: &dyn LLMClient,
238    tools: &[ToolDefinition],
239    model: Option<&str>,
240    logger: &dyn Logger,
241) -> GentResult<String> {
242    let mut last_content = content.to_string();
243    let mut retry_messages = messages.to_vec();
244
245    for retry in 0..=agent.output_retries {
246        // Try to parse as JSON
247        let json: serde_json::Value = match serde_json::from_str(&last_content) {
248            Ok(j) => j,
249            Err(e) => {
250                if retry >= agent.output_retries {
251                    return Err(GentError::OutputValidationError {
252                        message: format!("Invalid JSON: {}", e),
253                        expected: serde_json::to_string(&schema.to_json_schema())
254                            .unwrap_or_else(|_| "<schema>".to_string()),
255                        got: last_content,
256                    });
257                }
258                logger.log(
259                    LogLevel::Debug,
260                    "agent",
261                    &format!("Retry {}: invalid JSON", retry + 1),
262                );
263                let default_retry = "Please respond with valid JSON.";
264                let retry_msg = agent.retry_prompt.as_deref().unwrap_or(default_retry);
265                retry_messages.push(Message::assistant(&last_content));
266                retry_messages.push(Message::user(retry_msg));
267                let response = llm
268                    .chat(retry_messages.clone(), tools.to_vec(), model, true)
269                    .await?;
270                last_content = response.content.unwrap_or_default();
271                continue;
272            }
273        };
274
275        // Validate against schema
276        match validate_output(&json, schema) {
277            Ok(()) => {
278                logger.log(LogLevel::Debug, "agent", "Output validation successful");
279                return Ok(last_content);
280            }
281            Err(e) => {
282                if retry >= agent.output_retries {
283                    return Err(GentError::OutputValidationError {
284                        message: e,
285                        expected: serde_json::to_string(&schema.to_json_schema())
286                            .unwrap_or_else(|_| "<schema>".to_string()),
287                        got: last_content,
288                    });
289                }
290                logger.log(
291                    LogLevel::Debug,
292                    "agent",
293                    &format!("Retry {}: {}", retry + 1, e),
294                );
295                let default_retry = format!(
296                    "Invalid response: {}. Please respond with JSON matching the schema.",
297                    e
298                );
299                let retry_msg = agent
300                    .retry_prompt
301                    .clone()
302                    .unwrap_or(default_retry);
303                retry_messages.push(Message::assistant(&last_content));
304                retry_messages.push(Message::user(retry_msg));
305                let response = llm
306                    .chat(retry_messages.clone(), tools.to_vec(), model, true)
307                    .await?;
308                last_content = response.content.unwrap_or_default();
309            }
310        }
311    }
312
313    Ok(last_content)
314}