Skip to main content

evolve_llm/
ollama.rs

1//! Ollama native `/api/chat` client. No auth.
2
3use crate::client::{CompletionResult, LlmClient, TokenUsage};
4use crate::error::LlmError;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::time::Duration;
8
9const DEFAULT_ENDPOINT: &str = "http://localhost:11434";
10const DEFAULT_MODEL_ENV: &str = "OLLAMA_MODEL";
11const FALLBACK_MODEL: &str = "qwen2.5-coder:7b";
12const RETRY_DELAY: Duration = Duration::from_millis(500);
13
14/// Minimal client for Ollama's native chat endpoint.
15#[derive(Debug, Clone)]
16pub struct OllamaClient {
17    endpoint: String,
18    model: String,
19    http: reqwest::Client,
20}
21
22impl OllamaClient {
23    /// Construct with the default local endpoint and `OLLAMA_MODEL` (or the
24    /// fallback) as the model tag.
25    pub fn with_endpoint(endpoint: impl Into<String>) -> Self {
26        let model = std::env::var(DEFAULT_MODEL_ENV).unwrap_or_else(|_| FALLBACK_MODEL.to_string());
27        Self {
28            endpoint: endpoint.into(),
29            model,
30            http: reqwest::Client::new(),
31        }
32    }
33
34    /// Construct with explicit endpoint + model (used by tests).
35    pub fn with_endpoint_and_model(endpoint: impl Into<String>, model: impl Into<String>) -> Self {
36        Self {
37            endpoint: endpoint.into(),
38            model: model.into(),
39            http: reqwest::Client::new(),
40        }
41    }
42
43    /// Default constructor: `http://localhost:11434`.
44    pub fn local() -> Self {
45        Self::with_endpoint(DEFAULT_ENDPOINT)
46    }
47}
48
49#[derive(Serialize)]
50struct ChatRequest<'a> {
51    model: &'a str,
52    messages: Vec<Message<'a>>,
53    stream: bool,
54    options: Options,
55}
56
57#[derive(Serialize)]
58struct Message<'a> {
59    role: &'a str,
60    content: &'a str,
61}
62
63#[derive(Serialize)]
64struct Options {
65    num_predict: u32,
66}
67
68#[derive(Deserialize)]
69struct ChatResponse {
70    message: AssistantMessage,
71    #[serde(default)]
72    prompt_eval_count: u32,
73    #[serde(default)]
74    eval_count: u32,
75}
76
77#[derive(Deserialize)]
78struct AssistantMessage {
79    content: String,
80}
81
82#[async_trait]
83impl LlmClient for OllamaClient {
84    async fn complete(&self, prompt: &str, max_tokens: u32) -> Result<CompletionResult, LlmError> {
85        let url = format!("{}/api/chat", self.endpoint);
86        let body = ChatRequest {
87            model: &self.model,
88            messages: vec![Message {
89                role: "user",
90                content: prompt,
91            }],
92            stream: false,
93            options: Options {
94                num_predict: max_tokens,
95            },
96        };
97
98        for attempt in 0..=1 {
99            let result = self.http.post(&url).json(&body).send().await;
100
101            match result {
102                Ok(resp) => {
103                    let status = resp.status();
104                    if status.is_success() {
105                        let raw = resp.text().await?;
106                        let parsed: ChatResponse = serde_json::from_str(&raw)?;
107                        return Ok(CompletionResult {
108                            text: parsed.message.content,
109                            usage: TokenUsage {
110                                input: parsed.prompt_eval_count,
111                                output: parsed.eval_count,
112                            },
113                        });
114                    }
115                    let retryable = status.as_u16() == 429 || status.is_server_error();
116                    if retryable && attempt == 0 {
117                        tokio::time::sleep(RETRY_DELAY).await;
118                        continue;
119                    }
120                    let body = resp.text().await.unwrap_or_default();
121                    let snippet = if body.len() > 512 {
122                        &body[..512]
123                    } else {
124                        &body
125                    };
126                    return Err(LlmError::UnexpectedStatus {
127                        status: status.as_u16(),
128                        body: snippet.to_string(),
129                    });
130                }
131                Err(e) if attempt == 0 && e.is_connect() => {
132                    tokio::time::sleep(RETRY_DELAY).await;
133                    continue;
134                }
135                Err(e) => return Err(LlmError::Http(e)),
136            }
137        }
138        unreachable!("retry loop exits via return")
139    }
140
141    fn model_id(&self) -> &str {
142        &self.model
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use wiremock::matchers::{method, path};
150    use wiremock::{Mock, MockServer, ResponseTemplate};
151
152    const SAMPLE_RESPONSE: &str = r#"{
153        "model": "qwen2.5-coder:7b",
154        "created_at": "2026-04-23T10:00:00Z",
155        "message": {"role": "assistant", "content": "Hello from mock Ollama"},
156        "done": true,
157        "prompt_eval_count": 10,
158        "eval_count": 4
159    }"#;
160
161    #[tokio::test]
162    async fn happy_path_parses_text_and_usage() {
163        let server = MockServer::start().await;
164        Mock::given(method("POST"))
165            .and(path("/api/chat"))
166            .respond_with(ResponseTemplate::new(200).set_body_string(SAMPLE_RESPONSE))
167            .mount(&server)
168            .await;
169
170        let client = OllamaClient::with_endpoint_and_model(server.uri(), "qwen2.5-coder:7b");
171        let got = client.complete("hi", 16).await.unwrap();
172        assert_eq!(got.text, "Hello from mock Ollama");
173        assert_eq!(got.usage.input, 10);
174        assert_eq!(got.usage.output, 4);
175    }
176
177    #[tokio::test]
178    async fn retries_once_on_5xx_then_succeeds() {
179        let server = MockServer::start().await;
180        Mock::given(method("POST"))
181            .and(path("/api/chat"))
182            .respond_with(ResponseTemplate::new(502))
183            .up_to_n_times(1)
184            .mount(&server)
185            .await;
186        Mock::given(method("POST"))
187            .and(path("/api/chat"))
188            .respond_with(ResponseTemplate::new(200).set_body_string(SAMPLE_RESPONSE))
189            .mount(&server)
190            .await;
191
192        let client = OllamaClient::with_endpoint_and_model(server.uri(), "qwen2.5-coder:7b");
193        let got = client.complete("hi", 16).await.unwrap();
194        assert_eq!(got.text, "Hello from mock Ollama");
195    }
196
197    #[tokio::test]
198    async fn gives_up_after_second_5xx() {
199        let server = MockServer::start().await;
200        Mock::given(method("POST"))
201            .and(path("/api/chat"))
202            .respond_with(ResponseTemplate::new(500))
203            .mount(&server)
204            .await;
205
206        let client = OllamaClient::with_endpoint_and_model(server.uri(), "qwen2.5-coder:7b");
207        let err = client.complete("hi", 16).await.unwrap_err();
208        assert!(matches!(
209            err,
210            LlmError::UnexpectedStatus { status: 500, .. }
211        ));
212    }
213
214    #[tokio::test]
215    async fn does_not_retry_on_4xx_except_429() {
216        let server = MockServer::start().await;
217        Mock::given(method("POST"))
218            .and(path("/api/chat"))
219            .respond_with(
220                ResponseTemplate::new(404).set_body_string(r#"{"error":"model missing"}"#),
221            )
222            .expect(1)
223            .mount(&server)
224            .await;
225
226        let client = OllamaClient::with_endpoint_and_model(server.uri(), "nonexistent");
227        let err = client.complete("hi", 16).await.unwrap_err();
228        assert!(matches!(
229            err,
230            LlmError::UnexpectedStatus { status: 404, .. }
231        ));
232    }
233}