Skip to main content

codemem_embeddings/
ollama.rs

1//! Ollama embedding provider for Codemem.
2//!
3//! Uses Ollama's local API to generate embeddings.
4//! Default model: nomic-embed-text (768 dimensions).
5
6use codemem_core::CodememError;
7
8/// Default Ollama base URL.
9pub const DEFAULT_BASE_URL: &str = "http://localhost:11434";
10
11/// Default Ollama embedding model.
12pub const DEFAULT_MODEL: &str = "nomic-embed-text";
13
14/// Ollama embedding provider.
15pub struct OllamaProvider {
16    base_url: String,
17    model: String,
18    dimensions: usize,
19    client: reqwest::blocking::Client,
20}
21
22impl OllamaProvider {
23    /// Create a new Ollama provider.
24    pub fn new(base_url: &str, model: &str, dimensions: usize) -> Self {
25        Self {
26            base_url: base_url.to_string(),
27            model: model.to_string(),
28            dimensions,
29            client: reqwest::blocking::Client::new(),
30        }
31    }
32
33    /// Create with default settings (localhost:11434, nomic-embed-text).
34    pub fn with_defaults() -> Self {
35        Self::new(DEFAULT_BASE_URL, DEFAULT_MODEL, 768)
36    }
37}
38
39impl super::EmbeddingProvider for OllamaProvider {
40    fn dimensions(&self) -> usize {
41        self.dimensions
42    }
43
44    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
45        let url = format!("{}/api/embeddings", self.base_url);
46        let body = serde_json::json!({
47            "model": self.model,
48            "prompt": text,
49        });
50
51        let response = self
52            .client
53            .post(&url)
54            .json(&body)
55            .send()
56            .map_err(|e| CodememError::Embedding(format!("Ollama request failed: {e}")))?;
57
58        if !response.status().is_success() {
59            return Err(CodememError::Embedding(format!(
60                "Ollama returned status {}",
61                response.status()
62            )));
63        }
64
65        let json: serde_json::Value = response
66            .json()
67            .map_err(|e| CodememError::Embedding(format!("Ollama response parse error: {e}")))?;
68
69        let embedding = json
70            .get("embedding")
71            .and_then(|v| v.as_array())
72            .ok_or_else(|| CodememError::Embedding("Missing 'embedding' field in response".into()))?
73            .iter()
74            .map(|v| v.as_f64().unwrap_or(0.0) as f32)
75            .collect();
76
77        Ok(embedding)
78    }
79
80    fn name(&self) -> &str {
81        "ollama"
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88
89    #[test]
90    fn ollama_provider_construction() {
91        let provider = OllamaProvider::with_defaults();
92        assert_eq!(provider.base_url, DEFAULT_BASE_URL);
93        assert_eq!(provider.model, DEFAULT_MODEL);
94        assert_eq!(provider.dimensions, 768);
95    }
96
97    #[test]
98    fn ollama_provider_custom() {
99        let provider = OllamaProvider::new("http://myhost:11434", "mxbai-embed-large", 1024);
100        assert_eq!(provider.base_url, "http://myhost:11434");
101        assert_eq!(provider.model, "mxbai-embed-large");
102        assert_eq!(provider.dimensions, 1024);
103    }
104
105    #[test]
106    fn ollama_name_returns_ollama() {
107        use crate::EmbeddingProvider;
108        let provider = OllamaProvider::with_defaults();
109        assert_eq!(provider.name(), "ollama");
110    }
111
112    #[test]
113    fn ollama_dimensions_matches_constructor() {
114        use crate::EmbeddingProvider;
115        let provider = OllamaProvider::new("http://localhost:11434", "nomic-embed-text", 512);
116        assert_eq!(EmbeddingProvider::dimensions(&provider), 512);
117    }
118
119    #[test]
120    fn ollama_embed_success_mock() {
121        use crate::EmbeddingProvider;
122        let mut server = mockito::Server::new();
123        let mock = server
124            .mock("POST", "/api/embeddings")
125            .with_status(200)
126            .with_header("content-type", "application/json")
127            .with_body(r#"{"embedding": [0.1, 0.2, 0.3]}"#)
128            .create();
129
130        let provider = OllamaProvider::new(&server.url(), "nomic-embed-text", 3);
131        let result = provider.embed("test");
132        mock.assert();
133
134        let embedding = result.unwrap();
135        assert_eq!(embedding.len(), 3);
136        assert!((embedding[0] - 0.1).abs() < 1e-6);
137        assert!((embedding[1] - 0.2).abs() < 1e-6);
138        assert!((embedding[2] - 0.3).abs() < 1e-6);
139    }
140
141    #[test]
142    fn ollama_embed_server_error_mock() {
143        use crate::EmbeddingProvider;
144        let mut server = mockito::Server::new();
145        let mock = server
146            .mock("POST", "/api/embeddings")
147            .with_status(500)
148            .with_body("Internal Server Error")
149            .create();
150
151        let provider = OllamaProvider::new(&server.url(), "nomic-embed-text", 768);
152        let result = provider.embed("test");
153        mock.assert();
154
155        assert!(result.is_err());
156    }
157}