1use crate::llm::LLMProvider;
11use crate::llm::client::ChatSession;
12use crate::llm::context::AgentContextBuilder;
13use crate::llm::tool_executor::ToolExecutor;
14use crate::llm::types::{
15 ChatCompletionRequest, ChatMessage, ContentPart, ImageUrl, LLMResult, MessageContent, Role,
16 Tool,
17};
18use anyhow::Result;
19use std::collections::HashMap;
20use std::path::Path;
21use std::sync::Arc;
22
23#[derive(Debug, Clone)]
25pub struct AgentLoopConfig {
26 pub max_tool_iterations: usize,
28 pub default_model: String,
30 pub temperature: Option<f32>,
32 pub max_tokens: Option<u32>,
34}
35
36impl Default for AgentLoopConfig {
37 fn default() -> Self {
38 Self {
39 max_tool_iterations: 10,
40 default_model: "gpt-4o-mini".to_string(),
41 temperature: Some(0.7),
42 max_tokens: None,
43 }
44 }
45}
46
47pub struct AgentLoop {
49 provider: Arc<dyn LLMProvider>,
51 tools: Arc<dyn ToolExecutor>,
53 config: AgentLoopConfig,
55}
56
57pub struct AgentLoopRunner {
62 agent_loop: AgentLoop,
63 context_builder: Option<AgentContextBuilder>,
64 session: Option<ChatSession>,
65 model: Option<String>,
66}
67
68impl AgentLoop {
69 pub fn new(
71 provider: Arc<dyn LLMProvider>,
72 tools: Arc<dyn ToolExecutor>,
73 config: AgentLoopConfig,
74 ) -> Self {
75 Self {
76 provider,
77 tools,
78 config,
79 }
80 }
81
82 pub fn with_defaults(provider: Arc<dyn LLMProvider>, tools: Arc<dyn ToolExecutor>) -> Self {
84 Self::new(provider, tools, AgentLoopConfig::default())
85 }
86
87 pub async fn process_message(
89 &self,
90 context: Vec<ChatMessage>,
91 content: &str,
92 media: Option<Vec<String>>,
93 ) -> Result<String> {
94 self.process_with_options(context, content, media, None)
95 .await
96 }
97
98 pub async fn process_with_model(
100 &self,
101 context: Vec<ChatMessage>,
102 content: &str,
103 media: Option<Vec<String>>,
104 model: &str,
105 ) -> Result<String> {
106 self.process_with_options(context, content, media, Some(model))
107 .await
108 }
109
110 pub async fn process_with_options(
112 &self,
113 mut context: Vec<ChatMessage>,
114 content: &str,
115 media: Option<Vec<String>>,
116 model: Option<&str>,
117 ) -> Result<String> {
118 let user_msg = if let Some(media_paths) = media {
120 if !media_paths.is_empty() {
121 Self::build_vision_message(content, &media_paths)?
122 } else {
123 ChatMessage::user(content)
124 }
125 } else {
126 ChatMessage::user(content)
127 };
128
129 context.push(user_msg);
130
131 let tools = self
133 .tools
134 .available_tools()
135 .await
136 .map_err(|e| anyhow::anyhow!(e))?;
137
138 self.run_agent_loop(context, &tools, model).await
140 }
141
142 pub async fn process_with_context_builder(
144 &self,
145 context_builder: &AgentContextBuilder,
146 history: Vec<ChatMessage>,
147 content: &str,
148 media: Option<Vec<String>>,
149 model: Option<&str>,
150 ) -> Result<String> {
151 let context = context_builder
152 .build_messages(history, content, media)
153 .await?;
154
155 let tools = self
156 .tools
157 .available_tools()
158 .await
159 .map_err(|e| anyhow::anyhow!(e))?;
160
161 self.run_agent_loop(context, &tools, model).await
162 }
163
164 pub async fn process_with_session(
170 &self,
171 session: &mut ChatSession,
172 content: &str,
173 media: Option<Vec<String>>,
174 context_builder: Option<&AgentContextBuilder>,
175 model: Option<&str>,
176 ) -> Result<String> {
177 let history = session.messages().to_vec();
178
179 let context = if let Some(builder) = context_builder {
180 builder
181 .build_messages(history, content, media.clone())
182 .await?
183 } else {
184 let mut messages = history;
185 let user_msg = if let Some(media_paths) = media.clone() {
186 if !media_paths.is_empty() {
187 Self::build_vision_message(content, &media_paths)?
188 } else {
189 ChatMessage::user(content)
190 }
191 } else {
192 ChatMessage::user(content)
193 };
194 messages.push(user_msg);
195 messages
196 };
197
198 let tools = self
199 .tools
200 .available_tools()
201 .await
202 .map_err(|e| anyhow::anyhow!(e))?;
203
204 let response = self.run_agent_loop(context, &tools, model).await?;
205
206 let user_msg = if let Some(media_paths) = media {
207 if !media_paths.is_empty() {
208 Self::build_vision_message(content, &media_paths)?
209 } else {
210 ChatMessage::user(content)
211 }
212 } else {
213 ChatMessage::user(content)
214 };
215 session.messages_mut().push(user_msg);
216 session
217 .messages_mut()
218 .push(ChatMessage::assistant(&response));
219
220 Ok(response)
221 }
222
223 async fn run_agent_loop(
225 &self,
226 mut messages: Vec<ChatMessage>,
227 tools: &[Tool],
228 model: Option<&str>,
229 ) -> Result<String> {
230 let model = model.unwrap_or(&self.config.default_model);
231
232 for _iteration in 0..self.config.max_tool_iterations {
233 let mut request = ChatCompletionRequest::new(model);
235 request.messages = messages.clone();
236 request.temperature = self.config.temperature;
237 request.max_tokens = self.config.max_tokens;
238
239 if !tools.is_empty() {
240 request.tools = Some(tools.to_vec());
241 }
242
243 let response = self.provider.chat(request).await?;
245
246 if let Some(tool_calls) = response.tool_calls()
248 && !tool_calls.is_empty()
249 {
250 messages.push(ChatMessage::assistant_with_tool_calls(tool_calls.clone()));
252
253 for tool_call in tool_calls {
255 tracing::debug!(
256 "Executing tool: {} with args: {:?}",
257 tool_call.function.name,
258 tool_call.function.arguments
259 );
260
261 let result = self
262 .execute_tool(&tool_call.function.name, &tool_call.function.arguments)
263 .await;
264
265 messages.push(ChatMessage::tool_result(
266 &tool_call.id,
267 result.unwrap_or_else(|e| format!("Error: {}", e)),
268 ));
269 }
270
271 continue;
272 }
273
274 if let Some(content) = response.content() {
276 return Ok(content.to_string());
277 } else {
278 return Ok("No response generated.".to_string());
279 }
280 }
281
282 tracing::warn!(
284 "Agent loop exceeded max iterations ({})",
285 self.config.max_tool_iterations
286 );
287 Ok("I've completed processing but hit the maximum iteration limit.".to_string())
288 }
289
290 async fn execute_tool(&self, name: &str, arguments: &str) -> LLMResult<String> {
292 self.tools.execute(name, arguments).await
293 }
294
295 fn build_vision_message(text: &str, image_paths: &[String]) -> Result<ChatMessage> {
297 let mut parts = vec![ContentPart::Text {
298 text: text.to_string(),
299 }];
300
301 for path in image_paths {
302 let image_url = Self::encode_image_data_url(Path::new(path))?;
303 parts.push(ContentPart::Image { image_url });
304 }
305
306 Ok(ChatMessage {
307 role: Role::User,
308 content: Some(MessageContent::Parts(parts)),
309 name: None,
310 tool_calls: None,
311 tool_call_id: None,
312 })
313 }
314
315 fn encode_image_data_url(path: &Path) -> Result<ImageUrl> {
317 use base64::Engine;
318 use base64::engine::general_purpose::STANDARD_NO_PAD;
319 use std::fs;
320
321 let bytes = fs::read(path)?;
322 let mime_type = infer::get_from_path(path)?
323 .ok_or_else(|| anyhow::anyhow!("Unknown MIME type for: {:?}", path))?
324 .mime_type()
325 .to_string();
326
327 let base64 = STANDARD_NO_PAD.encode(&bytes);
328 let url = format!("data:{};base64,{}", mime_type, base64);
329
330 Ok(ImageUrl { url, detail: None })
331 }
332
333 pub fn config(&self) -> &AgentLoopConfig {
335 &self.config
336 }
337}
338
339impl AgentLoopRunner {
340 pub fn new(agent_loop: AgentLoop) -> Self {
342 Self {
343 agent_loop,
344 context_builder: None,
345 session: None,
346 model: None,
347 }
348 }
349
350 pub fn with_context_builder(mut self, context_builder: AgentContextBuilder) -> Self {
352 self.context_builder = Some(context_builder);
353 self
354 }
355
356 pub fn with_session(mut self, session: ChatSession) -> Self {
358 self.session = Some(session);
359 self
360 }
361
362 pub fn with_model(mut self, model: impl Into<String>) -> Self {
364 self.model = Some(model.into());
365 self
366 }
367
368 pub fn session_mut(&mut self) -> Option<&mut ChatSession> {
370 self.session.as_mut()
371 }
372
373 pub async fn run(&mut self, content: &str, media: Option<Vec<String>>) -> Result<String> {
375 if let Some(session) = self.session.as_mut() {
376 let builder = self.context_builder.as_ref();
377 return self
378 .agent_loop
379 .process_with_session(session, content, media, builder, self.model.as_deref())
380 .await;
381 }
382
383 if let Some(builder) = self.context_builder.as_ref() {
384 return self
385 .agent_loop
386 .process_with_context_builder(
387 builder,
388 Vec::new(),
389 content,
390 media,
391 self.model.as_deref(),
392 )
393 .await;
394 }
395
396 self.agent_loop
397 .process_with_options(Vec::new(), content, media, self.model.as_deref())
398 .await
399 }
400}
401
402pub struct SimpleToolExecutor {
404 tools: HashMap<String, Box<dyn Fn(&str) -> Result<String> + Send + Sync>>,
405}
406
407impl SimpleToolExecutor {
408 pub fn new() -> Self {
409 Self {
410 tools: HashMap::new(),
411 }
412 }
413
414 pub fn register<F>(&mut self, name: impl Into<String>, handler: F) -> &mut Self
415 where
416 F: Fn(&str) -> Result<String> + Send + Sync + 'static,
417 {
418 self.tools.insert(name.into(), Box::new(handler));
419 self
420 }
421}
422
423impl Default for SimpleToolExecutor {
424 fn default() -> Self {
425 Self::new()
426 }
427}
428
429#[async_trait::async_trait]
430impl ToolExecutor for SimpleToolExecutor {
431 async fn execute(&self, name: &str, arguments: &str) -> LLMResult<String> {
432 if let Some(handler) = self.tools.get(name) {
433 handler(arguments).map_err(|e| crate::llm::types::LLMError::Other(e.to_string()))
434 } else {
435 Err(crate::llm::types::LLMError::Other(format!(
436 "Unknown tool: {}",
437 name
438 )))
439 }
440 }
441
442 async fn available_tools(&self) -> LLMResult<Vec<Tool>> {
443 Ok(Vec::new())
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451
452 #[test]
453 fn test_agent_loop_config_default() {
454 let config = AgentLoopConfig::default();
455 assert_eq!(config.max_tool_iterations, 10);
456 assert_eq!(config.default_model, "gpt-4o-mini");
457 }
458
459 #[test]
460 fn test_agent_loop_config_custom() {
461 let config = AgentLoopConfig {
462 max_tool_iterations: 5,
463 default_model: "gpt-4".to_string(),
464 temperature: Some(0.5),
465 max_tokens: Some(1000),
466 };
467 assert_eq!(config.max_tool_iterations, 5);
468 assert_eq!(config.default_model, "gpt-4");
469 }
470}