Skip to main content

codemem_embeddings/
ollama.rs

1//! Ollama embedding provider for Codemem.
2//!
3//! Uses Ollama's local API to generate embeddings.
4//! Default model: nomic-embed-text (768 dimensions).
5
6use codemem_core::CodememError;
7
8/// Default Ollama base URL.
9pub const DEFAULT_BASE_URL: &str = "http://localhost:11434";
10
11/// Default Ollama embedding model.
12pub const DEFAULT_MODEL: &str = "nomic-embed-text";
13
14/// Ollama embedding provider.
15pub struct OllamaProvider {
16    base_url: String,
17    model: String,
18    dimensions: usize,
19    client: reqwest::blocking::Client,
20}
21
22impl OllamaProvider {
23    /// Create a new Ollama provider.
24    pub fn new(base_url: &str, model: &str, dimensions: usize) -> Self {
25        Self {
26            base_url: base_url.to_string(),
27            model: model.to_string(),
28            dimensions,
29            client: reqwest::blocking::Client::new(),
30        }
31    }
32
33    /// Create with default settings (localhost:11434, nomic-embed-text).
34    pub fn with_defaults() -> Self {
35        Self::new(DEFAULT_BASE_URL, DEFAULT_MODEL, 768)
36    }
37}
38
39impl super::EmbeddingProvider for OllamaProvider {
40    fn dimensions(&self) -> usize {
41        self.dimensions
42    }
43
44    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
45        let url = format!("{}/api/embeddings", self.base_url);
46        let body = serde_json::json!({
47            "model": self.model,
48            "prompt": text,
49        });
50
51        let response = self
52            .client
53            .post(&url)
54            .json(&body)
55            .send()
56            .map_err(|e| CodememError::Embedding(format!("Ollama request failed: {e}")))?;
57
58        if !response.status().is_success() {
59            return Err(CodememError::Embedding(format!(
60                "Ollama returned status {}",
61                response.status()
62            )));
63        }
64
65        let json: serde_json::Value = response
66            .json()
67            .map_err(|e| CodememError::Embedding(format!("Ollama response parse error: {e}")))?;
68
69        let embedding = json
70            .get("embedding")
71            .and_then(|v| v.as_array())
72            .ok_or_else(|| CodememError::Embedding("Missing 'embedding' field in response".into()))?
73            .iter()
74            .map(|v| v.as_f64().unwrap_or(0.0) as f32)
75            .collect();
76
77        Ok(embedding)
78    }
79
80    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
81        if texts.is_empty() {
82            return Ok(vec![]);
83        }
84
85        // Ollama /api/embed supports batch via "input" array (Ollama >= 0.3)
86        let url = format!("{}/api/embed", self.base_url);
87        let body = serde_json::json!({
88            "model": self.model,
89            "input": texts,
90        });
91
92        let response =
93            self.client.post(&url).json(&body).send().map_err(|e| {
94                CodememError::Embedding(format!("Ollama batch request failed: {e}"))
95            })?;
96
97        if !response.status().is_success() {
98            // Fall back to sequential calls if batch endpoint unavailable
99            let mut results = Vec::with_capacity(texts.len());
100            for text in texts {
101                results.push(self.embed(text)?);
102            }
103            return Ok(results);
104        }
105
106        let json: serde_json::Value = response
107            .json()
108            .map_err(|e| CodememError::Embedding(format!("Ollama response parse error: {e}")))?;
109
110        let embeddings_arr = json
111            .get("embeddings")
112            .and_then(|v| v.as_array())
113            .ok_or_else(|| {
114                CodememError::Embedding("Missing 'embeddings' array in Ollama response".into())
115            })?;
116
117        if embeddings_arr.len() != texts.len() {
118            return Err(CodememError::Embedding(format!(
119                "Ollama returned {} embeddings, expected {}",
120                embeddings_arr.len(),
121                texts.len()
122            )));
123        }
124
125        let results: Vec<Vec<f32>> = embeddings_arr
126            .iter()
127            .map(|arr| {
128                arr.as_array()
129                    .unwrap_or(&vec![])
130                    .iter()
131                    .map(|v| v.as_f64().unwrap_or(0.0) as f32)
132                    .collect()
133            })
134            .collect();
135
136        Ok(results)
137    }
138
139    fn name(&self) -> &str {
140        "ollama"
141    }
142}
143
144#[cfg(test)]
145#[path = "tests/ollama_tests.rs"]
146mod tests;