manx_cli/rag/providers/
ollama.rs1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use super::{EmbeddingProvider as ProviderTrait, ProviderInfo};
7
8pub struct OllamaProvider {
10 client: Client,
11 base_url: String,
12 model: String,
13 dimension: Option<usize>, }
15
16#[derive(Serialize)]
17struct OllamaEmbeddingRequest {
18 model: String,
19 prompt: String,
20}
21
22#[derive(Deserialize)]
23struct OllamaEmbeddingResponse {
24 embedding: Vec<f32>,
25}
26
27#[derive(Serialize)]
28#[allow(dead_code)]
29struct OllamaShowRequest {
30 name: String,
31}
32
33#[derive(Deserialize)]
34#[allow(dead_code)]
35pub struct OllamaShowResponse {
36 pub details: Option<ModelDetails>,
37}
38
39#[derive(Deserialize)]
40#[allow(dead_code)]
41pub struct ModelDetails {
42 pub parameter_size: Option<String>,
43}
44
45impl OllamaProvider {
46 pub fn new(model: String, base_url: Option<String>) -> Self {
48 let client = Client::builder()
49 .timeout(std::time::Duration::from_secs(60))
50 .build()
51 .unwrap();
52
53 let base_url = base_url.unwrap_or_else(|| "http://localhost:11434".to_string());
54
55 Self {
56 client,
57 base_url,
58 model,
59 dimension: None,
60 }
61 }
62
63 #[allow(dead_code)]
65 pub async fn detect_dimension(&mut self) -> Result<usize> {
66 if let Some(dim) = self.dimension {
67 return Ok(dim);
68 }
69
70 log::info!(
71 "Detecting embedding dimension for Ollama model: {}",
72 self.model
73 );
74
75 let test_embedding = self.call_api("test").await?;
76 let dimension = test_embedding.len();
77
78 self.dimension = Some(dimension);
79 log::info!("Detected dimension: {} for model {}", dimension, self.model);
80
81 Ok(dimension)
82 }
83
84 #[allow(dead_code)]
86 pub async fn get_model_info(&self) -> Result<OllamaShowResponse> {
87 let request = OllamaShowRequest {
88 name: self.model.clone(),
89 };
90
91 let url = format!("{}/api/show", self.base_url);
92
93 let response = self
94 .client
95 .post(&url)
96 .header("Content-Type", "application/json")
97 .json(&request)
98 .send()
99 .await?;
100
101 let status = response.status();
102 if !status.is_success() {
103 let error_text = response.text().await.unwrap_or_default();
104 return Err(anyhow!(
105 "Ollama show API error: HTTP {} - {}",
106 status,
107 error_text
108 ));
109 }
110
111 let show_response: OllamaShowResponse = response.json().await?;
112 Ok(show_response)
113 }
114
115 async fn call_api(&self, text: &str) -> Result<Vec<f32>> {
117 let request = OllamaEmbeddingRequest {
118 model: self.model.clone(),
119 prompt: text.to_string(),
120 };
121
122 let url = format!("{}/api/embeddings", self.base_url);
123
124 let response = self
125 .client
126 .post(&url)
127 .header("Content-Type", "application/json")
128 .json(&request)
129 .send()
130 .await?;
131
132 let status = response.status();
133 if !status.is_success() {
134 let error_text = response.text().await.unwrap_or_default();
135 return Err(anyhow!(
136 "Ollama API error: HTTP {} - {}",
137 status,
138 error_text
139 ));
140 }
141
142 let embedding_response: OllamaEmbeddingResponse = response.json().await?;
143
144 if embedding_response.embedding.is_empty() {
145 return Err(anyhow!("No embeddings returned from Ollama API"));
146 }
147
148 Ok(embedding_response.embedding)
149 }
150
151 pub async fn check_server(&self) -> Result<()> {
153 let url = format!("{}/api/version", self.base_url);
154
155 let response = self.client.get(&url).send().await.map_err(|e| {
156 anyhow!(
157 "Failed to connect to Ollama server at {}: {}",
158 self.base_url,
159 e
160 )
161 })?;
162
163 if !response.status().is_success() {
164 return Err(anyhow!(
165 "Ollama server returned error: HTTP {}",
166 response.status()
167 ));
168 }
169
170 Ok(())
171 }
172
173 pub fn get_common_model_info(model: &str) -> (Option<usize>, usize) {
175 match model {
176 "nomic-embed-text" => (Some(768), 2048),
177 "mxbai-embed-large" => (Some(1024), 512),
178 "all-minilm" => (Some(384), 512),
179 _ => (None, 2048), }
181 }
182}
183
184#[async_trait]
185impl ProviderTrait for OllamaProvider {
186 async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
187 if text.trim().is_empty() {
188 return Err(anyhow!("Cannot embed empty text"));
189 }
190
191 let (_, max_chars) = Self::get_common_model_info(&self.model);
193 let truncated_text = if text.len() > max_chars * 4 {
194 &text[..max_chars * 4]
196 } else {
197 text
198 };
199
200 self.call_api(truncated_text).await
201 }
202
203 async fn get_dimension(&self) -> Result<usize> {
204 if let Some(dim) = self.dimension {
205 Ok(dim)
206 } else {
207 let (known_dim, _) = Self::get_common_model_info(&self.model);
209 if let Some(dim) = known_dim {
210 Ok(dim)
211 } else {
212 Err(anyhow!(
214 "Dimension not known for model {}. Use 'manx embedding test' to detect it.",
215 self.model
216 ))
217 }
218 }
219 }
220
221 async fn health_check(&self) -> Result<()> {
222 self.check_server().await?;
223 self.call_api("test").await.map(|_| ())
224 }
225
226 fn get_info(&self) -> ProviderInfo {
227 let (_, max_length) = Self::get_common_model_info(&self.model);
228
229 ProviderInfo {
230 name: "Ollama Local Server".to_string(),
231 provider_type: "ollama".to_string(),
232 model_name: Some(self.model.clone()),
233 description: format!("Ollama model: {} ({})", self.model, self.base_url),
234 max_input_length: Some(max_length),
235 }
236 }
237
238 fn as_any(&self) -> &dyn std::any::Any {
239 self
240 }
241}