1use crate::error::{Error, Result};
7use serde::{Deserialize, Serialize};
8
9use super::config::{resolve_ollama_endpoint, resolve_ollama_model};
10use super::provider::EmbeddingProvider;
11use super::types::{ollama_models, ProviderInfo};
12
13pub struct OllamaProvider {
15 client: reqwest::Client,
16 endpoint: String,
17 model: String,
18 dimensions: usize,
19 max_chars: usize,
20}
21
22impl OllamaProvider {
23 pub fn new() -> Self {
25 Self::with_config(None, None)
26 }
27
28 pub fn with_config(endpoint: Option<String>, model: Option<String>) -> Self {
30 let endpoint = endpoint.unwrap_or_else(resolve_ollama_endpoint);
31 let model = model.unwrap_or_else(resolve_ollama_model);
32 let config = ollama_models::get_config(&model);
33
34 Self {
35 client: reqwest::Client::new(),
36 endpoint,
37 model,
38 dimensions: config.dimensions,
39 max_chars: config.max_chars,
40 }
41 }
42}
43
44impl Default for OllamaProvider {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50#[derive(Debug, Deserialize)]
52struct OllamaTagsResponse {
53 models: Option<Vec<OllamaModel>>,
54}
55
56#[derive(Debug, Deserialize)]
57struct OllamaModel {
58 name: String,
59}
60
61#[derive(Debug, Serialize)]
63struct OllamaEmbedRequest<'a> {
64 model: &'a str,
65 input: EmbedInput<'a>,
66}
67
68#[derive(Debug, Serialize)]
69#[serde(untagged)]
70enum EmbedInput<'a> {
71 Single(&'a str),
72 Batch(Vec<&'a str>),
73}
74
75#[derive(Debug, Deserialize)]
77struct OllamaEmbedResponse {
78 embeddings: Vec<Vec<f32>>,
79}
80
81impl EmbeddingProvider for OllamaProvider {
82 fn info(&self) -> ProviderInfo {
83 ProviderInfo {
84 name: "ollama".to_string(),
85 model: self.model.clone(),
86 dimensions: self.dimensions,
87 max_chars: self.max_chars,
88 available: false, }
90 }
91
92 async fn is_available(&self) -> bool {
93 let url = format!("{}/api/tags", self.endpoint);
94
95 let response = match self.client
96 .get(&url)
97 .timeout(std::time::Duration::from_secs(2))
98 .send()
99 .await
100 {
101 Ok(r) => r,
102 Err(_) => return false,
103 };
104
105 if !response.status().is_success() {
106 return false;
107 }
108
109 let data: OllamaTagsResponse = match response.json().await {
110 Ok(d) => d,
111 Err(_) => return false,
112 };
113
114 data.models.map_or(false, |models| {
116 models.iter().any(|m| {
117 m.name == self.model || m.name.starts_with(&format!("{}:", self.model))
118 })
119 })
120 }
121
122 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
123 let url = format!("{}/api/embed", self.endpoint);
124
125 let request = OllamaEmbedRequest {
126 model: &self.model,
127 input: EmbedInput::Single(text),
128 };
129
130 let response = self.client
131 .post(&url)
132 .json(&request)
133 .send()
134 .await
135 .map_err(|e| Error::Embedding(format!("Ollama request failed: {e}")))?;
136
137 if !response.status().is_success() {
138 let error = response.text().await.unwrap_or_default();
139 return Err(Error::Embedding(format!("Ollama embedding failed: {error}")));
140 }
141
142 let data: OllamaEmbedResponse = response.json().await
143 .map_err(|e| Error::Embedding(format!("Failed to parse Ollama response: {e}")))?;
144
145 data.embeddings.into_iter().next()
146 .ok_or_else(|| Error::Embedding("No embeddings returned from Ollama".into()))
147 }
148
149 async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
150 let url = format!("{}/api/embed", self.endpoint);
151
152 let request = OllamaEmbedRequest {
153 model: &self.model,
154 input: EmbedInput::Batch(texts.to_vec()),
155 };
156
157 let response = self.client
158 .post(&url)
159 .json(&request)
160 .send()
161 .await
162 .map_err(|e| Error::Embedding(format!("Ollama batch request failed: {e}")))?;
163
164 if !response.status().is_success() {
165 let error = response.text().await.unwrap_or_default();
166 return Err(Error::Embedding(format!("Ollama batch embedding failed: {error}")));
167 }
168
169 let data: OllamaEmbedResponse = response.json().await
170 .map_err(|e| Error::Embedding(format!("Failed to parse Ollama response: {e}")))?;
171
172 Ok(data.embeddings)
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_ollama_provider_creation() {
182 let provider = OllamaProvider::new();
183 let info = provider.info();
184 assert_eq!(info.name, "ollama");
185 assert!(!info.model.is_empty());
186 assert!(info.dimensions > 0);
187 }
188
189 #[test]
190 fn test_ollama_provider_custom_config() {
191 let provider = OllamaProvider::with_config(
192 Some("http://custom:11434".to_string()),
193 Some("mxbai-embed-large".to_string()),
194 );
195 let info = provider.info();
196 assert_eq!(info.model, "mxbai-embed-large");
197 assert_eq!(info.dimensions, 1024);
198 }
199}