1use crate::core::{HttpClient, Provider};
6use crate::error::LlmConnectorError;
7use crate::types::{ChatRequest, ChatResponse, Choice, Message, Role};
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::any::Any;
11
12#[cfg(feature = "streaming")]
13use crate::types::ChatStream;
14
15#[derive(Clone, Debug)]
20pub struct OllamaProvider {
21 client: HttpClient,
22 base_url: String,
23}
24
25impl OllamaProvider {
26 pub fn new(base_url: &str) -> Result<Self, LlmConnectorError> {
38 let client = HttpClient::new(base_url)?;
40
41 Ok(Self {
42 client,
43 base_url: base_url.to_string(),
44 })
45 }
46
47 pub fn with_config(
49 base_url: &str,
50 timeout_secs: Option<u64>,
51 proxy: Option<&str>,
52 ) -> Result<Self, LlmConnectorError> {
53 let client = HttpClient::with_config(base_url, timeout_secs, proxy)?;
55
56 Ok(Self {
57 client,
58 base_url: base_url.to_string(),
59 })
60 }
61
62 pub async fn pull_model(&self, model_name: &str) -> Result<(), LlmConnectorError> {
78 let request = OllamaPullRequest {
79 name: model_name.to_string(),
80 stream: Some(false),
81 };
82
83 let url = format!("{}/api/pull", self.base_url);
84 let response = self.client.post(&url, &request).await?;
85
86 if !response.status().is_success() {
87 let text = response
88 .text()
89 .await
90 .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
91 return Err(LlmConnectorError::ApiError(format!(
92 "Failed to pull model: {}",
93 text
94 )));
95 }
96
97 Ok(())
98 }
99
100 pub async fn delete_model(&self, model_name: &str) -> Result<(), LlmConnectorError> {
116 let request = OllamaDeleteRequest {
117 name: model_name.to_string(),
118 };
119
120 let url = format!("{}/api/delete", self.base_url);
121 let response = self.client.post(&url, &request).await?;
122
123 if !response.status().is_success() {
124 let text = response
125 .text()
126 .await
127 .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
128 return Err(LlmConnectorError::ApiError(format!(
129 "Failed to delete model: {}",
130 text
131 )));
132 }
133
134 Ok(())
135 }
136
137 pub async fn show_model(&self, model_name: &str) -> Result<OllamaModelInfo, LlmConnectorError> {
145 let request = OllamaShowRequest {
146 name: model_name.to_string(),
147 };
148
149 let url = format!("{}/api/show", self.base_url);
150 let response = self.client.post(&url, &request).await?;
151 let status = response.status();
152 let text = response
153 .text()
154 .await
155 .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
156
157 if !status.is_success() {
158 return Err(LlmConnectorError::ApiError(format!(
159 "Failed to show model: {}",
160 text
161 )));
162 }
163
164 serde_json::from_str(&text).map_err(|e| {
165 LlmConnectorError::ParseError(format!("Failed to parse model info: {}", e))
166 })
167 }
168
169 pub async fn model_exists(&self, model_name: &str) -> Result<bool, LlmConnectorError> {
171 match self.show_model(model_name).await {
172 Ok(_) => Ok(true),
173 Err(LlmConnectorError::ApiError(_)) => Ok(false),
174 Err(e) => Err(e),
175 }
176 }
177}
178
179#[async_trait]
180impl Provider for OllamaProvider {
181 fn name(&self) -> &str {
182 "ollama"
183 }
184
185 async fn models(&self) -> Result<Vec<String>, LlmConnectorError> {
186 let url = format!("{}/api/tags", self.base_url);
187 let response = self.client.get(&url).await?;
188 let status = response.status();
189 let text = response
190 .text()
191 .await
192 .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
193
194 if !status.is_success() {
195 return Err(LlmConnectorError::ApiError(format!(
196 "Failed to get models: {}",
197 text
198 )));
199 }
200
201 let models_response: OllamaModelsResponse = serde_json::from_str(&text)
202 .map_err(|e| LlmConnectorError::ParseError(format!("Failed to parse models: {}", e)))?;
203
204 Ok(models_response.models.into_iter().map(|m| m.name).collect())
205 }
206
207 async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, LlmConnectorError> {
208 let ollama_request = OllamaChatRequest {
209 model: request.model.clone(),
210 messages: request
211 .messages
212 .iter()
213 .map(|msg| OllamaMessage {
214 role: match msg.role {
215 Role::User => "user".to_string(),
216 Role::Assistant => "assistant".to_string(),
217 Role::System => "system".to_string(),
218 Role::Tool => "user".to_string(), },
220 content: msg.content_as_text(),
222 })
223 .collect(),
224 stream: Some(false),
225 options: Some(OllamaOptions {
226 temperature: request.temperature,
227 num_predict: request.max_tokens.map(|t| t as i32),
228 top_p: request.top_p,
229 }),
230 };
231
232 let url = format!("{}/api/chat", self.base_url);
233 let response = self.client.post(&url, &ollama_request).await?;
234 let status = response.status();
235 let text = response
236 .text()
237 .await
238 .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
239
240 if !status.is_success() {
241 return Err(LlmConnectorError::ApiError(format!(
242 "Ollama chat failed: {}",
243 text
244 )));
245 }
246
247 let ollama_response: OllamaChatResponse = serde_json::from_str(&text).map_err(|e| {
248 LlmConnectorError::ParseError(format!("Failed to parse Ollama response: {}", e))
249 })?;
250
251 let content = ollama_response.message.content.clone();
252
253 let choices = vec![Choice {
254 index: 0,
255 message: Message {
256 role: Role::Assistant,
257 content: vec![crate::types::MessageBlock::text(&content)],
258 name: None,
259 tool_calls: None,
260 tool_call_id: None,
261 reasoning_content: None,
262 reasoning: None,
263 thought: None,
264 thinking: None,
265 },
266 finish_reason: Some("stop".to_string()),
267 logprobs: None,
268 }];
269
270 Ok(ChatResponse {
271 id: "ollama-response".to_string(),
272 object: "chat.completion".to_string(),
273 created: std::time::SystemTime::now()
274 .duration_since(std::time::UNIX_EPOCH)
275 .unwrap_or_default()
276 .as_secs(),
277 model: ollama_response.model,
278 choices,
279 content,
280 reasoning_content: None,
281 usage: None, system_fingerprint: None,
283 })
284 }
285
286 #[cfg(feature = "streaming")]
287 async fn chat_stream(&self, request: &ChatRequest) -> Result<ChatStream, LlmConnectorError> {
288 let ollama_request = OllamaChatRequest {
289 model: request.model.clone(),
290 messages: request
291 .messages
292 .iter()
293 .map(|msg| OllamaMessage {
294 role: match msg.role {
295 Role::User => "user".to_string(),
296 Role::Assistant => "assistant".to_string(),
297 Role::System => "system".to_string(),
298 Role::Tool => "user".to_string(),
299 },
300 content: msg.content_as_text(),
301 })
302 .collect(),
303 stream: Some(true),
304 options: Some(OllamaOptions {
305 temperature: request.temperature,
306 num_predict: request.max_tokens.map(|t| t as i32),
307 top_p: request.top_p,
308 }),
309 };
310
311 let url = format!("{}/api/chat", self.base_url);
312 let response = self.client.stream(&url, &ollama_request).await?;
313 let status = response.status();
314
315 if !status.is_success() {
316 let text = response
317 .text()
318 .await
319 .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
320 return Err(LlmConnectorError::ApiError(format!(
321 "Ollama stream failed: {}",
322 text
323 )));
324 }
325
326 Ok(crate::sse::sse_to_streaming_response(response))
328 }
329
330 fn as_any(&self) -> &dyn Any {
331 self
332 }
333}
334
335#[derive(Serialize, Debug)]
337struct OllamaChatRequest {
338 model: String,
339 messages: Vec<OllamaMessage>,
340 stream: Option<bool>,
341 options: Option<OllamaOptions>,
342}
343
344#[derive(Serialize, Debug)]
345struct OllamaMessage {
346 role: String,
347 content: String,
348}
349
350#[derive(Serialize, Debug)]
351struct OllamaOptions {
352 #[serde(skip_serializing_if = "Option::is_none")]
353 temperature: Option<f32>,
354 #[serde(skip_serializing_if = "Option::is_none")]
355 num_predict: Option<i32>,
356 #[serde(skip_serializing_if = "Option::is_none")]
357 top_p: Option<f32>,
358}
359
360#[derive(Deserialize, Debug)]
361struct OllamaChatResponse {
362 model: String,
363 message: OllamaResponseMessage,
364 #[allow(dead_code)]
365 done: bool,
366}
367
368#[derive(Deserialize, Debug)]
369struct OllamaResponseMessage {
370 #[allow(dead_code)]
371 role: String,
372 content: String,
373}
374
375#[derive(Serialize, Debug)]
376struct OllamaPullRequest {
377 name: String,
378 stream: Option<bool>,
379}
380
381#[derive(Serialize, Debug)]
382struct OllamaDeleteRequest {
383 name: String,
384}
385
386#[derive(Serialize, Debug)]
387struct OllamaShowRequest {
388 name: String,
389}
390
391#[derive(Deserialize, Debug)]
392pub struct OllamaModelInfo {
393 pub modelfile: String,
394 pub parameters: String,
395 pub template: String,
396 pub details: OllamaModelDetails,
397}
398
399#[derive(Deserialize, Debug)]
400pub struct OllamaModelDetails {
401 pub format: String,
402 pub family: String,
403 pub families: Option<Vec<String>>,
404 pub parameter_size: String,
405 pub quantization_level: String,
406}
407
408#[derive(Deserialize, Debug)]
409struct OllamaModelsResponse {
410 models: Vec<OllamaModel>,
411}
412
413#[derive(Deserialize, Debug)]
414struct OllamaModel {
415 name: String,
416 #[allow(dead_code)]
417 modified_at: String,
418 #[allow(dead_code)]
419 size: u64,
420}
421
422pub fn ollama() -> Result<OllamaProvider, LlmConnectorError> {
431 OllamaProvider::new("http://localhost:11434")
432}
433
434pub fn ollama_with_base_url(base_url: &str) -> Result<OllamaProvider, LlmConnectorError> {
446 OllamaProvider::new(base_url)
447}
448
449pub fn ollama_with_config(
467 base_url: &str,
468 timeout_secs: Option<u64>,
469 proxy: Option<&str>,
470) -> Result<OllamaProvider, LlmConnectorError> {
471 OllamaProvider::with_config(base_url, timeout_secs, proxy)
472}