Skip to main content

llm_stack_ollama/
provider.rs

1//! Ollama `Provider` implementation.
2
3use std::collections::HashSet;
4
5use llm_stack::ChatResponse;
6use llm_stack::error::LlmError;
7use llm_stack::provider::{Capability, ChatParams, Provider, ProviderMetadata};
8use llm_stack::stream::ChatStream;
9use tracing::instrument;
10
11use crate::config::OllamaConfig;
12use crate::convert;
13
14/// Ollama provider implementing [`Provider`].
15///
16/// Connects to a locally running Ollama instance. No authentication
17/// is required by default.
18///
19/// # Example
20///
21/// ```rust,no_run
22/// use llm_stack_ollama::{OllamaConfig, OllamaProvider};
23/// use llm_stack::{ChatParams, ChatMessage, Provider};
24///
25/// # async fn example() -> Result<(), llm_stack::LlmError> {
26/// let provider = OllamaProvider::new(OllamaConfig::default());
27///
28/// let response = provider.generate(&ChatParams {
29///     messages: vec![ChatMessage::user("Hello!")],
30///     ..Default::default()
31/// }).await?;
32/// # Ok(())
33/// # }
34/// ```
35#[derive(Debug)]
36pub struct OllamaProvider {
37    config: OllamaConfig,
38    client: reqwest::Client,
39}
40
41impl OllamaProvider {
42    /// Create a new Ollama provider from configuration.
43    ///
44    /// If `config.client` is `Some`, that client is reused for connection
45    /// pooling. Otherwise a new client is built with the configured timeout.
46    pub fn new(config: OllamaConfig) -> Self {
47        let client = config.client.clone().unwrap_or_else(|| {
48            let mut builder = reqwest::Client::builder();
49            if let Some(timeout) = config.timeout {
50                builder = builder.timeout(timeout);
51            }
52            builder.build().expect("failed to build HTTP client")
53        });
54        Self { config, client }
55    }
56
57    /// Build the full URL for the chat endpoint.
58    fn chat_url(&self) -> String {
59        let base = self.config.base_url.trim_end_matches('/');
60        format!("{base}/api/chat")
61    }
62
63    /// Send a request to the Ollama API and return the raw response.
64    async fn send_request(
65        &self,
66        params: &ChatParams,
67        stream: bool,
68    ) -> Result<reqwest::Response, LlmError> {
69        let request_body = convert::build_request(params, &self.config, stream)?;
70
71        let mut headers = reqwest::header::HeaderMap::new();
72        headers.insert(
73            "content-type",
74            reqwest::header::HeaderValue::from_static("application/json"),
75        );
76        if let Some(extra) = &params.extra_headers {
77            headers.extend(extra.iter().map(|(k, v)| (k.clone(), v.clone())));
78        }
79
80        let mut req = self
81            .client
82            .post(self.chat_url())
83            .headers(headers)
84            .json(&request_body);
85
86        if let Some(timeout) = params.timeout {
87            req = req.timeout(timeout);
88        }
89
90        let response = req.send().await.map_err(|e| {
91            if e.is_timeout() {
92                LlmError::Timeout {
93                    elapsed_ms: params
94                        .timeout
95                        .or(self.config.timeout)
96                        .map_or(0, |d| u64::try_from(d.as_millis()).unwrap_or(u64::MAX)),
97                }
98            } else {
99                LlmError::Http {
100                    status: e.status().map(|s| {
101                        http::StatusCode::from_u16(s.as_u16())
102                            .unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR)
103                    }),
104                    message: e.to_string(),
105                    retryable: e.is_connect() || e.is_timeout(),
106                }
107            }
108        })?;
109
110        let status = response.status();
111        if !status.is_success() {
112            let body = response.text().await.unwrap_or_default();
113            let http_status = http::StatusCode::from_u16(status.as_u16())
114                .unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR);
115            return Err(convert::convert_error(http_status, &body));
116        }
117
118        Ok(response)
119    }
120}
121
122impl Provider for OllamaProvider {
123    #[instrument(skip_all, fields(model = %self.config.model))]
124    async fn generate(&self, params: &ChatParams) -> Result<ChatResponse, LlmError> {
125        let response = self.send_request(params, false).await?;
126
127        let body = response
128            .text()
129            .await
130            .map_err(|e| LlmError::ResponseFormat {
131                message: format!("Failed to read Ollama response body: {e}"),
132                raw: String::new(),
133            })?;
134
135        let api_response: crate::types::Response =
136            serde_json::from_str(&body).map_err(|e| LlmError::ResponseFormat {
137                message: format!("Failed to parse Ollama response: {e}"),
138                raw: body,
139            })?;
140
141        Ok(convert::convert_response(api_response))
142    }
143
144    #[instrument(skip_all, fields(model = %self.config.model))]
145    async fn stream(&self, params: &ChatParams) -> Result<ChatStream, LlmError> {
146        let response = self.send_request(params, true).await?;
147        Ok(crate::stream::into_stream(response))
148    }
149
150    fn metadata(&self) -> ProviderMetadata {
151        let mut capabilities = HashSet::new();
152        capabilities.insert(Capability::Tools);
153        capabilities.insert(Capability::Vision);
154        capabilities.insert(Capability::StructuredOutput);
155
156        ProviderMetadata {
157            name: "ollama".into(),
158            model: self.config.model.clone(),
159            context_window: context_window_for_model(&self.config.model),
160            capabilities,
161        }
162    }
163}
164
165/// Look up the context window size for known Ollama models.
166///
167/// Defaults to 128K for unknown models as most modern models
168/// support large contexts via Ollama's automatic context extension.
169fn context_window_for_model(model: &str) -> u64 {
170    if model.starts_with("mistral") || model.starts_with("mixtral") {
171        32_000
172    } else if model.starts_with("gemma") {
173        8_192
174    } else {
175        // llama3, phi, qwen, deepseek, and most modern models default to 128K
176        128_000
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use std::time::Duration;
183
184    use super::*;
185
186    #[test]
187    fn test_metadata() {
188        let provider = OllamaProvider::new(OllamaConfig {
189            model: "llama3.2".into(),
190            ..Default::default()
191        });
192        let meta = provider.metadata();
193
194        assert_eq!(meta.name, "ollama");
195        assert_eq!(meta.model, "llama3.2");
196        assert_eq!(meta.context_window, 128_000);
197        assert!(meta.capabilities.contains(&Capability::Tools));
198        assert!(meta.capabilities.contains(&Capability::Vision));
199    }
200
201    #[test]
202    fn test_metadata_mistral() {
203        let provider = OllamaProvider::new(OllamaConfig {
204            model: "mistral".into(),
205            ..Default::default()
206        });
207        let meta = provider.metadata();
208        assert_eq!(meta.context_window, 32_000);
209    }
210
211    #[test]
212    fn test_context_window_gemma() {
213        assert_eq!(context_window_for_model("gemma2"), 8_192);
214    }
215
216    #[test]
217    fn test_context_window_unknown() {
218        assert_eq!(context_window_for_model("some-custom-model"), 128_000);
219    }
220
221    #[test]
222    fn test_chat_url() {
223        let provider = OllamaProvider::new(OllamaConfig {
224            base_url: "http://localhost:11434".into(),
225            ..Default::default()
226        });
227        assert_eq!(provider.chat_url(), "http://localhost:11434/api/chat");
228    }
229
230    #[test]
231    fn test_chat_url_trailing_slash() {
232        let provider = OllamaProvider::new(OllamaConfig {
233            base_url: "http://remote:11434/".into(),
234            ..Default::default()
235        });
236        assert_eq!(provider.chat_url(), "http://remote:11434/api/chat");
237    }
238
239    #[test]
240    fn test_new_with_custom_client() {
241        let custom_client = reqwest::Client::builder()
242            .timeout(Duration::from_secs(10))
243            .build()
244            .unwrap();
245
246        let provider = OllamaProvider::new(OllamaConfig {
247            client: Some(custom_client),
248            ..Default::default()
249        });
250        assert_eq!(provider.metadata().name, "ollama");
251    }
252
253    #[test]
254    fn test_new_with_timeout() {
255        let provider = OllamaProvider::new(OllamaConfig {
256            timeout: Some(Duration::from_secs(60)),
257            ..Default::default()
258        });
259        assert_eq!(provider.metadata().name, "ollama");
260    }
261}