Skip to main content

manx_cli/rag/providers/
ollama.rs

1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use super::{EmbeddingProvider as ProviderTrait, ProviderInfo};
7
8/// Ollama API embedding provider
9pub struct OllamaProvider {
10    client: Client,
11    base_url: String,
12    model: String,
13    dimension: Option<usize>, // Cached dimension
14}
15
16#[derive(Serialize)]
17struct OllamaEmbeddingRequest {
18    model: String,
19    prompt: String,
20}
21
22#[derive(Deserialize)]
23struct OllamaEmbeddingResponse {
24    embedding: Vec<f32>,
25}
26
27#[derive(Serialize)]
28#[allow(dead_code)]
29struct OllamaShowRequest {
30    name: String,
31}
32
33#[derive(Deserialize)]
34#[allow(dead_code)]
35pub struct OllamaShowResponse {
36    pub details: Option<ModelDetails>,
37}
38
39#[derive(Deserialize)]
40#[allow(dead_code)]
41pub struct ModelDetails {
42    pub parameter_size: Option<String>,
43}
44
45impl OllamaProvider {
46    /// Create a new Ollama provider
47    pub fn new(model: String, base_url: Option<String>) -> Self {
48        let client = Client::builder()
49            .timeout(std::time::Duration::from_secs(60))
50            .build()
51            .unwrap();
52
53        let base_url = base_url.unwrap_or_else(|| "http://localhost:11434".to_string());
54
55        Self {
56            client,
57            base_url,
58            model,
59            dimension: None,
60        }
61    }
62
63    /// Detect dimension by making a test API call
64    #[allow(dead_code)]
65    pub async fn detect_dimension(&mut self) -> Result<usize> {
66        if let Some(dim) = self.dimension {
67            return Ok(dim);
68        }
69
70        log::info!(
71            "Detecting embedding dimension for Ollama model: {}",
72            self.model
73        );
74
75        let test_embedding = self.call_api("test").await?;
76        let dimension = test_embedding.len();
77
78        self.dimension = Some(dimension);
79        log::info!("Detected dimension: {} for model {}", dimension, self.model);
80
81        Ok(dimension)
82    }
83
84    /// Get model information from Ollama
85    #[allow(dead_code)]
86    pub async fn get_model_info(&self) -> Result<OllamaShowResponse> {
87        let request = OllamaShowRequest {
88            name: self.model.clone(),
89        };
90
91        let url = format!("{}/api/show", self.base_url);
92
93        let response = self
94            .client
95            .post(&url)
96            .header("Content-Type", "application/json")
97            .json(&request)
98            .send()
99            .await?;
100
101        let status = response.status();
102        if !status.is_success() {
103            let error_text = response.text().await.unwrap_or_default();
104            return Err(anyhow!(
105                "Ollama show API error: HTTP {} - {}",
106                status,
107                error_text
108            ));
109        }
110
111        let show_response: OllamaShowResponse = response.json().await?;
112        Ok(show_response)
113    }
114
115    /// Make API call to Ollama embeddings endpoint
116    async fn call_api(&self, text: &str) -> Result<Vec<f32>> {
117        let request = OllamaEmbeddingRequest {
118            model: self.model.clone(),
119            prompt: text.to_string(),
120        };
121
122        let url = format!("{}/api/embeddings", self.base_url);
123
124        let response = self
125            .client
126            .post(&url)
127            .header("Content-Type", "application/json")
128            .json(&request)
129            .send()
130            .await?;
131
132        let status = response.status();
133        if !status.is_success() {
134            let error_text = response.text().await.unwrap_or_default();
135            return Err(anyhow!(
136                "Ollama API error: HTTP {} - {}",
137                status,
138                error_text
139            ));
140        }
141
142        let embedding_response: OllamaEmbeddingResponse = response.json().await?;
143
144        if embedding_response.embedding.is_empty() {
145            return Err(anyhow!("No embeddings returned from Ollama API"));
146        }
147
148        Ok(embedding_response.embedding)
149    }
150
151    /// Check if Ollama server is available
152    pub async fn check_server(&self) -> Result<()> {
153        let url = format!("{}/api/version", self.base_url);
154
155        let response = self.client.get(&url).send().await.map_err(|e| {
156            anyhow!(
157                "Failed to connect to Ollama server at {}: {}",
158                self.base_url,
159                e
160            )
161        })?;
162
163        if !response.status().is_success() {
164            return Err(anyhow!(
165                "Ollama server returned error: HTTP {}",
166                response.status()
167            ));
168        }
169
170        Ok(())
171    }
172
173    /// Get common Ollama model information (dimension estimates)
174    pub fn get_common_model_info(model: &str) -> (Option<usize>, usize) {
175        match model {
176            "nomic-embed-text" => (Some(768), 2048),
177            "mxbai-embed-large" => (Some(1024), 512),
178            "all-minilm" => (Some(384), 512),
179            _ => (None, 2048), // Unknown model, will detect dynamically
180        }
181    }
182}
183
184#[async_trait]
185impl ProviderTrait for OllamaProvider {
186    async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
187        if text.trim().is_empty() {
188            return Err(anyhow!("Cannot embed empty text"));
189        }
190
191        // Ollama typically handles longer texts well, but let's be conservative
192        let (_, max_chars) = Self::get_common_model_info(&self.model);
193        let truncated_text = if text.len() > max_chars * 4 {
194            // Rough token approximation
195            &text[..max_chars * 4]
196        } else {
197            text
198        };
199
200        self.call_api(truncated_text).await
201    }
202
203    async fn get_dimension(&self) -> Result<usize> {
204        if let Some(dim) = self.dimension {
205            Ok(dim)
206        } else {
207            // Try to use known dimensions for common models
208            let (known_dim, _) = Self::get_common_model_info(&self.model);
209            if let Some(dim) = known_dim {
210                Ok(dim)
211            } else {
212                // Need to detect dynamically
213                Err(anyhow!(
214                    "Dimension not known for model {}. Use 'manx embedding test' to detect it.",
215                    self.model
216                ))
217            }
218        }
219    }
220
221    async fn health_check(&self) -> Result<()> {
222        self.check_server().await?;
223        self.call_api("test").await.map(|_| ())
224    }
225
226    fn get_info(&self) -> ProviderInfo {
227        let (_, max_length) = Self::get_common_model_info(&self.model);
228
229        ProviderInfo {
230            name: "Ollama Local Server".to_string(),
231            provider_type: "ollama".to_string(),
232            model_name: Some(self.model.clone()),
233            description: format!("Ollama model: {} ({})", self.model, self.base_url),
234            max_input_length: Some(max_length),
235        }
236    }
237
238    fn as_any(&self) -> &dyn std::any::Any {
239        self
240    }
241}