Skip to main content

laminae_ollama/
lib.rs

1//! # laminae-ollama — Ollama Client for Local LLM Inference
2//!
3//! Standalone HTTP client for [Ollama](https://ollama.ai). Supports both
4//! blocking and streaming completions via the `/api/chat` endpoint.
5//!
6//! Zero internal dependencies — this crate talks to a local Ollama instance
7//! and nothing else.
8//!
9//! ## Quick Start
10//!
11//! ```rust,no_run
12//! use laminae_ollama::OllamaClient;
13//!
14//! #[tokio::main]
15//! async fn main() -> anyhow::Result<()> {
16//!     let client = OllamaClient::new();
17//!
18//!     if !client.is_available().await {
19//!         eprintln!("Ollama is not running — start it with `ollama serve`");
20//!         return Ok(());
21//!     }
22//!
23//!     let response = client.complete(
24//!         "llama3.2",
25//!         "You are a helpful assistant.",
26//!         "What is 2 + 2?",
27//!         0.7,
28//!         256,
29//!     ).await?;
30//!
31//!     println!("{response}");
32//!     Ok(())
33//! }
34//! ```
35
36use anyhow::{Context, Result};
37use futures_util::StreamExt;
38use serde::{Deserialize, Serialize};
39use thiserror::Error;
40use tokio::sync::mpsc;
41
42// ── Typed Errors ──
43
44/// Typed errors for the Ollama client.
45#[derive(Debug, Error)]
46pub enum OllamaError {
47    /// Failed to connect to the Ollama server.
48    #[error("connection failed: {0}")]
49    ConnectionFailed(String),
50
51    /// Request timed out waiting for a response.
52    #[error("request timed out")]
53    Timeout,
54
55    /// The server returned a response that could not be parsed.
56    #[error("invalid response: {0}")]
57    InvalidResponse(String),
58
59    /// The requested model is not available locally.
60    #[error("model not found: {0}")]
61    ModelNotFound(String),
62
63    /// The Ollama server returned an HTTP error status.
64    #[error("server error (HTTP {0})")]
65    ServerError(u16),
66}
67
68/// Default Ollama API endpoint.
69const DEFAULT_BASE_URL: &str = "http://127.0.0.1:11434";
70const DEFAULT_TIMEOUT_SECS: u64 = 60;
71const RETRY_BACKOFF_MS: u64 = 1000;
72
73/// Client for Ollama's local LLM API.
74///
75/// Runs entirely on-device — zero cost, no API key, no network egress.
76#[derive(Clone)]
77pub struct OllamaClient {
78    http: reqwest::Client,
79    base_url: String,
80}
81
82/// Configuration for creating an [`OllamaClient`].
83#[derive(Debug, Clone)]
84pub struct OllamaConfig {
85    /// Base URL for the Ollama API (default: `http://127.0.0.1:11434`).
86    pub base_url: String,
87    /// Request timeout in seconds (default: 60).
88    pub timeout_secs: u64,
89}
90
91impl Default for OllamaConfig {
92    fn default() -> Self {
93        Self {
94            base_url: DEFAULT_BASE_URL.to_string(),
95            timeout_secs: DEFAULT_TIMEOUT_SECS,
96        }
97    }
98}
99
100#[derive(Serialize)]
101struct ChatRequest<'a> {
102    model: &'a str,
103    messages: Vec<ChatMessage<'a>>,
104    stream: bool,
105    options: ChatOptions,
106}
107
108#[derive(Serialize)]
109struct ChatOptions {
110    temperature: f32,
111    num_predict: i32,
112}
113
114#[derive(Serialize)]
115struct ChatMessage<'a> {
116    role: &'a str,
117    content: &'a str,
118}
119
120#[derive(Deserialize)]
121struct ChatResponse {
122    message: Option<ResponseMessage>,
123}
124
125#[derive(Deserialize)]
126struct ResponseMessage {
127    content: String,
128}
129
130#[derive(Deserialize)]
131struct StreamResponse {
132    message: Option<ResponseMessage>,
133    #[serde(default)]
134    done: bool,
135}
136
137#[derive(Deserialize)]
138struct TagsResponse {
139    models: Option<Vec<ModelInfo>>,
140}
141
142#[derive(Deserialize)]
143struct ModelInfo {
144    name: String,
145}
146
147impl OllamaClient {
148    /// Create a client with default settings (localhost:11434, 60s timeout).
149    ///
150    /// # Panics
151    ///
152    /// Panics if the HTTP client cannot be initialized. This should only happen
153    /// if the TLS backend is unavailable on the platform.
154    pub fn new() -> Self {
155        Self::with_config(OllamaConfig::default())
156            .expect("failed to build HTTP client with default OllamaConfig")
157    }
158
159    /// Create a client with custom configuration.
160    ///
161    /// # Errors
162    ///
163    /// Returns [`OllamaError::ConnectionFailed`] if the HTTP client cannot be
164    /// initialized (e.g. TLS backend unavailable).
165    pub fn with_config(config: OllamaConfig) -> Result<Self, OllamaError> {
166        let http = reqwest::Client::builder()
167            .timeout(std::time::Duration::from_secs(config.timeout_secs))
168            .build()
169            .map_err(|e| OllamaError::ConnectionFailed(e.to_string()))?;
170
171        Ok(Self {
172            http,
173            base_url: config.base_url,
174        })
175    }
176
177    /// Check if Ollama is running and reachable.
178    pub async fn is_available(&self) -> bool {
179        let url = format!("{}/api/tags", self.base_url);
180        self.http
181            .get(&url)
182            .timeout(std::time::Duration::from_secs(3))
183            .send()
184            .await
185            .map(|r| r.status().is_success())
186            .unwrap_or(false)
187    }
188
189    /// Check if a specific model is pulled locally.
190    pub async fn has_model(&self, model: &str) -> bool {
191        let url = format!("{}/api/tags", self.base_url);
192        match self.http.get(&url).send().await {
193            Ok(resp) => {
194                if let Ok(tags) = resp.json::<TagsResponse>().await {
195                    if let Some(models) = tags.models {
196                        return models
197                            .iter()
198                            .any(|m| m.name == model || m.name.starts_with(&format!("{model}:")));
199                    }
200                }
201                false
202            }
203            Err(_) => false,
204        }
205    }
206
207    /// Send a completion request (non-streaming).
208    ///
209    /// Retries once on transient connection errors.
210    pub async fn complete(
211        &self,
212        model: &str,
213        system: &str,
214        user_message: &str,
215        temperature: f32,
216        max_tokens: i32,
217    ) -> Result<String> {
218        let body = ChatRequest {
219            model,
220            messages: vec![
221                ChatMessage {
222                    role: "system",
223                    content: system,
224                },
225                ChatMessage {
226                    role: "user",
227                    content: user_message,
228                },
229            ],
230            stream: false,
231            options: ChatOptions {
232                temperature,
233                num_predict: max_tokens,
234            },
235        };
236
237        match self.send_request(&body).await {
238            Ok(text) => return Ok(text),
239            Err(e) => {
240                if Self::is_retryable(&e) {
241                    tracing::warn!(
242                        "Ollama retryable error: {e} — retrying in {RETRY_BACKOFF_MS}ms"
243                    );
244                    tokio::time::sleep(std::time::Duration::from_millis(RETRY_BACKOFF_MS)).await;
245                } else {
246                    return Err(e);
247                }
248            }
249        }
250
251        self.send_request(&body).await
252    }
253
254    /// Send a completion with full message history (multi-turn conversation).
255    pub async fn complete_with_history(
256        &self,
257        model: &str,
258        messages: &[(&str, &str)], // (role, content) pairs
259        temperature: f32,
260        max_tokens: i32,
261    ) -> Result<String> {
262        let chat_messages: Vec<ChatMessage<'_>> = messages
263            .iter()
264            .map(|(role, content)| ChatMessage { role, content })
265            .collect();
266
267        let body = ChatRequest {
268            model,
269            messages: chat_messages,
270            stream: false,
271            options: ChatOptions {
272                temperature,
273                num_predict: max_tokens,
274            },
275        };
276
277        self.send_request(&body).await
278    }
279
280    async fn send_request(&self, body: &ChatRequest<'_>) -> Result<String> {
281        let url = format!("{}/api/chat", self.base_url);
282        let resp = self
283            .http
284            .post(&url)
285            .json(body)
286            .send()
287            .await
288            .context("Failed to reach Ollama — is it running? (ollama serve)")?;
289
290        let status = resp.status();
291        if !status.is_success() {
292            let body_text = resp.text().await.unwrap_or_default();
293            anyhow::bail!("Ollama error ({}): {}", status.as_u16(), body_text);
294        }
295
296        let response: ChatResponse = resp
297            .json()
298            .await
299            .context("Failed to parse Ollama response")?;
300
301        let text = response.message.map(|m| m.content).unwrap_or_default();
302
303        if text.trim().is_empty() {
304            anyhow::bail!("Empty response from Ollama");
305        }
306
307        Ok(text)
308    }
309
310    /// Streaming completion — yields text chunks via an mpsc channel.
311    ///
312    /// ```rust,no_run
313    /// # use laminae_ollama::OllamaClient;
314    /// # async fn example() -> anyhow::Result<()> {
315    /// let client = OllamaClient::new();
316    /// let mut rx = client.complete_streaming(
317    ///     "llama3.2", "You are helpful.", "Hello!", 0.7, 256,
318    /// ).await?;
319    ///
320    /// while let Some(chunk) = rx.recv().await {
321    ///     print!("{chunk}");
322    /// }
323    /// # Ok(())
324    /// # }
325    /// ```
326    pub async fn complete_streaming(
327        &self,
328        model: &str,
329        system: &str,
330        user_message: &str,
331        temperature: f32,
332        max_tokens: i32,
333    ) -> Result<mpsc::Receiver<String>> {
334        let (tx, rx) = mpsc::channel(64);
335
336        let url = format!("{}/api/chat", self.base_url);
337        let body = ChatRequest {
338            model,
339            messages: vec![
340                ChatMessage {
341                    role: "system",
342                    content: system,
343                },
344                ChatMessage {
345                    role: "user",
346                    content: user_message,
347                },
348            ],
349            stream: true,
350            options: ChatOptions {
351                temperature,
352                num_predict: max_tokens,
353            },
354        };
355
356        let resp = self
357            .http
358            .post(&url)
359            .json(&body)
360            .send()
361            .await
362            .context("Failed to reach Ollama for streaming")?;
363
364        if !resp.status().is_success() {
365            let status = resp.status();
366            let body_text = resp.text().await.unwrap_or_default();
367            anyhow::bail!(
368                "Ollama streaming error ({}): {}",
369                status.as_u16(),
370                body_text
371            );
372        }
373
374        tokio::spawn(async move {
375            let mut stream = resp.bytes_stream();
376            let mut buffer = String::new();
377
378            while let Some(chunk_result) = stream.next().await {
379                let bytes = match chunk_result {
380                    Ok(b) => b,
381                    Err(_) => break,
382                };
383
384                buffer.push_str(&String::from_utf8_lossy(&bytes));
385
386                while let Some(newline_pos) = buffer.find('\n') {
387                    let line = buffer[..newline_pos].to_string();
388                    buffer = buffer[newline_pos + 1..].to_string();
389
390                    if line.trim().is_empty() {
391                        continue;
392                    }
393
394                    if let Ok(resp) = serde_json::from_str::<StreamResponse>(&line) {
395                        if let Some(msg) = resp.message {
396                            if !msg.content.is_empty() && tx.send(msg.content).await.is_err() {
397                                return;
398                            }
399                        }
400                        if resp.done {
401                            return;
402                        }
403                    }
404                }
405            }
406        });
407
408        Ok(rx)
409    }
410
411    fn is_retryable(error: &anyhow::Error) -> bool {
412        let msg = error.to_string();
413        msg.contains("connection refused")
414            || msg.contains("timeout")
415            || msg.contains("Connection reset")
416    }
417}
418
419impl Default for OllamaClient {
420    fn default() -> Self {
421        Self::new()
422    }
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428
429    #[test]
430    fn test_client_creation() {
431        let client = OllamaClient::new();
432        assert_eq!(client.base_url, DEFAULT_BASE_URL);
433    }
434
435    #[test]
436    fn test_custom_config() {
437        let client = OllamaClient::with_config(OllamaConfig {
438            base_url: "http://10.0.0.5:11434".to_string(),
439            timeout_secs: 120,
440        })
441        .expect("failed to build client with custom config");
442        assert_eq!(client.base_url, "http://10.0.0.5:11434");
443    }
444
445    #[test]
446    fn test_retryable_errors() {
447        assert!(OllamaClient::is_retryable(&anyhow::anyhow!(
448            "connection refused"
449        )));
450        assert!(OllamaClient::is_retryable(&anyhow::anyhow!(
451            "request timeout"
452        )));
453        assert!(!OllamaClient::is_retryable(&anyhow::anyhow!(
454            "model not found"
455        )));
456    }
457
458    #[test]
459    fn test_default_config() {
460        let config = OllamaConfig::default();
461        assert_eq!(config.base_url, DEFAULT_BASE_URL);
462        assert_eq!(config.timeout_secs, 60);
463    }
464}