Skip to main content

chant/
agent.rs

1//! Agent runtime for ollama with function calling support.
2//!
3//! Uses ureq HTTP client for direct API communication.
4
5use anyhow::{anyhow, Result};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use url::Url;
9
10use crate::tools;
11
12const MAX_ITERATIONS: usize = 50;
13const DEFAULT_OLLAMA_ENDPOINT: &str = "http://localhost:11434";
14
15/// Calculate exponential backoff delay with jitter
16fn calculate_backoff(attempt: u32, base_delay_ms: u64) -> u64 {
17    let exponential = 2u64.saturating_pow(attempt - 1);
18    let delay = base_delay_ms.saturating_mul(exponential);
19    // Add jitter: ±10% of delay to avoid thundering herd
20    let jitter = (delay / 10).saturating_mul(
21        ((attempt as u64).wrapping_mul(7)) % 21 / 10, // Deterministic pseudo-random jitter
22    );
23    if attempt.is_multiple_of(2) {
24        delay.saturating_add(jitter)
25    } else {
26        delay.saturating_sub(jitter)
27    }
28}
29
30#[derive(Debug, Serialize, Deserialize)]
31struct ChatRequest {
32    model: String,
33    messages: Vec<Message>,
34    #[serde(skip_serializing_if = "Vec::is_empty")]
35    tools: Vec<Value>,
36    stream: bool,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40struct Message {
41    role: String,
42    content: String,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    tool_calls: Option<Vec<ToolCall>>,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    tool_call_id: Option<String>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50struct ToolCall {
51    id: String,
52    function: ToolCallFunction,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56struct ToolCallFunction {
57    name: String,
58    arguments: Value,
59}
60
61#[derive(Debug, Deserialize)]
62struct ChatResponse {
63    message: ResponseMessage,
64}
65
66#[derive(Debug, Deserialize)]
67struct ResponseMessage {
68    #[allow(dead_code)]
69    role: String,
70    content: String,
71    #[serde(default)]
72    tool_calls: Vec<ToolCall>,
73}
74
75/// Run an agent loop using direct HTTP with function calling.
76///
77/// This creates an agent loop where:
78/// 1. Model receives a prompt and thinks about what to do
79/// 2. Model requests tool execution if needed
80/// 3. Runtime executes the tool
81/// 4. Runtime feeds result back to model
82/// 5. Loop continues until task is complete
83pub fn run_agent(
84    endpoint: &str,
85    model: &str,
86    system_prompt: &str,
87    user_message: &str,
88    callback: &mut dyn FnMut(&str) -> Result<()>,
89) -> Result<String> {
90    run_agent_with_retries(
91        endpoint,
92        model,
93        system_prompt,
94        user_message,
95        callback,
96        3,
97        1000,
98    )
99}
100
101/// Run an agent loop with configurable retry policy
102pub fn run_agent_with_retries(
103    endpoint: &str,
104    model: &str,
105    system_prompt: &str,
106    user_message: &str,
107    callback: &mut dyn FnMut(&str) -> Result<()>,
108    max_retries: u32,
109    retry_delay_ms: u64,
110) -> Result<String> {
111    // Parse endpoint to get base URL
112    // Use const for fallback to avoid nested unwrap
113    let url = Url::parse(endpoint).unwrap_or_else(|_| {
114        Url::parse(DEFAULT_OLLAMA_ENDPOINT).expect("DEFAULT_OLLAMA_ENDPOINT is valid")
115    });
116    let base_url = format!(
117        "{}://{}:{}",
118        url.scheme(),
119        url.host_str().unwrap_or("localhost"),
120        url.port().unwrap_or(11434)
121    );
122    let chat_url = format!("{}/api/chat", base_url);
123
124    // Build messages
125    let mut messages: Vec<Message> = vec![];
126    if !system_prompt.is_empty() {
127        messages.push(Message {
128            role: "system".to_string(),
129            content: system_prompt.to_string(),
130            tool_calls: None,
131            tool_call_id: None,
132        });
133    }
134    messages.push(Message {
135        role: "user".to_string(),
136        content: user_message.to_string(),
137        tool_calls: None,
138        tool_call_id: None,
139    });
140
141    // Get tool definitions (already in correct format with lowercase "function")
142    let tool_defs = tools::get_tool_definitions();
143
144    // Tool calling loop
145    let mut iteration = 0;
146    let mut final_response = String::new();
147
148    loop {
149        iteration += 1;
150        if iteration > MAX_ITERATIONS {
151            callback(&format!(
152                "Warning: Reached max iterations ({})",
153                MAX_ITERATIONS
154            ))?;
155            break;
156        }
157
158        // Build request
159        let request = ChatRequest {
160            model: model.to_string(),
161            messages: messages.clone(),
162            tools: tool_defs.clone(),
163            stream: false,
164        };
165
166        // Send request with retry logic
167        let mut attempt = 0;
168        let chat_response = loop {
169            attempt += 1;
170
171            let client = ureq::Agent::new();
172            let response = client
173                .post(&chat_url)
174                .set("Content-Type", "application/json")
175                .send_json(&request);
176
177            match response {
178                Ok(resp) => {
179                    let status = resp.status();
180
181                    // Check for retryable HTTP errors
182                    let is_retryable = status == 429
183                        || status == 500
184                        || status == 502
185                        || status == 503
186                        || status == 504;
187
188                    if status == 200 {
189                        // Success
190                        let response_text = resp.into_string()?;
191                        match serde_json::from_str::<ChatResponse>(&response_text) {
192                            Ok(parsed) => break parsed,
193                            Err(e) => {
194                                return Err(anyhow!(
195                                    "Failed to parse response: {} - body: {}",
196                                    e,
197                                    response_text
198                                ))
199                            }
200                        }
201                    } else if is_retryable && attempt <= max_retries {
202                        // Retryable error - wait and retry
203                        let delay_ms = calculate_backoff(attempt, retry_delay_ms);
204                        callback(&format!(
205                            "[Retry {}] HTTP {} - waiting {}ms before retry",
206                            attempt, status, delay_ms
207                        ))?;
208                        std::thread::sleep(std::time::Duration::from_millis(delay_ms));
209                        continue;
210                    } else {
211                        // Non-retryable error or max retries exceeded
212                        return Err(anyhow!(
213                            "HTTP request failed with status {}: {} (after {} attempt{})",
214                            status,
215                            resp.status_text(),
216                            attempt,
217                            if attempt == 1 { "" } else { "s" }
218                        ));
219                    }
220                }
221                Err(e) => {
222                    // Network error - check if retryable
223                    let error_str = e.to_string();
224                    let is_retryable = error_str.contains("Connection")
225                        || error_str.contains("timeout")
226                        || error_str.contains("reset");
227
228                    if is_retryable && attempt <= max_retries {
229                        let delay_ms = calculate_backoff(attempt, retry_delay_ms);
230                        callback(&format!(
231                            "[Retry {}] Network error - waiting {}ms before retry: {}",
232                            attempt, delay_ms, error_str
233                        ))?;
234                        std::thread::sleep(std::time::Duration::from_millis(delay_ms));
235                        continue;
236                    } else {
237                        return Err(anyhow!("HTTP request failed: {}", e));
238                    }
239                }
240            }
241        };
242
243        // Check if model requested tool calls
244        if chat_response.message.tool_calls.is_empty() {
245            // No tool calls - model has provided final response
246            final_response = chat_response.message.content.clone();
247
248            // Buffer content and only call callback when we have complete lines
249            let mut line_buffer = String::new();
250            for ch in final_response.chars() {
251                line_buffer.push(ch);
252                if ch == '\n' {
253                    let line = line_buffer.trim_end_matches('\n');
254                    callback(line)?;
255                    line_buffer.clear();
256                }
257            }
258
259            // Flush any remaining buffered content
260            if !line_buffer.is_empty() {
261                callback(&line_buffer)?;
262            }
263            break;
264        }
265
266        // Add assistant message with tool calls to history
267        messages.push(Message {
268            role: "assistant".to_string(),
269            content: chat_response.message.content.clone(),
270            tool_calls: Some(chat_response.message.tool_calls.clone()),
271            tool_call_id: None,
272        });
273
274        // Process each tool call from the model
275        for tool_call in &chat_response.message.tool_calls {
276            let tool_name = &tool_call.function.name;
277            let tool_args = &tool_call.function.arguments;
278
279            // Log the tool call
280            callback(&format!("[Tool: {}] {}", tool_name, tool_args))?;
281
282            // Execute the tool
283            let result = match tools::execute_tool(tool_name, tool_args) {
284                Ok(output) => output,
285                Err(error) => format!("Error: {}", error),
286            };
287
288            // Log abbreviated result
289            let result_preview = if result.len() > 200 {
290                format!("{}... ({} bytes)", &result[..200], result.len())
291            } else {
292                result.clone()
293            };
294            callback(&format!("[Result] {}", result_preview))?;
295
296            // Add tool response to messages
297            messages.push(Message {
298                role: "tool".to_string(),
299                content: result,
300                tool_calls: None,
301                tool_call_id: Some(tool_call.id.clone()),
302            });
303        }
304    }
305
306    Ok(final_response)
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn test_agent_initialization() {
315        // Test that agent can be created with proper parameters
316        let endpoint = "http://localhost:11434";
317        let model = "qwen2.5:7b";
318        let system_prompt = "You are a helpful assistant.";
319        let user_message = "Hello, who are you?";
320
321        let mut callback = |_line: &str| -> Result<()> { Ok(()) };
322
323        // Verify types are correct
324        let _endpoint = endpoint;
325        let _model = model;
326        let _system_prompt = system_prompt;
327        let _user_message = user_message;
328        let _ = &mut callback;
329    }
330}