ricecoder_providers/providers/
ollama.rs

1//! Ollama provider implementation
2//!
3//! Supports local model execution via Ollama.
4//! Ollama allows running large language models locally without sending code to external services.
5
6use async_trait::async_trait;
7use reqwest::Client;
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10use std::time::{Duration, SystemTime};
11use tokio::time::sleep;
12use tracing::{debug, error, info, warn};
13
14use super::ollama_config::OllamaConfig;
15use crate::error::ProviderError;
16use crate::models::{Capability, ChatRequest, ChatResponse, FinishReason, ModelInfo, TokenUsage};
17use crate::provider::Provider;
18
19/// Configuration for retry logic
20const MAX_RETRIES: u32 = 3;
21const INITIAL_BACKOFF_MS: u64 = 100;
22const MAX_BACKOFF_MS: u64 = 400;
23
24/// Cache for models with TTL
25struct ModelCache {
26    models: Option<Vec<ModelInfo>>,
27    cached_at: Option<SystemTime>,
28    ttl: Duration,
29}
30
31impl ModelCache {
32    /// Create a new model cache with default TTL (5 minutes)
33    fn new() -> Self {
34        Self {
35            models: None,
36            cached_at: None,
37            ttl: Duration::from_secs(300), // 5 minutes default
38        }
39    }
40
41    /// Create a new model cache with custom TTL
42    /// Reserved for future use when configurable TTL is needed
43    #[allow(dead_code)]
44    fn with_ttl(ttl: Duration) -> Self {
45        Self {
46            models: None,
47            cached_at: None,
48            ttl,
49        }
50    }
51
52    /// Check if cache is still valid
53    fn is_valid(&self) -> bool {
54        if let (Some(cached_at), Some(_)) = (self.cached_at, &self.models) {
55            if let Ok(elapsed) = cached_at.elapsed() {
56                return elapsed < self.ttl;
57            }
58        }
59        false
60    }
61
62    /// Get cached models if valid
63    fn get(&self) -> Option<Vec<ModelInfo>> {
64        if self.is_valid() {
65            self.models.clone()
66        } else {
67            None
68        }
69    }
70
71    /// Set cached models
72    fn set(&mut self, models: Vec<ModelInfo>) {
73        self.models = Some(models);
74        self.cached_at = Some(SystemTime::now());
75    }
76
77    /// Get cached models even if expired (for fallback)
78    fn get_stale(&self) -> Option<Vec<ModelInfo>> {
79        self.models.clone()
80    }
81
82    /// Clear the cache
83    /// Reserved for future use when cache invalidation is needed
84    #[allow(dead_code)]
85    fn clear(&mut self) {
86        self.models = None;
87        self.cached_at = None;
88    }
89}
90
91/// Ollama provider implementation
92pub struct OllamaProvider {
93    client: Arc<Client>,
94    base_url: String,
95    available_models: Vec<ModelInfo>,
96    model_cache: Arc<tokio::sync::Mutex<ModelCache>>,
97}
98
99/// Helper function to determine if an error is transient (retryable)
100fn is_transient_error(err: &reqwest::Error) -> bool {
101    err.is_timeout() || err.is_connect() || err.status().is_some_and(|s| s.is_server_error())
102}
103
104/// Execute a request with exponential backoff retry logic
105/// Returns the response if successful, or the last error if all retries fail
106async fn execute_with_retry<F, Fut>(mut request_fn: F) -> Result<reqwest::Response, reqwest::Error>
107where
108    F: FnMut() -> Fut,
109    Fut: std::future::Future<Output = Result<reqwest::Response, reqwest::Error>>,
110{
111    let mut attempt = 0;
112
113    loop {
114        match request_fn().await {
115            Ok(response) => return Ok(response),
116            Err(err) => {
117                // Check if error is transient and we haven't exceeded max retries
118                if is_transient_error(&err) && attempt < MAX_RETRIES {
119                    // Calculate exponential backoff: 100ms, 200ms, 400ms
120                    let backoff_ms = INITIAL_BACKOFF_MS * 2_u64.pow(attempt);
121                    let backoff_ms = backoff_ms.min(MAX_BACKOFF_MS);
122
123                    warn!(
124                        "Transient error on attempt {}/{}, retrying after {}ms: {}",
125                        attempt + 1,
126                        MAX_RETRIES,
127                        backoff_ms,
128                        err
129                    );
130
131                    sleep(Duration::from_millis(backoff_ms)).await;
132                    attempt += 1;
133                } else {
134                    // Permanent error or max retries exceeded
135                    if attempt >= MAX_RETRIES {
136                        debug!("Max retries ({}) exceeded for request", MAX_RETRIES);
137                    }
138                    return Err(err);
139                }
140            }
141        }
142    }
143}
144
145impl OllamaProvider {
146    /// Create a new Ollama provider instance
147    pub fn new(base_url: String) -> Result<Self, ProviderError> {
148        if base_url.is_empty() {
149            return Err(ProviderError::ConfigError(
150                "Ollama base URL is required".to_string(),
151            ));
152        }
153
154        Ok(Self {
155            client: Arc::new(Client::new()),
156            base_url,
157            available_models: vec![],
158            model_cache: Arc::new(tokio::sync::Mutex::new(ModelCache::new())),
159        })
160    }
161
162    /// Create a new Ollama provider with default localhost endpoint
163    pub fn with_default_endpoint() -> Result<Self, ProviderError> {
164        Self::new("http://localhost:11434".to_string())
165    }
166
167    /// Create a new Ollama provider from configuration files
168    /// Loads configuration with proper precedence:
169    /// 1. Environment variables (highest priority)
170    /// 2. Project config (.ricecoder/config.yaml)
171    /// 3. Global config (~/.ricecoder/config.yaml)
172    /// 4. Built-in defaults (lowest priority)
173    pub fn from_config() -> Result<Self, ProviderError> {
174        let config = OllamaConfig::load_with_precedence()?;
175        debug!(
176            "Creating OllamaProvider from configuration: base_url={}, default_model={}",
177            config.base_url, config.default_model
178        );
179        Self::new(config.base_url)
180    }
181
182    /// Get the current configuration
183    pub fn config(&self) -> Result<OllamaConfig, ProviderError> {
184        OllamaConfig::load_with_precedence()
185    }
186
187    /// Detect if Ollama is available at startup
188    /// Returns true if Ollama is running and accessible
189    pub async fn detect_availability(&self) -> bool {
190        debug!("Detecting Ollama availability at {}", self.base_url);
191
192        match self.health_check().await {
193            Ok(true) => {
194                info!("Ollama is available at {}", self.base_url);
195                true
196            }
197            Ok(false) => {
198                warn!("Ollama health check returned false at {}", self.base_url);
199                false
200            }
201            Err(e) => {
202                warn!("Ollama is not available at {}: {}", self.base_url, e);
203                false
204            }
205        }
206    }
207
208    /// Get models with offline fallback
209    /// Returns cached models if available, or default models if offline
210    pub async fn get_models_with_fallback(&self) -> Vec<ModelInfo> {
211        let cache = self.model_cache.lock().await;
212
213        // Try to get valid cached models
214        if let Some(cached_models) = cache.get() {
215            debug!("Returning cached models ({} models)", cached_models.len());
216            return cached_models;
217        }
218
219        // Try to get stale cached models as fallback
220        if let Some(stale_models) = cache.get_stale() {
221            warn!(
222                "Returning stale cached models ({} models) - cache expired",
223                stale_models.len()
224            );
225            return stale_models;
226        }
227
228        // No cache available, return default models
229        debug!("No cached models available, returning defaults for offline mode");
230        vec![
231            ModelInfo {
232                id: "mistral".to_string(),
233                name: "Mistral".to_string(),
234                provider: "ollama".to_string(),
235                context_window: 8192,
236                capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
237                pricing: None,
238            },
239            ModelInfo {
240                id: "neural-chat".to_string(),
241                name: "Neural Chat".to_string(),
242                provider: "ollama".to_string(),
243                context_window: 4096,
244                capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
245                pricing: None,
246            },
247            ModelInfo {
248                id: "llama2".to_string(),
249                name: "Llama 2".to_string(),
250                provider: "ollama".to_string(),
251                context_window: 4096,
252                capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
253                pricing: None,
254            },
255        ]
256    }
257
258    /// Fetch available models from Ollama with caching
259    /// Returns cached models if available and not expired
260    /// Falls back to cached models if Ollama is unavailable
261    pub async fn fetch_models(&mut self) -> Result<(), ProviderError> {
262        debug!("Fetching available models from Ollama");
263
264        // Check if cache is valid
265        let cache = self.model_cache.lock().await;
266        if let Some(cached_models) = cache.get() {
267            debug!("Using cached models ({} models)", cached_models.len());
268            self.available_models = cached_models;
269            return Ok(());
270        }
271
272        // Cache is invalid, fetch from Ollama
273        drop(cache); // Release lock before making network request
274
275        let base_url = self.base_url.clone();
276        let client = self.client.clone();
277
278        let response = execute_with_retry(|| {
279            let client = client.clone();
280            let url = format!("{}/api/tags", base_url);
281            async move { client.get(url).send().await }
282        })
283        .await
284        .map_err(|e| {
285            error!("Failed to fetch models from Ollama after retries: {}", e);
286            ProviderError::NetworkError
287        })?;
288
289        if !response.status().is_success() {
290            return Err(ProviderError::ProviderError(format!(
291                "Ollama API error: {}",
292                response.status()
293            )));
294        }
295
296        let tags_response: OllamaTagsResponse = response.json().await.map_err(|e| {
297            error!("Failed to parse Ollama tags response: {}", e);
298            ProviderError::ProviderError(format!("Failed to parse Ollama response: {}", e))
299        })?;
300
301        // Convert Ollama models to our ModelInfo format
302        self.available_models = tags_response
303            .models
304            .unwrap_or_default()
305            .into_iter()
306            .map(|model| ModelInfo {
307                id: model.name.clone(),
308                name: model.name.clone(),
309                provider: "ollama".to_string(),
310                context_window: 4096, // Default context window for local models
311                capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
312                pricing: None, // Local models have no pricing
313            })
314            .collect();
315
316        // Update cache
317        let mut cache = self.model_cache.lock().await;
318        cache.set(self.available_models.clone());
319
320        debug!("Fetched {} models from Ollama", self.available_models.len());
321        Ok(())
322    }
323
324    /// Convert Ollama API response to our ChatResponse
325    fn convert_response(
326        response: OllamaChatResponse,
327        model: String,
328    ) -> Result<ChatResponse, ProviderError> {
329        Ok(ChatResponse {
330            content: response.message.content,
331            model,
332            usage: TokenUsage {
333                prompt_tokens: 0, // Ollama doesn't provide token counts
334                completion_tokens: 0,
335                total_tokens: 0,
336            },
337            finish_reason: if response.done {
338                FinishReason::Stop
339            } else {
340                FinishReason::Error
341            },
342        })
343    }
344}
345
346#[async_trait]
347impl Provider for OllamaProvider {
348    fn id(&self) -> &str {
349        "ollama"
350    }
351
352    fn name(&self) -> &str {
353        "Ollama"
354    }
355
356    fn models(&self) -> Vec<ModelInfo> {
357        if self.available_models.is_empty() {
358            // Return some common Ollama models as defaults
359            vec![
360                ModelInfo {
361                    id: "mistral".to_string(),
362                    name: "Mistral".to_string(),
363                    provider: "ollama".to_string(),
364                    context_window: 8192,
365                    capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
366                    pricing: None,
367                },
368                ModelInfo {
369                    id: "neural-chat".to_string(),
370                    name: "Neural Chat".to_string(),
371                    provider: "ollama".to_string(),
372                    context_window: 4096,
373                    capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
374                    pricing: None,
375                },
376                ModelInfo {
377                    id: "llama2".to_string(),
378                    name: "Llama 2".to_string(),
379                    provider: "ollama".to_string(),
380                    context_window: 4096,
381                    capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
382                    pricing: None,
383                },
384            ]
385        } else {
386            self.available_models.clone()
387        }
388    }
389
390    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
391        debug!(
392            "Sending chat request to Ollama for model: {}",
393            request.model
394        );
395
396        let ollama_request = OllamaChatRequest {
397            model: request.model.clone(),
398            messages: request
399                .messages
400                .iter()
401                .map(|m| OllamaMessage {
402                    role: m.role.clone(),
403                    content: m.content.clone(),
404                })
405                .collect(),
406            stream: false,
407        };
408
409        let base_url = self.base_url.clone();
410        let client = self.client.clone();
411
412        let response = execute_with_retry(|| {
413            let client = client.clone();
414            let url = format!("{}/api/chat", base_url);
415            let req = ollama_request.clone();
416            async move { client.post(url).json(&req).send().await }
417        })
418        .await
419        .map_err(|e| {
420            error!("Ollama API request failed after retries: {}", e);
421            ProviderError::NetworkError
422        })?;
423
424        let status = response.status();
425        if !status.is_success() {
426            let error_text = response.text().await.unwrap_or_default();
427            error!("Ollama API error ({}): {}", status, error_text);
428
429            return Err(ProviderError::ProviderError(format!(
430                "Ollama API error: {}",
431                status
432            )));
433        }
434
435        let ollama_response: OllamaChatResponse = response.json().await.map_err(|e| {
436            error!("Failed to parse Ollama response: {}", e);
437            ProviderError::ProviderError(format!("Failed to parse Ollama response: {}", e))
438        })?;
439
440        Self::convert_response(ollama_response, request.model)
441    }
442
443    async fn chat_stream(
444        &self,
445        request: ChatRequest,
446    ) -> Result<crate::provider::ChatStream, ProviderError> {
447        debug!(
448            "Starting streaming chat request to Ollama for model: {}",
449            request.model
450        );
451
452        let ollama_request = OllamaChatRequest {
453            model: request.model.clone(),
454            messages: request
455                .messages
456                .iter()
457                .map(|m| OllamaMessage {
458                    role: m.role.clone(),
459                    content: m.content.clone(),
460                })
461                .collect(),
462            stream: true,
463        };
464
465        let base_url = self.base_url.clone();
466        let client = self.client.clone();
467        let model = request.model.clone();
468
469        let response = execute_with_retry(|| {
470            let client = client.clone();
471            let url = format!("{}/api/chat", base_url);
472            let req = ollama_request.clone();
473            async move { client.post(url).json(&req).send().await }
474        })
475        .await
476        .map_err(|e| {
477            error!("Ollama streaming request failed after retries: {}", e);
478            ProviderError::NetworkError
479        })?;
480
481        let status = response.status();
482        if !status.is_success() {
483            return Err(ProviderError::ProviderError(format!(
484                "Ollama API error: {}",
485                status
486            )));
487        }
488
489        // Read the entire response body and parse it line by line
490        // This creates a stream that yields responses as they are parsed
491        let body = response.text().await.map_err(|e| {
492            error!("Failed to read streaming response body: {}", e);
493            ProviderError::NetworkError
494        })?;
495
496        // Parse each line as a JSON object and create a stream
497        let responses: Vec<Result<ChatResponse, ProviderError>> = body
498            .lines()
499            .filter(|line| !line.is_empty())
500            .map(
501                |line| match serde_json::from_str::<OllamaChatResponse>(line) {
502                    Ok(ollama_response) => Ok(ChatResponse {
503                        content: ollama_response.message.content,
504                        model: model.clone(),
505                        usage: TokenUsage {
506                            prompt_tokens: 0,
507                            completion_tokens: 0,
508                            total_tokens: 0,
509                        },
510                        finish_reason: if ollama_response.done {
511                            FinishReason::Stop
512                        } else {
513                            FinishReason::Error
514                        },
515                    }),
516                    Err(e) => {
517                        debug!("Failed to parse streaming response line: {}", e);
518                        Err(ProviderError::ProviderError(format!(
519                            "Failed to parse streaming response: {}",
520                            e
521                        )))
522                    }
523                },
524            )
525            .collect();
526
527        // Convert to a stream
528        let chat_stream = futures::stream::iter(responses);
529        Ok(Box::new(chat_stream))
530    }
531
532    fn count_tokens(&self, content: &str, _model: &str) -> Result<usize, ProviderError> {
533        // Ollama doesn't provide an exact token counting API
534        // Use a reasonable approximation: 1 token ≈ 4 characters
535        let token_count = content.len().div_ceil(4);
536        Ok(token_count)
537    }
538
539    async fn health_check(&self) -> Result<bool, ProviderError> {
540        debug!("Performing health check for Ollama provider");
541
542        let base_url = self.base_url.clone();
543        let client = self.client.clone();
544
545        let response = execute_with_retry(|| {
546            let client = client.clone();
547            let url = format!("{}/api/tags", base_url);
548            async move { client.get(url).send().await }
549        })
550        .await
551        .map_err(|e| {
552            warn!("Ollama health check failed after retries: {}", e);
553            ProviderError::NetworkError
554        })?;
555
556        match response.status().as_u16() {
557            200 => {
558                debug!("Ollama health check passed");
559                Ok(true)
560            }
561            _ => {
562                warn!(
563                    "Ollama health check failed with status: {}",
564                    response.status()
565                );
566                Ok(false)
567            }
568        }
569    }
570}
571
572/// Ollama API chat request format
573#[derive(Debug, Serialize, Clone)]
574struct OllamaChatRequest {
575    model: String,
576    messages: Vec<OllamaMessage>,
577    stream: bool,
578}
579
580/// Ollama API message format
581#[derive(Debug, Serialize, Deserialize, Clone)]
582struct OllamaMessage {
583    role: String,
584    content: String,
585}
586
587/// Ollama API chat response format
588#[derive(Debug, Deserialize)]
589struct OllamaChatResponse {
590    message: OllamaResponseMessage,
591    done: bool,
592}
593
594/// Ollama API response message format
595#[derive(Debug, Deserialize)]
596struct OllamaResponseMessage {
597    #[allow(dead_code)]
598    role: String,
599    content: String,
600}
601
602/// Ollama API tags response format
603#[derive(Debug, Deserialize)]
604struct OllamaTagsResponse {
605    models: Option<Vec<OllamaModel>>,
606}
607
608/// Ollama model information
609#[derive(Debug, Deserialize, Clone)]
610struct OllamaModel {
611    name: String,
612}