1use async_trait::async_trait;
7use reqwest::Client;
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10use std::time::{Duration, SystemTime};
11use tokio::time::sleep;
12use tracing::{debug, error, info, warn};
13
14use super::ollama_config::OllamaConfig;
15use crate::error::ProviderError;
16use crate::models::{Capability, ChatRequest, ChatResponse, FinishReason, ModelInfo, TokenUsage};
17use crate::provider::Provider;
18
19const MAX_RETRIES: u32 = 3;
21const INITIAL_BACKOFF_MS: u64 = 100;
22const MAX_BACKOFF_MS: u64 = 400;
23
24struct ModelCache {
26 models: Option<Vec<ModelInfo>>,
27 cached_at: Option<SystemTime>,
28 ttl: Duration,
29}
30
31impl ModelCache {
32 fn new() -> Self {
34 Self {
35 models: None,
36 cached_at: None,
37 ttl: Duration::from_secs(300), }
39 }
40
41 #[allow(dead_code)]
44 fn with_ttl(ttl: Duration) -> Self {
45 Self {
46 models: None,
47 cached_at: None,
48 ttl,
49 }
50 }
51
52 fn is_valid(&self) -> bool {
54 if let (Some(cached_at), Some(_)) = (self.cached_at, &self.models) {
55 if let Ok(elapsed) = cached_at.elapsed() {
56 return elapsed < self.ttl;
57 }
58 }
59 false
60 }
61
62 fn get(&self) -> Option<Vec<ModelInfo>> {
64 if self.is_valid() {
65 self.models.clone()
66 } else {
67 None
68 }
69 }
70
71 fn set(&mut self, models: Vec<ModelInfo>) {
73 self.models = Some(models);
74 self.cached_at = Some(SystemTime::now());
75 }
76
77 fn get_stale(&self) -> Option<Vec<ModelInfo>> {
79 self.models.clone()
80 }
81
82 #[allow(dead_code)]
85 fn clear(&mut self) {
86 self.models = None;
87 self.cached_at = None;
88 }
89}
90
91pub struct OllamaProvider {
93 client: Arc<Client>,
94 base_url: String,
95 available_models: Vec<ModelInfo>,
96 model_cache: Arc<tokio::sync::Mutex<ModelCache>>,
97}
98
99fn is_transient_error(err: &reqwest::Error) -> bool {
101 err.is_timeout() || err.is_connect() || err.status().is_some_and(|s| s.is_server_error())
102}
103
104async fn execute_with_retry<F, Fut>(mut request_fn: F) -> Result<reqwest::Response, reqwest::Error>
107where
108 F: FnMut() -> Fut,
109 Fut: std::future::Future<Output = Result<reqwest::Response, reqwest::Error>>,
110{
111 let mut attempt = 0;
112
113 loop {
114 match request_fn().await {
115 Ok(response) => return Ok(response),
116 Err(err) => {
117 if is_transient_error(&err) && attempt < MAX_RETRIES {
119 let backoff_ms = INITIAL_BACKOFF_MS * 2_u64.pow(attempt);
121 let backoff_ms = backoff_ms.min(MAX_BACKOFF_MS);
122
123 warn!(
124 "Transient error on attempt {}/{}, retrying after {}ms: {}",
125 attempt + 1,
126 MAX_RETRIES,
127 backoff_ms,
128 err
129 );
130
131 sleep(Duration::from_millis(backoff_ms)).await;
132 attempt += 1;
133 } else {
134 if attempt >= MAX_RETRIES {
136 debug!("Max retries ({}) exceeded for request", MAX_RETRIES);
137 }
138 return Err(err);
139 }
140 }
141 }
142 }
143}
144
145impl OllamaProvider {
146 pub fn new(base_url: String) -> Result<Self, ProviderError> {
148 if base_url.is_empty() {
149 return Err(ProviderError::ConfigError(
150 "Ollama base URL is required".to_string(),
151 ));
152 }
153
154 Ok(Self {
155 client: Arc::new(Client::new()),
156 base_url,
157 available_models: vec![],
158 model_cache: Arc::new(tokio::sync::Mutex::new(ModelCache::new())),
159 })
160 }
161
162 pub fn with_default_endpoint() -> Result<Self, ProviderError> {
164 Self::new("http://localhost:11434".to_string())
165 }
166
167 pub fn from_config() -> Result<Self, ProviderError> {
174 let config = OllamaConfig::load_with_precedence()?;
175 debug!(
176 "Creating OllamaProvider from configuration: base_url={}, default_model={}",
177 config.base_url, config.default_model
178 );
179 Self::new(config.base_url)
180 }
181
182 pub fn config(&self) -> Result<OllamaConfig, ProviderError> {
184 OllamaConfig::load_with_precedence()
185 }
186
187 pub async fn detect_availability(&self) -> bool {
190 debug!("Detecting Ollama availability at {}", self.base_url);
191
192 match self.health_check().await {
193 Ok(true) => {
194 info!("Ollama is available at {}", self.base_url);
195 true
196 }
197 Ok(false) => {
198 warn!("Ollama health check returned false at {}", self.base_url);
199 false
200 }
201 Err(e) => {
202 warn!("Ollama is not available at {}: {}", self.base_url, e);
203 false
204 }
205 }
206 }
207
208 pub async fn get_models_with_fallback(&self) -> Vec<ModelInfo> {
211 let cache = self.model_cache.lock().await;
212
213 if let Some(cached_models) = cache.get() {
215 debug!("Returning cached models ({} models)", cached_models.len());
216 return cached_models;
217 }
218
219 if let Some(stale_models) = cache.get_stale() {
221 warn!(
222 "Returning stale cached models ({} models) - cache expired",
223 stale_models.len()
224 );
225 return stale_models;
226 }
227
228 debug!("No cached models available, returning defaults for offline mode");
230 vec![
231 ModelInfo {
232 id: "mistral".to_string(),
233 name: "Mistral".to_string(),
234 provider: "ollama".to_string(),
235 context_window: 8192,
236 capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
237 pricing: None,
238 },
239 ModelInfo {
240 id: "neural-chat".to_string(),
241 name: "Neural Chat".to_string(),
242 provider: "ollama".to_string(),
243 context_window: 4096,
244 capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
245 pricing: None,
246 },
247 ModelInfo {
248 id: "llama2".to_string(),
249 name: "Llama 2".to_string(),
250 provider: "ollama".to_string(),
251 context_window: 4096,
252 capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
253 pricing: None,
254 },
255 ]
256 }
257
258 pub async fn fetch_models(&mut self) -> Result<(), ProviderError> {
262 debug!("Fetching available models from Ollama");
263
264 let cache = self.model_cache.lock().await;
266 if let Some(cached_models) = cache.get() {
267 debug!("Using cached models ({} models)", cached_models.len());
268 self.available_models = cached_models;
269 return Ok(());
270 }
271
272 drop(cache); let base_url = self.base_url.clone();
276 let client = self.client.clone();
277
278 let response = execute_with_retry(|| {
279 let client = client.clone();
280 let url = format!("{}/api/tags", base_url);
281 async move { client.get(url).send().await }
282 })
283 .await
284 .map_err(|e| {
285 error!("Failed to fetch models from Ollama after retries: {}", e);
286 ProviderError::NetworkError
287 })?;
288
289 if !response.status().is_success() {
290 return Err(ProviderError::ProviderError(format!(
291 "Ollama API error: {}",
292 response.status()
293 )));
294 }
295
296 let tags_response: OllamaTagsResponse = response.json().await.map_err(|e| {
297 error!("Failed to parse Ollama tags response: {}", e);
298 ProviderError::ProviderError(format!("Failed to parse Ollama response: {}", e))
299 })?;
300
301 self.available_models = tags_response
303 .models
304 .unwrap_or_default()
305 .into_iter()
306 .map(|model| ModelInfo {
307 id: model.name.clone(),
308 name: model.name.clone(),
309 provider: "ollama".to_string(),
310 context_window: 4096, capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
312 pricing: None, })
314 .collect();
315
316 let mut cache = self.model_cache.lock().await;
318 cache.set(self.available_models.clone());
319
320 debug!("Fetched {} models from Ollama", self.available_models.len());
321 Ok(())
322 }
323
324 fn convert_response(
326 response: OllamaChatResponse,
327 model: String,
328 ) -> Result<ChatResponse, ProviderError> {
329 Ok(ChatResponse {
330 content: response.message.content,
331 model,
332 usage: TokenUsage {
333 prompt_tokens: 0, completion_tokens: 0,
335 total_tokens: 0,
336 },
337 finish_reason: if response.done {
338 FinishReason::Stop
339 } else {
340 FinishReason::Error
341 },
342 })
343 }
344}
345
346#[async_trait]
347impl Provider for OllamaProvider {
348 fn id(&self) -> &str {
349 "ollama"
350 }
351
352 fn name(&self) -> &str {
353 "Ollama"
354 }
355
356 fn models(&self) -> Vec<ModelInfo> {
357 if self.available_models.is_empty() {
358 vec![
360 ModelInfo {
361 id: "mistral".to_string(),
362 name: "Mistral".to_string(),
363 provider: "ollama".to_string(),
364 context_window: 8192,
365 capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
366 pricing: None,
367 },
368 ModelInfo {
369 id: "neural-chat".to_string(),
370 name: "Neural Chat".to_string(),
371 provider: "ollama".to_string(),
372 context_window: 4096,
373 capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
374 pricing: None,
375 },
376 ModelInfo {
377 id: "llama2".to_string(),
378 name: "Llama 2".to_string(),
379 provider: "ollama".to_string(),
380 context_window: 4096,
381 capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
382 pricing: None,
383 },
384 ]
385 } else {
386 self.available_models.clone()
387 }
388 }
389
390 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
391 debug!(
392 "Sending chat request to Ollama for model: {}",
393 request.model
394 );
395
396 let ollama_request = OllamaChatRequest {
397 model: request.model.clone(),
398 messages: request
399 .messages
400 .iter()
401 .map(|m| OllamaMessage {
402 role: m.role.clone(),
403 content: m.content.clone(),
404 })
405 .collect(),
406 stream: false,
407 };
408
409 let base_url = self.base_url.clone();
410 let client = self.client.clone();
411
412 let response = execute_with_retry(|| {
413 let client = client.clone();
414 let url = format!("{}/api/chat", base_url);
415 let req = ollama_request.clone();
416 async move { client.post(url).json(&req).send().await }
417 })
418 .await
419 .map_err(|e| {
420 error!("Ollama API request failed after retries: {}", e);
421 ProviderError::NetworkError
422 })?;
423
424 let status = response.status();
425 if !status.is_success() {
426 let error_text = response.text().await.unwrap_or_default();
427 error!("Ollama API error ({}): {}", status, error_text);
428
429 return Err(ProviderError::ProviderError(format!(
430 "Ollama API error: {}",
431 status
432 )));
433 }
434
435 let ollama_response: OllamaChatResponse = response.json().await.map_err(|e| {
436 error!("Failed to parse Ollama response: {}", e);
437 ProviderError::ProviderError(format!("Failed to parse Ollama response: {}", e))
438 })?;
439
440 Self::convert_response(ollama_response, request.model)
441 }
442
443 async fn chat_stream(
444 &self,
445 request: ChatRequest,
446 ) -> Result<crate::provider::ChatStream, ProviderError> {
447 debug!(
448 "Starting streaming chat request to Ollama for model: {}",
449 request.model
450 );
451
452 let ollama_request = OllamaChatRequest {
453 model: request.model.clone(),
454 messages: request
455 .messages
456 .iter()
457 .map(|m| OllamaMessage {
458 role: m.role.clone(),
459 content: m.content.clone(),
460 })
461 .collect(),
462 stream: true,
463 };
464
465 let base_url = self.base_url.clone();
466 let client = self.client.clone();
467 let model = request.model.clone();
468
469 let response = execute_with_retry(|| {
470 let client = client.clone();
471 let url = format!("{}/api/chat", base_url);
472 let req = ollama_request.clone();
473 async move { client.post(url).json(&req).send().await }
474 })
475 .await
476 .map_err(|e| {
477 error!("Ollama streaming request failed after retries: {}", e);
478 ProviderError::NetworkError
479 })?;
480
481 let status = response.status();
482 if !status.is_success() {
483 return Err(ProviderError::ProviderError(format!(
484 "Ollama API error: {}",
485 status
486 )));
487 }
488
489 let body = response.text().await.map_err(|e| {
492 error!("Failed to read streaming response body: {}", e);
493 ProviderError::NetworkError
494 })?;
495
496 let responses: Vec<Result<ChatResponse, ProviderError>> = body
498 .lines()
499 .filter(|line| !line.is_empty())
500 .map(
501 |line| match serde_json::from_str::<OllamaChatResponse>(line) {
502 Ok(ollama_response) => Ok(ChatResponse {
503 content: ollama_response.message.content,
504 model: model.clone(),
505 usage: TokenUsage {
506 prompt_tokens: 0,
507 completion_tokens: 0,
508 total_tokens: 0,
509 },
510 finish_reason: if ollama_response.done {
511 FinishReason::Stop
512 } else {
513 FinishReason::Error
514 },
515 }),
516 Err(e) => {
517 debug!("Failed to parse streaming response line: {}", e);
518 Err(ProviderError::ProviderError(format!(
519 "Failed to parse streaming response: {}",
520 e
521 )))
522 }
523 },
524 )
525 .collect();
526
527 let chat_stream = futures::stream::iter(responses);
529 Ok(Box::new(chat_stream))
530 }
531
532 fn count_tokens(&self, content: &str, _model: &str) -> Result<usize, ProviderError> {
533 let token_count = content.len().div_ceil(4);
536 Ok(token_count)
537 }
538
539 async fn health_check(&self) -> Result<bool, ProviderError> {
540 debug!("Performing health check for Ollama provider");
541
542 let base_url = self.base_url.clone();
543 let client = self.client.clone();
544
545 let response = execute_with_retry(|| {
546 let client = client.clone();
547 let url = format!("{}/api/tags", base_url);
548 async move { client.get(url).send().await }
549 })
550 .await
551 .map_err(|e| {
552 warn!("Ollama health check failed after retries: {}", e);
553 ProviderError::NetworkError
554 })?;
555
556 match response.status().as_u16() {
557 200 => {
558 debug!("Ollama health check passed");
559 Ok(true)
560 }
561 _ => {
562 warn!(
563 "Ollama health check failed with status: {}",
564 response.status()
565 );
566 Ok(false)
567 }
568 }
569 }
570}
571
572#[derive(Debug, Serialize, Clone)]
574struct OllamaChatRequest {
575 model: String,
576 messages: Vec<OllamaMessage>,
577 stream: bool,
578}
579
580#[derive(Debug, Serialize, Deserialize, Clone)]
582struct OllamaMessage {
583 role: String,
584 content: String,
585}
586
587#[derive(Debug, Deserialize)]
589struct OllamaChatResponse {
590 message: OllamaResponseMessage,
591 done: bool,
592}
593
594#[derive(Debug, Deserialize)]
596struct OllamaResponseMessage {
597 #[allow(dead_code)]
598 role: String,
599 content: String,
600}
601
602#[derive(Debug, Deserialize)]
604struct OllamaTagsResponse {
605 models: Option<Vec<OllamaModel>>,
606}
607
608#[derive(Debug, Deserialize, Clone)]
610struct OllamaModel {
611 name: String,
612}