llm_stack_ollama/
provider.rs1use std::collections::HashSet;
4
5use llm_stack::ChatResponse;
6use llm_stack::error::LlmError;
7use llm_stack::provider::{Capability, ChatParams, Provider, ProviderMetadata};
8use llm_stack::stream::ChatStream;
9use tracing::instrument;
10
11use crate::config::OllamaConfig;
12use crate::convert;
13
14#[derive(Debug)]
36pub struct OllamaProvider {
37 config: OllamaConfig,
38 client: reqwest::Client,
39}
40
41impl OllamaProvider {
42 pub fn new(config: OllamaConfig) -> Self {
47 let client = config.client.clone().unwrap_or_else(|| {
48 let mut builder = reqwest::Client::builder();
49 if let Some(timeout) = config.timeout {
50 builder = builder.timeout(timeout);
51 }
52 builder.build().expect("failed to build HTTP client")
53 });
54 Self { config, client }
55 }
56
57 fn chat_url(&self) -> String {
59 let base = self.config.base_url.trim_end_matches('/');
60 format!("{base}/api/chat")
61 }
62
63 async fn send_request(
65 &self,
66 params: &ChatParams,
67 stream: bool,
68 ) -> Result<reqwest::Response, LlmError> {
69 let request_body = convert::build_request(params, &self.config, stream)?;
70
71 let mut headers = reqwest::header::HeaderMap::new();
72 headers.insert(
73 "content-type",
74 reqwest::header::HeaderValue::from_static("application/json"),
75 );
76 if let Some(extra) = ¶ms.extra_headers {
77 headers.extend(extra.iter().map(|(k, v)| (k.clone(), v.clone())));
78 }
79
80 let mut req = self
81 .client
82 .post(self.chat_url())
83 .headers(headers)
84 .json(&request_body);
85
86 if let Some(timeout) = params.timeout {
87 req = req.timeout(timeout);
88 }
89
90 let response = req.send().await.map_err(|e| {
91 if e.is_timeout() {
92 LlmError::Timeout {
93 elapsed_ms: params
94 .timeout
95 .or(self.config.timeout)
96 .map_or(0, |d| u64::try_from(d.as_millis()).unwrap_or(u64::MAX)),
97 }
98 } else {
99 LlmError::Http {
100 status: e.status().map(|s| {
101 http::StatusCode::from_u16(s.as_u16())
102 .unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR)
103 }),
104 message: e.to_string(),
105 retryable: e.is_connect() || e.is_timeout(),
106 }
107 }
108 })?;
109
110 let status = response.status();
111 if !status.is_success() {
112 let body = response.text().await.unwrap_or_default();
113 let http_status = http::StatusCode::from_u16(status.as_u16())
114 .unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR);
115 return Err(convert::convert_error(http_status, &body));
116 }
117
118 Ok(response)
119 }
120}
121
122impl Provider for OllamaProvider {
123 #[instrument(skip_all, fields(model = %self.config.model))]
124 async fn generate(&self, params: &ChatParams) -> Result<ChatResponse, LlmError> {
125 let response = self.send_request(params, false).await?;
126
127 let body = response
128 .text()
129 .await
130 .map_err(|e| LlmError::ResponseFormat {
131 message: format!("Failed to read Ollama response body: {e}"),
132 raw: String::new(),
133 })?;
134
135 let api_response: crate::types::Response =
136 serde_json::from_str(&body).map_err(|e| LlmError::ResponseFormat {
137 message: format!("Failed to parse Ollama response: {e}"),
138 raw: body,
139 })?;
140
141 Ok(convert::convert_response(api_response))
142 }
143
144 #[instrument(skip_all, fields(model = %self.config.model))]
145 async fn stream(&self, params: &ChatParams) -> Result<ChatStream, LlmError> {
146 let response = self.send_request(params, true).await?;
147 Ok(crate::stream::into_stream(response))
148 }
149
150 fn metadata(&self) -> ProviderMetadata {
151 let mut capabilities = HashSet::new();
152 capabilities.insert(Capability::Tools);
153 capabilities.insert(Capability::Vision);
154 capabilities.insert(Capability::StructuredOutput);
155
156 ProviderMetadata {
157 name: "ollama".into(),
158 model: self.config.model.clone(),
159 context_window: context_window_for_model(&self.config.model),
160 capabilities,
161 }
162 }
163}
164
165fn context_window_for_model(model: &str) -> u64 {
170 if model.starts_with("mistral") || model.starts_with("mixtral") {
171 32_000
172 } else if model.starts_with("gemma") {
173 8_192
174 } else {
175 128_000
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use std::time::Duration;
183
184 use super::*;
185
186 #[test]
187 fn test_metadata() {
188 let provider = OllamaProvider::new(OllamaConfig {
189 model: "llama3.2".into(),
190 ..Default::default()
191 });
192 let meta = provider.metadata();
193
194 assert_eq!(meta.name, "ollama");
195 assert_eq!(meta.model, "llama3.2");
196 assert_eq!(meta.context_window, 128_000);
197 assert!(meta.capabilities.contains(&Capability::Tools));
198 assert!(meta.capabilities.contains(&Capability::Vision));
199 }
200
201 #[test]
202 fn test_metadata_mistral() {
203 let provider = OllamaProvider::new(OllamaConfig {
204 model: "mistral".into(),
205 ..Default::default()
206 });
207 let meta = provider.metadata();
208 assert_eq!(meta.context_window, 32_000);
209 }
210
211 #[test]
212 fn test_context_window_gemma() {
213 assert_eq!(context_window_for_model("gemma2"), 8_192);
214 }
215
216 #[test]
217 fn test_context_window_unknown() {
218 assert_eq!(context_window_for_model("some-custom-model"), 128_000);
219 }
220
221 #[test]
222 fn test_chat_url() {
223 let provider = OllamaProvider::new(OllamaConfig {
224 base_url: "http://localhost:11434".into(),
225 ..Default::default()
226 });
227 assert_eq!(provider.chat_url(), "http://localhost:11434/api/chat");
228 }
229
230 #[test]
231 fn test_chat_url_trailing_slash() {
232 let provider = OllamaProvider::new(OllamaConfig {
233 base_url: "http://remote:11434/".into(),
234 ..Default::default()
235 });
236 assert_eq!(provider.chat_url(), "http://remote:11434/api/chat");
237 }
238
239 #[test]
240 fn test_new_with_custom_client() {
241 let custom_client = reqwest::Client::builder()
242 .timeout(Duration::from_secs(10))
243 .build()
244 .unwrap();
245
246 let provider = OllamaProvider::new(OllamaConfig {
247 client: Some(custom_client),
248 ..Default::default()
249 });
250 assert_eq!(provider.metadata().name, "ollama");
251 }
252
253 #[test]
254 fn test_new_with_timeout() {
255 let provider = OllamaProvider::new(OllamaConfig {
256 timeout: Some(Duration::from_secs(60)),
257 ..Default::default()
258 });
259 assert_eq!(provider.metadata().name, "ollama");
260 }
261}