Skip to main content

chat_ollama/
lib.rs

1//! Ollama provider for chat-rs.
2//!
3//! Thin wrapper around [`chat_completions`] — Ollama serves an
4//! OpenAI-compatible `/v1/chat/completions` endpoint, so all chat,
5//! streaming, tools, structured output, and embedding logic lives in the
6//! `chat-completions` crate.
7//!
8//! What this crate adds on top:
9//! - Default base URL pointing at the local daemon
10//! - `OLLAMA_HOST` env var support
11//! - [`OllamaBuilder::pull`] to ensure the model is present before the
12//!   first request, hitting Ollama's native `/api/pull`. Returns the
13//!   builder so it slots into the normal chain.
14//!
15//! ```no_run
16//! # async fn run() -> Result<(), Box<dyn std::error::Error>> {
17//! use chat_ollama::OllamaBuilder;
18//!
19//! let client = OllamaBuilder::new()
20//!     .with_model("llama3.2")
21//!     .pull().await?
22//!     .build();
23//! # Ok(()) }
24//! ```
25
26use std::marker::PhantomData;
27
28use chat_completions::{
29    ChatError, CompletionsBuilder, CompletionsClient, Request, ReqwestTransport, Transport,
30    TransportError,
31};
32use serde::Deserialize;
33use serde_json::json;
34
35/// Default Ollama base host when `OLLAMA_HOST` is not set.
36pub const DEFAULT_OLLAMA_HOST: &str = "http://localhost:11434";
37
38const OLLAMA_HOST_ENV: &str = "OLLAMA_HOST";
39
40pub struct WithoutModel;
41pub struct WithModel;
42
43/// Ollama-flavored builder. Wraps [`CompletionsBuilder`] and adds
44/// `/api/pull` integration so a model can be fetched at build time.
45pub struct OllamaBuilder<M = WithoutModel, T: Transport = ReqwestTransport> {
46    scheme: String,
47    host: String,
48    model: Option<String>,
49    api_key: Option<String>,
50    extra_headers: Vec<(String, String)>,
51    description: Option<String>,
52    transport: Option<T>,
53    _m: PhantomData<M>,
54}
55
56impl Default for OllamaBuilder<WithoutModel, ReqwestTransport> {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl OllamaBuilder<WithoutModel, ReqwestTransport> {
63    /// Build pointed at `OLLAMA_HOST` if set, otherwise `http://localhost:11434`.
64    pub fn new() -> Self {
65        let host =
66            std::env::var(OLLAMA_HOST_ENV).unwrap_or_else(|_| DEFAULT_OLLAMA_HOST.to_string());
67        Self::with_host(host)
68    }
69
70    /// Build pointed at the given host. Accepts plain `http://host:port`
71    /// or a URL with a `/v1` suffix (the suffix is stripped — Ollama's
72    /// pull endpoint lives outside `/v1`).
73    pub fn with_host(host: impl AsRef<str>) -> Self {
74        let parsed = url::Url::parse(host.as_ref()).expect("Invalid Ollama host URL");
75        let scheme = parsed.scheme().to_string();
76        let host_port = parsed
77            .host_str()
78            .expect("Ollama host URL missing host")
79            .to_string()
80            + &parsed.port().map(|p| format!(":{p}")).unwrap_or_default();
81
82        Self {
83            scheme,
84            host: host_port,
85            model: None,
86            api_key: None,
87            extra_headers: Vec::new(),
88            description: None,
89            transport: Some(ReqwestTransport::default()),
90            _m: PhantomData,
91        }
92    }
93}
94
95impl<M, T: Transport> OllamaBuilder<M, T> {
96    /// Confirm the Ollama daemon is reachable at the configured host.
97    ///
98    /// Hits `/api/version`. Returns `Ok(())` if anything answers
99    /// (including 4xx — the daemon is alive, just rejecting the
100    /// request), or [`ChatError::Provider`] with an actionable install
101    /// hint when the connection is refused.
102    pub async fn ping(&self) -> Result<(), ChatError> {
103        let transport = self.transport.as_ref().expect("transport set");
104        let req = Request {
105            scheme: self.scheme.clone(),
106            host: self.host.clone(),
107            path: "/api/version".to_string(),
108            headers: vec![("Content-Type".into(), "application/json".into())],
109            body: Vec::new(),
110        };
111        match transport.send(req).await {
112            Ok(_) => Ok(()),
113            Err(e) => Err(map_transport_error(&self.scheme, &self.host, e)),
114        }
115    }
116
117    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
118        self.api_key = Some(api_key.into());
119        self
120    }
121
122    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
123        self.extra_headers.push((key.into(), value.into()));
124        self
125    }
126
127    pub fn with_description(mut self, description: impl Into<String>) -> Self {
128        self.description = Some(description.into());
129        self
130    }
131
132    pub fn with_transport<T2: Transport>(self, transport: T2) -> OllamaBuilder<M, T2> {
133        OllamaBuilder {
134            scheme: self.scheme,
135            host: self.host,
136            model: self.model,
137            api_key: self.api_key,
138            extra_headers: self.extra_headers,
139            description: self.description,
140            transport: Some(transport),
141            _m: PhantomData,
142        }
143    }
144}
145
146impl<T: Transport> OllamaBuilder<WithoutModel, T> {
147    pub fn with_model(self, model: impl Into<String>) -> OllamaBuilder<WithModel, T> {
148        OllamaBuilder {
149            scheme: self.scheme,
150            host: self.host,
151            model: Some(model.into()),
152            api_key: self.api_key,
153            extra_headers: self.extra_headers,
154            description: self.description,
155            transport: self.transport,
156            _m: PhantomData,
157        }
158    }
159}
160
161impl<T: Transport> OllamaBuilder<WithModel, T> {
162    /// Build the client without contacting the daemon.
163    pub fn build(self) -> CompletionsClient<T> {
164        let transport = self.transport.expect("transport set");
165        let model = self.model.expect("model set");
166
167        let base_url = format!("{}://{}/v1", self.scheme, self.host);
168        let mut b = CompletionsBuilder::new()
169            .with_base_url(base_url)
170            .with_model(model)
171            .with_transport(transport);
172
173        if let Some(key) = self.api_key {
174            b = b.with_api_key(key);
175        }
176        for (k, v) in self.extra_headers {
177            b = b.with_header(k, v);
178        }
179        if let Some(desc) = self.description {
180            b = b.with_description(desc);
181        }
182        b.build()
183    }
184
185    /// Ensure the configured model is downloaded.
186    ///
187    /// Issues a `POST /api/pull` against the daemon with `stream: false`.
188    /// If the model is already present locally this returns near-instantly;
189    /// otherwise it blocks until the download completes (no progress output).
190    ///
191    /// Returns the builder so the caller can keep chaining — pair with
192    /// `.build()` to get a ready-to-use client:
193    ///
194    /// ```no_run
195    /// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
196    /// # use chat_ollama::OllamaBuilder;
197    /// let client = OllamaBuilder::new()
198    ///     .with_model("llama3.2")
199    ///     .pull().await?
200    ///     .build();
201    /// # Ok(()) }
202    /// ```
203    pub async fn pull(self) -> Result<Self, ChatError> {
204        let model = self.model.as_ref().expect("model set");
205        let transport = self.transport.as_ref().expect("transport set");
206
207        let body = serde_json::to_vec(&json!({
208            "model": model,
209            "stream": false,
210        }))
211        .map_err(|e| ChatError::Other(e.to_string()))?;
212
213        let mut headers = vec![("Content-Type".into(), "application/json".into())];
214        if let Some(key) = &self.api_key {
215            headers.push(("Authorization".into(), format!("Bearer {key}")));
216        }
217        headers.extend(self.extra_headers.iter().cloned());
218
219        let req = Request {
220            scheme: self.scheme.clone(),
221            host: self.host.clone(),
222            path: "/api/pull".to_string(),
223            headers,
224            body,
225        };
226
227        let res = transport
228            .send(req)
229            .await
230            .map_err(|e| map_transport_error(&self.scheme, &self.host, e))?;
231        if !(200..300).contains(&res.status) {
232            let body = String::from_utf8_lossy(&res.body);
233            return Err(ChatError::Provider(format!(
234                "Ollama pull failed (HTTP {}): {body}",
235                res.status
236            )));
237        }
238
239        #[derive(Deserialize)]
240        struct PullResponse {
241            #[serde(default)]
242            status: Option<String>,
243            #[serde(default)]
244            error: Option<String>,
245        }
246
247        let parsed: PullResponse = serde_json::from_slice(&res.body).unwrap_or(PullResponse {
248            status: None,
249            error: None,
250        });
251
252        if let Some(err) = parsed.error {
253            return Err(ChatError::Provider(format!("Ollama pull: {err}")));
254        }
255        if let Some(status) = parsed.status
256            && status != "success"
257            && !status.is_empty()
258        {
259            return Err(ChatError::Provider(format!("Ollama pull status: {status}")));
260        }
261        Ok(self)
262    }
263}
264
265/// Translate a transport-level failure into a [`ChatError`] with an
266/// install/start hint when the daemon was unreachable.
267fn map_transport_error(scheme: &str, host: &str, err: TransportError) -> ChatError {
268    match &err {
269        TransportError::Connection(msg) => ChatError::Provider(format!(
270            "Ollama daemon unreachable at {scheme}://{host} ({msg}). \
271             Install from https://ollama.com/download, then run `ollama serve`. \
272             Override the host with OLLAMA_HOST=http://your-host:port."
273        )),
274        _ => ChatError::from(err),
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    #[test]
283    fn parses_host_only() {
284        let b = OllamaBuilder::with_host("http://localhost:11434");
285        assert_eq!(b.scheme, "http");
286        assert_eq!(b.host, "localhost:11434");
287    }
288
289    #[test]
290    fn parses_host_with_v1_path() {
291        let b = OllamaBuilder::with_host("http://localhost:11434/v1");
292        assert_eq!(b.scheme, "http");
293        assert_eq!(b.host, "localhost:11434");
294    }
295
296    #[test]
297    fn parses_remote_host() {
298        let b = OllamaBuilder::with_host("https://my-ollama.example.com:8443");
299        assert_eq!(b.scheme, "https");
300        assert_eq!(b.host, "my-ollama.example.com:8443");
301    }
302}