1use anyhow::{anyhow, Result};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use url::Url;
9
10use crate::tools;
11
12const MAX_ITERATIONS: usize = 50;
13const DEFAULT_OLLAMA_ENDPOINT: &str = "http://localhost:11434";
14
15fn calculate_backoff(attempt: u32, base_delay_ms: u64) -> u64 {
17 let exponential = 2u64.saturating_pow(attempt - 1);
18 let delay = base_delay_ms.saturating_mul(exponential);
19 let jitter = (delay / 10).saturating_mul(
21 ((attempt as u64).wrapping_mul(7)) % 21 / 10, );
23 if attempt.is_multiple_of(2) {
24 delay.saturating_add(jitter)
25 } else {
26 delay.saturating_sub(jitter)
27 }
28}
29
30#[derive(Debug, Serialize, Deserialize)]
31struct ChatRequest {
32 model: String,
33 messages: Vec<Message>,
34 #[serde(skip_serializing_if = "Vec::is_empty")]
35 tools: Vec<Value>,
36 stream: bool,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40struct Message {
41 role: String,
42 content: String,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 tool_calls: Option<Vec<ToolCall>>,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 tool_call_id: Option<String>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50struct ToolCall {
51 id: String,
52 function: ToolCallFunction,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56struct ToolCallFunction {
57 name: String,
58 arguments: Value,
59}
60
61#[derive(Debug, Deserialize)]
62struct ChatResponse {
63 message: ResponseMessage,
64}
65
66#[derive(Debug, Deserialize)]
67struct ResponseMessage {
68 #[allow(dead_code)]
69 role: String,
70 content: String,
71 #[serde(default)]
72 tool_calls: Vec<ToolCall>,
73}
74
75pub fn run_agent(
84 endpoint: &str,
85 model: &str,
86 system_prompt: &str,
87 user_message: &str,
88 callback: &mut dyn FnMut(&str) -> Result<()>,
89) -> Result<String> {
90 run_agent_with_retries(
91 endpoint,
92 model,
93 system_prompt,
94 user_message,
95 callback,
96 3,
97 1000,
98 )
99}
100
101pub fn run_agent_with_retries(
103 endpoint: &str,
104 model: &str,
105 system_prompt: &str,
106 user_message: &str,
107 callback: &mut dyn FnMut(&str) -> Result<()>,
108 max_retries: u32,
109 retry_delay_ms: u64,
110) -> Result<String> {
111 let url = Url::parse(endpoint).unwrap_or_else(|_| {
114 Url::parse(DEFAULT_OLLAMA_ENDPOINT).expect("DEFAULT_OLLAMA_ENDPOINT is valid")
115 });
116 let base_url = format!(
117 "{}://{}:{}",
118 url.scheme(),
119 url.host_str().unwrap_or("localhost"),
120 url.port().unwrap_or(11434)
121 );
122 let chat_url = format!("{}/api/chat", base_url);
123
124 let mut messages: Vec<Message> = vec![];
126 if !system_prompt.is_empty() {
127 messages.push(Message {
128 role: "system".to_string(),
129 content: system_prompt.to_string(),
130 tool_calls: None,
131 tool_call_id: None,
132 });
133 }
134 messages.push(Message {
135 role: "user".to_string(),
136 content: user_message.to_string(),
137 tool_calls: None,
138 tool_call_id: None,
139 });
140
141 let tool_defs = tools::get_tool_definitions();
143
144 let mut iteration = 0;
146 let mut final_response = String::new();
147
148 loop {
149 iteration += 1;
150 if iteration > MAX_ITERATIONS {
151 callback(&format!(
152 "Warning: Reached max iterations ({})",
153 MAX_ITERATIONS
154 ))?;
155 break;
156 }
157
158 let request = ChatRequest {
160 model: model.to_string(),
161 messages: messages.clone(),
162 tools: tool_defs.clone(),
163 stream: false,
164 };
165
166 let mut attempt = 0;
168 let chat_response = loop {
169 attempt += 1;
170
171 let client = ureq::Agent::new();
172 let response = client
173 .post(&chat_url)
174 .set("Content-Type", "application/json")
175 .send_json(&request);
176
177 match response {
178 Ok(resp) => {
179 let status = resp.status();
180
181 let is_retryable = status == 429
183 || status == 500
184 || status == 502
185 || status == 503
186 || status == 504;
187
188 if status == 200 {
189 let response_text = resp.into_string()?;
191 match serde_json::from_str::<ChatResponse>(&response_text) {
192 Ok(parsed) => break parsed,
193 Err(e) => {
194 return Err(anyhow!(
195 "Failed to parse response: {} - body: {}",
196 e,
197 response_text
198 ))
199 }
200 }
201 } else if is_retryable && attempt <= max_retries {
202 let delay_ms = calculate_backoff(attempt, retry_delay_ms);
204 callback(&format!(
205 "[Retry {}] HTTP {} - waiting {}ms before retry",
206 attempt, status, delay_ms
207 ))?;
208 std::thread::sleep(std::time::Duration::from_millis(delay_ms));
209 continue;
210 } else {
211 return Err(anyhow!(
213 "HTTP request failed with status {}: {} (after {} attempt{})",
214 status,
215 resp.status_text(),
216 attempt,
217 if attempt == 1 { "" } else { "s" }
218 ));
219 }
220 }
221 Err(e) => {
222 let error_str = e.to_string();
224 let is_retryable = error_str.contains("Connection")
225 || error_str.contains("timeout")
226 || error_str.contains("reset");
227
228 if is_retryable && attempt <= max_retries {
229 let delay_ms = calculate_backoff(attempt, retry_delay_ms);
230 callback(&format!(
231 "[Retry {}] Network error - waiting {}ms before retry: {}",
232 attempt, delay_ms, error_str
233 ))?;
234 std::thread::sleep(std::time::Duration::from_millis(delay_ms));
235 continue;
236 } else {
237 return Err(anyhow!("HTTP request failed: {}", e));
238 }
239 }
240 }
241 };
242
243 if chat_response.message.tool_calls.is_empty() {
245 final_response = chat_response.message.content.clone();
247
248 let mut line_buffer = String::new();
250 for ch in final_response.chars() {
251 line_buffer.push(ch);
252 if ch == '\n' {
253 let line = line_buffer.trim_end_matches('\n');
254 callback(line)?;
255 line_buffer.clear();
256 }
257 }
258
259 if !line_buffer.is_empty() {
261 callback(&line_buffer)?;
262 }
263 break;
264 }
265
266 messages.push(Message {
268 role: "assistant".to_string(),
269 content: chat_response.message.content.clone(),
270 tool_calls: Some(chat_response.message.tool_calls.clone()),
271 tool_call_id: None,
272 });
273
274 for tool_call in &chat_response.message.tool_calls {
276 let tool_name = &tool_call.function.name;
277 let tool_args = &tool_call.function.arguments;
278
279 callback(&format!("[Tool: {}] {}", tool_name, tool_args))?;
281
282 let result = match tools::execute_tool(tool_name, tool_args) {
284 Ok(output) => output,
285 Err(error) => format!("Error: {}", error),
286 };
287
288 let result_preview = if result.len() > 200 {
290 format!("{}... ({} bytes)", &result[..200], result.len())
291 } else {
292 result.clone()
293 };
294 callback(&format!("[Result] {}", result_preview))?;
295
296 messages.push(Message {
298 role: "tool".to_string(),
299 content: result,
300 tool_calls: None,
301 tool_call_id: Some(tool_call.id.clone()),
302 });
303 }
304 }
305
306 Ok(final_response)
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn test_agent_initialization() {
315 let endpoint = "http://localhost:11434";
317 let model = "qwen2.5:7b";
318 let system_prompt = "You are a helpful assistant.";
319 let user_message = "Hello, who are you?";
320
321 let mut callback = |_line: &str| -> Result<()> { Ok(()) };
322
323 let _endpoint = endpoint;
325 let _model = model;
326 let _system_prompt = system_prompt;
327 let _user_message = user_message;
328 let _ = &mut callback;
329 }
330}