Skip to main content

codemem_embeddings/
ollama.rs

1//! Ollama embedding provider for Codemem.
2//!
3//! Uses Ollama's local API to generate embeddings.
4//! Default model: nomic-embed-text (768 dimensions).
5
6use codemem_core::CodememError;
7
8/// Default Ollama base URL.
9pub const DEFAULT_BASE_URL: &str = "http://localhost:11434";
10
11/// Default Ollama embedding model.
12pub const DEFAULT_MODEL: &str = "nomic-embed-text";
13
14/// Ollama embedding provider.
15pub struct OllamaProvider {
16    base_url: String,
17    model: String,
18    dimensions: usize,
19    client: reqwest::blocking::Client,
20}
21
22impl OllamaProvider {
23    /// Create a new Ollama provider.
24    pub fn new(base_url: &str, model: &str, dimensions: usize) -> Self {
25        Self {
26            base_url: base_url.to_string(),
27            model: model.to_string(),
28            dimensions,
29            client: reqwest::blocking::Client::new(),
30        }
31    }
32
33    /// Create with default settings (localhost:11434, nomic-embed-text).
34    pub fn with_defaults() -> Self {
35        Self::new(DEFAULT_BASE_URL, DEFAULT_MODEL, 768)
36    }
37}
38
39impl super::EmbeddingProvider for OllamaProvider {
40    fn dimensions(&self) -> usize {
41        self.dimensions
42    }
43
44    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
45        let url = format!("{}/api/embeddings", self.base_url);
46        let body = serde_json::json!({
47            "model": self.model,
48            "prompt": text,
49        });
50
51        let response = self
52            .client
53            .post(&url)
54            .json(&body)
55            .send()
56            .map_err(|e| CodememError::Embedding(format!("Ollama request failed: {e}")))?;
57
58        if !response.status().is_success() {
59            return Err(CodememError::Embedding(format!(
60                "Ollama returned status {}",
61                response.status()
62            )));
63        }
64
65        let json: serde_json::Value = response
66            .json()
67            .map_err(|e| CodememError::Embedding(format!("Ollama response parse error: {e}")))?;
68
69        let embedding = json
70            .get("embedding")
71            .and_then(|v| v.as_array())
72            .ok_or_else(|| CodememError::Embedding("Missing 'embedding' field in response".into()))?
73            .iter()
74            .map(|v| v.as_f64().unwrap_or(0.0) as f32)
75            .collect();
76
77        Ok(embedding)
78    }
79
80    fn name(&self) -> &str {
81        "ollama"
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88
89    #[test]
90    fn ollama_provider_construction() {
91        let provider = OllamaProvider::with_defaults();
92        assert_eq!(provider.base_url, DEFAULT_BASE_URL);
93        assert_eq!(provider.model, DEFAULT_MODEL);
94        assert_eq!(provider.dimensions, 768);
95    }
96
97    #[test]
98    fn ollama_provider_custom() {
99        let provider = OllamaProvider::new("http://myhost:11434", "mxbai-embed-large", 1024);
100        assert_eq!(provider.base_url, "http://myhost:11434");
101        assert_eq!(provider.model, "mxbai-embed-large");
102        assert_eq!(provider.dimensions, 1024);
103    }
104}