codemem_embeddings/
ollama.rs1use codemem_core::CodememError;
7
8pub const DEFAULT_BASE_URL: &str = "http://localhost:11434";
10
11pub const DEFAULT_MODEL: &str = "nomic-embed-text";
13
14pub struct OllamaProvider {
16 base_url: String,
17 model: String,
18 dimensions: usize,
19 client: reqwest::blocking::Client,
20}
21
22impl OllamaProvider {
23 pub fn new(base_url: &str, model: &str, dimensions: usize) -> Self {
25 Self {
26 base_url: base_url.to_string(),
27 model: model.to_string(),
28 dimensions,
29 client: reqwest::blocking::Client::new(),
30 }
31 }
32
33 pub fn with_defaults() -> Self {
35 Self::new(DEFAULT_BASE_URL, DEFAULT_MODEL, 768)
36 }
37}
38
39impl super::EmbeddingProvider for OllamaProvider {
40 fn dimensions(&self) -> usize {
41 self.dimensions
42 }
43
44 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
45 let url = format!("{}/api/embeddings", self.base_url);
46 let body = serde_json::json!({
47 "model": self.model,
48 "prompt": text,
49 });
50
51 let response = self
52 .client
53 .post(&url)
54 .json(&body)
55 .send()
56 .map_err(|e| CodememError::Embedding(format!("Ollama request failed: {e}")))?;
57
58 if !response.status().is_success() {
59 return Err(CodememError::Embedding(format!(
60 "Ollama returned status {}",
61 response.status()
62 )));
63 }
64
65 let json: serde_json::Value = response
66 .json()
67 .map_err(|e| CodememError::Embedding(format!("Ollama response parse error: {e}")))?;
68
69 let embedding = json
70 .get("embedding")
71 .and_then(|v| v.as_array())
72 .ok_or_else(|| CodememError::Embedding("Missing 'embedding' field in response".into()))?
73 .iter()
74 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
75 .collect();
76
77 Ok(embedding)
78 }
79
80 fn name(&self) -> &str {
81 "ollama"
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88
89 #[test]
90 fn ollama_provider_construction() {
91 let provider = OllamaProvider::with_defaults();
92 assert_eq!(provider.base_url, DEFAULT_BASE_URL);
93 assert_eq!(provider.model, DEFAULT_MODEL);
94 assert_eq!(provider.dimensions, 768);
95 }
96
97 #[test]
98 fn ollama_provider_custom() {
99 let provider = OllamaProvider::new("http://myhost:11434", "mxbai-embed-large", 1024);
100 assert_eq!(provider.base_url, "http://myhost:11434");
101 assert_eq!(provider.model, "mxbai-embed-large");
102 assert_eq!(provider.dimensions, 1024);
103 }
104
105 #[test]
106 fn ollama_name_returns_ollama() {
107 use crate::EmbeddingProvider;
108 let provider = OllamaProvider::with_defaults();
109 assert_eq!(provider.name(), "ollama");
110 }
111
112 #[test]
113 fn ollama_dimensions_matches_constructor() {
114 use crate::EmbeddingProvider;
115 let provider = OllamaProvider::new("http://localhost:11434", "nomic-embed-text", 512);
116 assert_eq!(EmbeddingProvider::dimensions(&provider), 512);
117 }
118
119 #[test]
120 fn ollama_embed_success_mock() {
121 use crate::EmbeddingProvider;
122 let mut server = mockito::Server::new();
123 let mock = server
124 .mock("POST", "/api/embeddings")
125 .with_status(200)
126 .with_header("content-type", "application/json")
127 .with_body(r#"{"embedding": [0.1, 0.2, 0.3]}"#)
128 .create();
129
130 let provider = OllamaProvider::new(&server.url(), "nomic-embed-text", 3);
131 let result = provider.embed("test");
132 mock.assert();
133
134 let embedding = result.unwrap();
135 assert_eq!(embedding.len(), 3);
136 assert!((embedding[0] - 0.1).abs() < 1e-6);
137 assert!((embedding[1] - 0.2).abs() < 1e-6);
138 assert!((embedding[2] - 0.3).abs() < 1e-6);
139 }
140
141 #[test]
142 fn ollama_embed_server_error_mock() {
143 use crate::EmbeddingProvider;
144 let mut server = mockito::Server::new();
145 let mock = server
146 .mock("POST", "/api/embeddings")
147 .with_status(500)
148 .with_body("Internal Server Error")
149 .create();
150
151 let provider = OllamaProvider::new(&server.url(), "nomic-embed-text", 768);
152 let result = provider.embed("test");
153 mock.assert();
154
155 assert!(result.is_err());
156 }
157}