1use anyhow::{Context, Result};
37use futures_util::StreamExt;
38use serde::{Deserialize, Serialize};
39use thiserror::Error;
40use tokio::sync::mpsc;
41
42#[derive(Debug, Error)]
46pub enum OllamaError {
47 #[error("connection failed: {0}")]
49 ConnectionFailed(String),
50
51 #[error("request timed out")]
53 Timeout,
54
55 #[error("invalid response: {0}")]
57 InvalidResponse(String),
58
59 #[error("model not found: {0}")]
61 ModelNotFound(String),
62
63 #[error("server error (HTTP {0})")]
65 ServerError(u16),
66}
67
68const DEFAULT_BASE_URL: &str = "http://127.0.0.1:11434";
70const DEFAULT_TIMEOUT_SECS: u64 = 60;
71const RETRY_BACKOFF_MS: u64 = 1000;
72
73#[derive(Clone)]
77pub struct OllamaClient {
78 http: reqwest::Client,
79 base_url: String,
80}
81
82#[derive(Debug, Clone)]
84pub struct OllamaConfig {
85 pub base_url: String,
87 pub timeout_secs: u64,
89}
90
91impl Default for OllamaConfig {
92 fn default() -> Self {
93 Self {
94 base_url: DEFAULT_BASE_URL.to_string(),
95 timeout_secs: DEFAULT_TIMEOUT_SECS,
96 }
97 }
98}
99
100#[derive(Serialize)]
101struct ChatRequest<'a> {
102 model: &'a str,
103 messages: Vec<ChatMessage<'a>>,
104 stream: bool,
105 options: ChatOptions,
106}
107
108#[derive(Serialize)]
109struct ChatOptions {
110 temperature: f32,
111 num_predict: i32,
112}
113
114#[derive(Serialize)]
115struct ChatMessage<'a> {
116 role: &'a str,
117 content: &'a str,
118}
119
120#[derive(Deserialize)]
121struct ChatResponse {
122 message: Option<ResponseMessage>,
123}
124
125#[derive(Deserialize)]
126struct ResponseMessage {
127 content: String,
128}
129
130#[derive(Deserialize)]
131struct StreamResponse {
132 message: Option<ResponseMessage>,
133 #[serde(default)]
134 done: bool,
135}
136
137#[derive(Deserialize)]
138struct TagsResponse {
139 models: Option<Vec<ModelInfo>>,
140}
141
142#[derive(Deserialize)]
143struct ModelInfo {
144 name: String,
145}
146
147impl OllamaClient {
148 pub fn new() -> Self {
155 Self::with_config(OllamaConfig::default())
156 .expect("failed to build HTTP client with default OllamaConfig")
157 }
158
159 pub fn with_config(config: OllamaConfig) -> Result<Self, OllamaError> {
166 let http = reqwest::Client::builder()
167 .timeout(std::time::Duration::from_secs(config.timeout_secs))
168 .build()
169 .map_err(|e| OllamaError::ConnectionFailed(e.to_string()))?;
170
171 Ok(Self {
172 http,
173 base_url: config.base_url,
174 })
175 }
176
177 pub async fn is_available(&self) -> bool {
179 let url = format!("{}/api/tags", self.base_url);
180 self.http
181 .get(&url)
182 .timeout(std::time::Duration::from_secs(3))
183 .send()
184 .await
185 .map(|r| r.status().is_success())
186 .unwrap_or(false)
187 }
188
189 pub async fn has_model(&self, model: &str) -> bool {
191 let url = format!("{}/api/tags", self.base_url);
192 match self.http.get(&url).send().await {
193 Ok(resp) => {
194 if let Ok(tags) = resp.json::<TagsResponse>().await {
195 if let Some(models) = tags.models {
196 return models
197 .iter()
198 .any(|m| m.name == model || m.name.starts_with(&format!("{model}:")));
199 }
200 }
201 false
202 }
203 Err(_) => false,
204 }
205 }
206
207 pub async fn complete(
211 &self,
212 model: &str,
213 system: &str,
214 user_message: &str,
215 temperature: f32,
216 max_tokens: i32,
217 ) -> Result<String> {
218 let body = ChatRequest {
219 model,
220 messages: vec![
221 ChatMessage {
222 role: "system",
223 content: system,
224 },
225 ChatMessage {
226 role: "user",
227 content: user_message,
228 },
229 ],
230 stream: false,
231 options: ChatOptions {
232 temperature,
233 num_predict: max_tokens,
234 },
235 };
236
237 match self.send_request(&body).await {
238 Ok(text) => return Ok(text),
239 Err(e) => {
240 if Self::is_retryable(&e) {
241 tracing::warn!(
242 "Ollama retryable error: {e} — retrying in {RETRY_BACKOFF_MS}ms"
243 );
244 tokio::time::sleep(std::time::Duration::from_millis(RETRY_BACKOFF_MS)).await;
245 } else {
246 return Err(e);
247 }
248 }
249 }
250
251 self.send_request(&body).await
252 }
253
254 pub async fn complete_with_history(
256 &self,
257 model: &str,
258 messages: &[(&str, &str)], temperature: f32,
260 max_tokens: i32,
261 ) -> Result<String> {
262 let chat_messages: Vec<ChatMessage<'_>> = messages
263 .iter()
264 .map(|(role, content)| ChatMessage { role, content })
265 .collect();
266
267 let body = ChatRequest {
268 model,
269 messages: chat_messages,
270 stream: false,
271 options: ChatOptions {
272 temperature,
273 num_predict: max_tokens,
274 },
275 };
276
277 self.send_request(&body).await
278 }
279
280 async fn send_request(&self, body: &ChatRequest<'_>) -> Result<String> {
281 let url = format!("{}/api/chat", self.base_url);
282 let resp = self
283 .http
284 .post(&url)
285 .json(body)
286 .send()
287 .await
288 .context("Failed to reach Ollama — is it running? (ollama serve)")?;
289
290 let status = resp.status();
291 if !status.is_success() {
292 let body_text = resp.text().await.unwrap_or_default();
293 anyhow::bail!("Ollama error ({}): {}", status.as_u16(), body_text);
294 }
295
296 let response: ChatResponse = resp
297 .json()
298 .await
299 .context("Failed to parse Ollama response")?;
300
301 let text = response.message.map(|m| m.content).unwrap_or_default();
302
303 if text.trim().is_empty() {
304 anyhow::bail!("Empty response from Ollama");
305 }
306
307 Ok(text)
308 }
309
310 pub async fn complete_streaming(
327 &self,
328 model: &str,
329 system: &str,
330 user_message: &str,
331 temperature: f32,
332 max_tokens: i32,
333 ) -> Result<mpsc::Receiver<String>> {
334 let (tx, rx) = mpsc::channel(64);
335
336 let url = format!("{}/api/chat", self.base_url);
337 let body = ChatRequest {
338 model,
339 messages: vec![
340 ChatMessage {
341 role: "system",
342 content: system,
343 },
344 ChatMessage {
345 role: "user",
346 content: user_message,
347 },
348 ],
349 stream: true,
350 options: ChatOptions {
351 temperature,
352 num_predict: max_tokens,
353 },
354 };
355
356 let resp = self
357 .http
358 .post(&url)
359 .json(&body)
360 .send()
361 .await
362 .context("Failed to reach Ollama for streaming")?;
363
364 if !resp.status().is_success() {
365 let status = resp.status();
366 let body_text = resp.text().await.unwrap_or_default();
367 anyhow::bail!(
368 "Ollama streaming error ({}): {}",
369 status.as_u16(),
370 body_text
371 );
372 }
373
374 tokio::spawn(async move {
375 let mut stream = resp.bytes_stream();
376 let mut buffer = String::new();
377
378 while let Some(chunk_result) = stream.next().await {
379 let bytes = match chunk_result {
380 Ok(b) => b,
381 Err(_) => break,
382 };
383
384 buffer.push_str(&String::from_utf8_lossy(&bytes));
385
386 while let Some(newline_pos) = buffer.find('\n') {
387 let line = buffer[..newline_pos].to_string();
388 buffer = buffer[newline_pos + 1..].to_string();
389
390 if line.trim().is_empty() {
391 continue;
392 }
393
394 if let Ok(resp) = serde_json::from_str::<StreamResponse>(&line) {
395 if let Some(msg) = resp.message {
396 if !msg.content.is_empty() && tx.send(msg.content).await.is_err() {
397 return;
398 }
399 }
400 if resp.done {
401 return;
402 }
403 }
404 }
405 }
406 });
407
408 Ok(rx)
409 }
410
411 fn is_retryable(error: &anyhow::Error) -> bool {
412 let msg = error.to_string();
413 msg.contains("connection refused")
414 || msg.contains("timeout")
415 || msg.contains("Connection reset")
416 }
417}
418
419impl Default for OllamaClient {
420 fn default() -> Self {
421 Self::new()
422 }
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428
429 #[test]
430 fn test_client_creation() {
431 let client = OllamaClient::new();
432 assert_eq!(client.base_url, DEFAULT_BASE_URL);
433 }
434
435 #[test]
436 fn test_custom_config() {
437 let client = OllamaClient::with_config(OllamaConfig {
438 base_url: "http://10.0.0.5:11434".to_string(),
439 timeout_secs: 120,
440 })
441 .expect("failed to build client with custom config");
442 assert_eq!(client.base_url, "http://10.0.0.5:11434");
443 }
444
445 #[test]
446 fn test_retryable_errors() {
447 assert!(OllamaClient::is_retryable(&anyhow::anyhow!(
448 "connection refused"
449 )));
450 assert!(OllamaClient::is_retryable(&anyhow::anyhow!(
451 "request timeout"
452 )));
453 assert!(!OllamaClient::is_retryable(&anyhow::anyhow!(
454 "model not found"
455 )));
456 }
457
458 #[test]
459 fn test_default_config() {
460 let config = OllamaConfig::default();
461 assert_eq!(config.base_url, DEFAULT_BASE_URL);
462 assert_eq!(config.timeout_secs, 60);
463 }
464}