1use crate::client::{CompletionResult, LlmClient, TokenUsage};
4use crate::error::LlmError;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::time::Duration;
8
9const DEFAULT_ENDPOINT: &str = "http://localhost:11434";
10const DEFAULT_MODEL_ENV: &str = "OLLAMA_MODEL";
11const FALLBACK_MODEL: &str = "qwen2.5-coder:7b";
12const RETRY_DELAY: Duration = Duration::from_millis(500);
13
14#[derive(Debug, Clone)]
16pub struct OllamaClient {
17 endpoint: String,
18 model: String,
19 http: reqwest::Client,
20}
21
22impl OllamaClient {
23 pub fn with_endpoint(endpoint: impl Into<String>) -> Self {
26 let model = std::env::var(DEFAULT_MODEL_ENV).unwrap_or_else(|_| FALLBACK_MODEL.to_string());
27 Self {
28 endpoint: endpoint.into(),
29 model,
30 http: reqwest::Client::new(),
31 }
32 }
33
34 pub fn with_endpoint_and_model(endpoint: impl Into<String>, model: impl Into<String>) -> Self {
36 Self {
37 endpoint: endpoint.into(),
38 model: model.into(),
39 http: reqwest::Client::new(),
40 }
41 }
42
43 pub fn local() -> Self {
45 Self::with_endpoint(DEFAULT_ENDPOINT)
46 }
47}
48
49#[derive(Serialize)]
50struct ChatRequest<'a> {
51 model: &'a str,
52 messages: Vec<Message<'a>>,
53 stream: bool,
54 options: Options,
55}
56
57#[derive(Serialize)]
58struct Message<'a> {
59 role: &'a str,
60 content: &'a str,
61}
62
63#[derive(Serialize)]
64struct Options {
65 num_predict: u32,
66}
67
68#[derive(Deserialize)]
69struct ChatResponse {
70 message: AssistantMessage,
71 #[serde(default)]
72 prompt_eval_count: u32,
73 #[serde(default)]
74 eval_count: u32,
75}
76
77#[derive(Deserialize)]
78struct AssistantMessage {
79 content: String,
80}
81
82#[async_trait]
83impl LlmClient for OllamaClient {
84 async fn complete(&self, prompt: &str, max_tokens: u32) -> Result<CompletionResult, LlmError> {
85 let url = format!("{}/api/chat", self.endpoint);
86 let body = ChatRequest {
87 model: &self.model,
88 messages: vec![Message {
89 role: "user",
90 content: prompt,
91 }],
92 stream: false,
93 options: Options {
94 num_predict: max_tokens,
95 },
96 };
97
98 for attempt in 0..=1 {
99 let result = self.http.post(&url).json(&body).send().await;
100
101 match result {
102 Ok(resp) => {
103 let status = resp.status();
104 if status.is_success() {
105 let raw = resp.text().await?;
106 let parsed: ChatResponse = serde_json::from_str(&raw)?;
107 return Ok(CompletionResult {
108 text: parsed.message.content,
109 usage: TokenUsage {
110 input: parsed.prompt_eval_count,
111 output: parsed.eval_count,
112 },
113 });
114 }
115 let retryable = status.as_u16() == 429 || status.is_server_error();
116 if retryable && attempt == 0 {
117 tokio::time::sleep(RETRY_DELAY).await;
118 continue;
119 }
120 let body = resp.text().await.unwrap_or_default();
121 let snippet = if body.len() > 512 {
122 &body[..512]
123 } else {
124 &body
125 };
126 return Err(LlmError::UnexpectedStatus {
127 status: status.as_u16(),
128 body: snippet.to_string(),
129 });
130 }
131 Err(e) if attempt == 0 && e.is_connect() => {
132 tokio::time::sleep(RETRY_DELAY).await;
133 continue;
134 }
135 Err(e) => return Err(LlmError::Http(e)),
136 }
137 }
138 unreachable!("retry loop exits via return")
139 }
140
141 fn model_id(&self) -> &str {
142 &self.model
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use wiremock::matchers::{method, path};
150 use wiremock::{Mock, MockServer, ResponseTemplate};
151
152 const SAMPLE_RESPONSE: &str = r#"{
153 "model": "qwen2.5-coder:7b",
154 "created_at": "2026-04-23T10:00:00Z",
155 "message": {"role": "assistant", "content": "Hello from mock Ollama"},
156 "done": true,
157 "prompt_eval_count": 10,
158 "eval_count": 4
159 }"#;
160
161 #[tokio::test]
162 async fn happy_path_parses_text_and_usage() {
163 let server = MockServer::start().await;
164 Mock::given(method("POST"))
165 .and(path("/api/chat"))
166 .respond_with(ResponseTemplate::new(200).set_body_string(SAMPLE_RESPONSE))
167 .mount(&server)
168 .await;
169
170 let client = OllamaClient::with_endpoint_and_model(server.uri(), "qwen2.5-coder:7b");
171 let got = client.complete("hi", 16).await.unwrap();
172 assert_eq!(got.text, "Hello from mock Ollama");
173 assert_eq!(got.usage.input, 10);
174 assert_eq!(got.usage.output, 4);
175 }
176
177 #[tokio::test]
178 async fn retries_once_on_5xx_then_succeeds() {
179 let server = MockServer::start().await;
180 Mock::given(method("POST"))
181 .and(path("/api/chat"))
182 .respond_with(ResponseTemplate::new(502))
183 .up_to_n_times(1)
184 .mount(&server)
185 .await;
186 Mock::given(method("POST"))
187 .and(path("/api/chat"))
188 .respond_with(ResponseTemplate::new(200).set_body_string(SAMPLE_RESPONSE))
189 .mount(&server)
190 .await;
191
192 let client = OllamaClient::with_endpoint_and_model(server.uri(), "qwen2.5-coder:7b");
193 let got = client.complete("hi", 16).await.unwrap();
194 assert_eq!(got.text, "Hello from mock Ollama");
195 }
196
197 #[tokio::test]
198 async fn gives_up_after_second_5xx() {
199 let server = MockServer::start().await;
200 Mock::given(method("POST"))
201 .and(path("/api/chat"))
202 .respond_with(ResponseTemplate::new(500))
203 .mount(&server)
204 .await;
205
206 let client = OllamaClient::with_endpoint_and_model(server.uri(), "qwen2.5-coder:7b");
207 let err = client.complete("hi", 16).await.unwrap_err();
208 assert!(matches!(
209 err,
210 LlmError::UnexpectedStatus { status: 500, .. }
211 ));
212 }
213
214 #[tokio::test]
215 async fn does_not_retry_on_4xx_except_429() {
216 let server = MockServer::start().await;
217 Mock::given(method("POST"))
218 .and(path("/api/chat"))
219 .respond_with(
220 ResponseTemplate::new(404).set_body_string(r#"{"error":"model missing"}"#),
221 )
222 .expect(1)
223 .mount(&server)
224 .await;
225
226 let client = OllamaClient::with_endpoint_and_model(server.uri(), "nonexistent");
227 let err = client.complete("hi", 16).await.unwrap_err();
228 assert!(matches!(
229 err,
230 LlmError::UnexpectedStatus { status: 404, .. }
231 ));
232 }
233}