1use crate::errors::{GentError, GentResult};
4use crate::interpreter::{AgentValue, OutputSchema};
5use crate::logging::{LogLevel, Logger, NullLogger};
6use crate::runtime::validation::validate_output;
7use crate::runtime::{LLMClient, LLMResponse, Message, ToolDefinition, ToolRegistry, ToolResult};
8
9const DEFAULT_MAX_STEPS: u32 = 10;
10
11pub async fn run_agent(
13 agent: &AgentValue,
14 input: Option<String>,
15 llm: &dyn LLMClient,
16) -> GentResult<String> {
17 let registry = ToolRegistry::new();
18 let logger = NullLogger;
19 run_agent_with_tools(agent, input, llm, ®istry, &logger).await
20}
21
22pub async fn run_agent_with_tools(
24 agent: &AgentValue,
25 input: Option<String>,
26 llm: &dyn LLMClient,
27 tools: &ToolRegistry,
28 logger: &dyn Logger,
29) -> GentResult<String> {
30 let max_steps = agent.max_steps.unwrap_or(DEFAULT_MAX_STEPS);
31 let tool_defs = tools.definitions_for(&agent.tools);
32 let model = agent.model.as_deref();
33 let json_mode = agent.output_schema.is_some();
34
35 logger.log(
36 LogLevel::Debug,
37 "agent",
38 &format!("Agent '{}' requested tools: {:?}", agent.name, agent.tools),
39 );
40 logger.log(
41 LogLevel::Debug,
42 "agent",
43 &format!("Tool definitions provided to LLM: {}", tool_defs.len()),
44 );
45 for def in &tool_defs {
46 logger.log(
47 LogLevel::Trace,
48 "agent",
49 &format!(" - {} : {}", def.name, def.description),
50 );
51 }
52
53 let mut messages = Vec::new();
55
56 if !agent.system_prompt.is_empty() {
58 let system_prompt = if let Some(schema) = &agent.output_schema {
59 logger.log(
60 LogLevel::Debug,
61 "agent",
62 "Agent has output schema, enabling JSON mode",
63 );
64 let default_instructions = "You must respond with JSON matching this schema:";
65 let instructions = agent
66 .output_instructions
67 .as_deref()
68 .unwrap_or(default_instructions);
69 format!(
70 "{}\n\n{}\n{}",
71 agent.system_prompt,
72 instructions,
73 serde_json::to_string_pretty(&schema.to_json_schema())
74 .unwrap_or_else(|_| "<schema>".to_string())
75 )
76 } else {
77 agent.system_prompt.clone()
78 };
79 messages.push(Message::system(&system_prompt));
80 }
81
82 if let Some(user_prompt) = &agent.user_prompt {
84 messages.push(Message::user(user_prompt.clone()));
85 } else if let Some(user_input) = input {
86 messages.push(Message::user(user_input));
87 }
88
89 if messages.is_empty() {
91 logger.log(
92 LogLevel::Debug,
93 "agent",
94 "No prompts provided, returning empty result",
95 );
96 return Ok(String::new());
97 }
98
99 for step in 0..max_steps {
100 logger.log(
101 LogLevel::Debug,
102 "agent",
103 &format!("Step {}/{}", step + 1, max_steps),
104 );
105 let response = llm
106 .chat(messages.clone(), tool_defs.clone(), model, json_mode)
107 .await?;
108
109 if response.tool_calls.is_empty() {
111 logger.log(
112 LogLevel::Debug,
113 "agent",
114 "No tool calls, returning response",
115 );
116 let content = response.content.unwrap_or_default();
117
118 if let Some(schema) = &agent.output_schema {
120 return validate_and_retry_output(
121 &content, schema, agent, &messages, llm, &tool_defs, model, logger,
122 )
123 .await;
124 }
125
126 return Ok(content);
127 }
128
129 logger.log(
130 LogLevel::Debug,
131 "agent",
132 &format!("LLM made {} tool call(s)", response.tool_calls.len()),
133 );
134 for call in &response.tool_calls {
135 logger.log(
136 LogLevel::Trace,
137 "agent",
138 &format!(" - {}({})", call.name, call.arguments),
139 );
140 }
141
142 messages.push(Message::assistant_with_tool_calls(
144 response.tool_calls.clone(),
145 ));
146
147 for call in &response.tool_calls {
149 let result = match tools.get(&call.name) {
150 Some(tool) => match tool.execute(call.arguments.clone()).await {
151 Ok(output) => {
152 logger.log(
153 LogLevel::Debug,
154 "agent",
155 &format!("Tool '{}' returned: {}", call.name, output),
156 );
157 ToolResult {
158 call_id: call.id.clone(),
159 content: output,
160 is_error: false,
161 }
162 }
163 Err(error) => {
164 logger.log(
165 LogLevel::Warn,
166 "agent",
167 &format!("Tool '{}' error: {}", call.name, error),
168 );
169 ToolResult {
170 call_id: call.id.clone(),
171 content: error,
172 is_error: true,
173 }
174 }
175 },
176 None => {
177 logger.log(
178 LogLevel::Warn,
179 "agent",
180 &format!("Unknown tool: {}", call.name),
181 );
182 ToolResult {
183 call_id: call.id.clone(),
184 content: format!("Unknown tool: {}", call.name),
185 is_error: true,
186 }
187 }
188 };
189
190 messages.push(Message::tool_result(result));
191 }
192 }
193
194 Err(GentError::MaxStepsExceeded { limit: max_steps })
195}
196
197pub async fn run_agent_full(
199 agent: &AgentValue,
200 input: Option<String>,
201 llm: &dyn LLMClient,
202) -> GentResult<LLMResponse> {
203 let mut messages = Vec::new();
205
206 if !agent.system_prompt.is_empty() {
208 messages.push(Message::system(&agent.system_prompt));
209 }
210
211 if let Some(user_prompt) = &agent.user_prompt {
213 messages.push(Message::user(user_prompt.clone()));
214 } else if let Some(user_input) = input {
215 messages.push(Message::user(user_input));
216 }
217
218 if messages.is_empty() {
220 return Ok(LLMResponse {
221 content: Some(String::new()),
222 tool_calls: vec![],
223 });
224 }
225
226 let model = agent.model.as_deref();
227 llm.chat(messages, vec![], model, false).await
228}
229
230#[allow(clippy::too_many_arguments)]
232async fn validate_and_retry_output(
233 content: &str,
234 schema: &OutputSchema,
235 agent: &AgentValue,
236 messages: &[Message],
237 llm: &dyn LLMClient,
238 tools: &[ToolDefinition],
239 model: Option<&str>,
240 logger: &dyn Logger,
241) -> GentResult<String> {
242 let mut last_content = content.to_string();
243 let mut retry_messages = messages.to_vec();
244
245 for retry in 0..=agent.output_retries {
246 let json: serde_json::Value = match serde_json::from_str(&last_content) {
248 Ok(j) => j,
249 Err(e) => {
250 if retry >= agent.output_retries {
251 return Err(GentError::OutputValidationError {
252 message: format!("Invalid JSON: {}", e),
253 expected: serde_json::to_string(&schema.to_json_schema())
254 .unwrap_or_else(|_| "<schema>".to_string()),
255 got: last_content,
256 });
257 }
258 logger.log(
259 LogLevel::Debug,
260 "agent",
261 &format!("Retry {}: invalid JSON", retry + 1),
262 );
263 let default_retry = "Please respond with valid JSON.";
264 let retry_msg = agent.retry_prompt.as_deref().unwrap_or(default_retry);
265 retry_messages.push(Message::assistant(&last_content));
266 retry_messages.push(Message::user(retry_msg));
267 let response = llm
268 .chat(retry_messages.clone(), tools.to_vec(), model, true)
269 .await?;
270 last_content = response.content.unwrap_or_default();
271 continue;
272 }
273 };
274
275 match validate_output(&json, schema) {
277 Ok(()) => {
278 logger.log(LogLevel::Debug, "agent", "Output validation successful");
279 return Ok(last_content);
280 }
281 Err(e) => {
282 if retry >= agent.output_retries {
283 return Err(GentError::OutputValidationError {
284 message: e,
285 expected: serde_json::to_string(&schema.to_json_schema())
286 .unwrap_or_else(|_| "<schema>".to_string()),
287 got: last_content,
288 });
289 }
290 logger.log(
291 LogLevel::Debug,
292 "agent",
293 &format!("Retry {}: {}", retry + 1, e),
294 );
295 let default_retry = format!(
296 "Invalid response: {}. Please respond with JSON matching the schema.",
297 e
298 );
299 let retry_msg = agent
300 .retry_prompt
301 .clone()
302 .unwrap_or(default_retry);
303 retry_messages.push(Message::assistant(&last_content));
304 retry_messages.push(Message::user(retry_msg));
305 let response = llm
306 .chat(retry_messages.clone(), tools.to_vec(), model, true)
307 .await?;
308 last_content = response.content.unwrap_or_default();
309 }
310 }
311 }
312
313 Ok(last_content)
314}