codemem_embeddings/
ollama.rs1use codemem_core::CodememError;
7
8pub const DEFAULT_BASE_URL: &str = "http://localhost:11434";
10
11pub const DEFAULT_MODEL: &str = "nomic-embed-text";
13
14pub struct OllamaProvider {
16 base_url: String,
17 model: String,
18 dimensions: usize,
19 client: reqwest::blocking::Client,
20}
21
22impl OllamaProvider {
23 pub fn new(base_url: &str, model: &str, dimensions: usize) -> Self {
25 Self {
26 base_url: base_url.to_string(),
27 model: model.to_string(),
28 dimensions,
29 client: reqwest::blocking::Client::new(),
30 }
31 }
32
33 pub fn with_defaults() -> Self {
35 Self::new(DEFAULT_BASE_URL, DEFAULT_MODEL, 768)
36 }
37}
38
39impl super::EmbeddingProvider for OllamaProvider {
40 fn dimensions(&self) -> usize {
41 self.dimensions
42 }
43
44 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
45 let url = format!("{}/api/embeddings", self.base_url);
46 let body = serde_json::json!({
47 "model": self.model,
48 "prompt": text,
49 });
50
51 let response = self
52 .client
53 .post(&url)
54 .json(&body)
55 .send()
56 .map_err(|e| CodememError::Embedding(format!("Ollama request failed: {e}")))?;
57
58 if !response.status().is_success() {
59 return Err(CodememError::Embedding(format!(
60 "Ollama returned status {}",
61 response.status()
62 )));
63 }
64
65 let json: serde_json::Value = response
66 .json()
67 .map_err(|e| CodememError::Embedding(format!("Ollama response parse error: {e}")))?;
68
69 let embedding = json
70 .get("embedding")
71 .and_then(|v| v.as_array())
72 .ok_or_else(|| CodememError::Embedding("Missing 'embedding' field in response".into()))?
73 .iter()
74 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
75 .collect();
76
77 Ok(embedding)
78 }
79
80 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
81 if texts.is_empty() {
82 return Ok(vec![]);
83 }
84
85 let url = format!("{}/api/embed", self.base_url);
87 let body = serde_json::json!({
88 "model": self.model,
89 "input": texts,
90 });
91
92 let response =
93 self.client.post(&url).json(&body).send().map_err(|e| {
94 CodememError::Embedding(format!("Ollama batch request failed: {e}"))
95 })?;
96
97 if !response.status().is_success() {
98 let mut results = Vec::with_capacity(texts.len());
100 for text in texts {
101 results.push(self.embed(text)?);
102 }
103 return Ok(results);
104 }
105
106 let json: serde_json::Value = response
107 .json()
108 .map_err(|e| CodememError::Embedding(format!("Ollama response parse error: {e}")))?;
109
110 let embeddings_arr = json
111 .get("embeddings")
112 .and_then(|v| v.as_array())
113 .ok_or_else(|| {
114 CodememError::Embedding("Missing 'embeddings' array in Ollama response".into())
115 })?;
116
117 if embeddings_arr.len() != texts.len() {
118 return Err(CodememError::Embedding(format!(
119 "Ollama returned {} embeddings, expected {}",
120 embeddings_arr.len(),
121 texts.len()
122 )));
123 }
124
125 let results: Vec<Vec<f32>> = embeddings_arr
126 .iter()
127 .map(|arr| {
128 arr.as_array()
129 .unwrap_or(&vec![])
130 .iter()
131 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
132 .collect()
133 })
134 .collect();
135
136 Ok(results)
137 }
138
139 fn name(&self) -> &str {
140 "ollama"
141 }
142}
143
144#[cfg(test)]
145#[path = "tests/ollama_tests.rs"]
146mod tests;