1mod cache;
16mod provider;
17mod queue;
18mod tfidf;
19
20#[cfg(feature = "cohere")]
21pub mod cohere;
22#[cfg(feature = "ollama")]
23pub mod ollama;
24#[cfg(feature = "onnx-embed")]
25pub mod onnx;
26#[cfg(feature = "voyage")]
27pub mod voyage;
28
29pub use cache::{EmbeddingCache, EmbeddingCacheStats};
30pub use provider::{EmbeddingProvider, EmbeddingProviderInfo, EmbeddingRegistry};
31pub use queue::{get_embedding, get_embedding_status, EmbeddingQueue, EmbeddingWorker};
32pub use tfidf::TfIdfEmbedder;
33
34use std::sync::Arc;
35
36use crate::error::{EngramError, Result};
37use crate::types::EmbeddingConfig;
38
39pub trait Embedder: Send + Sync {
41 fn embed(&self, text: &str) -> Result<Vec<f32>>;
43
44 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
46 texts.iter().map(|t| self.embed(t)).collect()
47 }
48
49 fn dimensions(&self) -> usize;
51
52 fn model_name(&self) -> &str;
54}
55
56#[cfg(feature = "openai")]
61pub struct OpenAIEmbedder {
62 client: reqwest::Client,
63 api_key: String,
64 base_url: String,
65 model: String,
66 dimensions: usize,
67}
68
69#[cfg(feature = "openai")]
70impl OpenAIEmbedder {
71 pub fn new(api_key: String) -> Self {
73 Self {
74 client: reqwest::Client::new(),
75 api_key,
76 base_url: "https://api.openai.com/v1".to_string(),
77 model: "text-embedding-3-small".to_string(),
78 dimensions: 1536,
79 }
80 }
81
82 pub fn with_config(
90 api_key: String,
91 base_url: Option<String>,
92 model: Option<String>,
93 dimensions: Option<usize>,
94 ) -> Self {
95 Self {
96 client: reqwest::Client::new(),
97 api_key,
98 base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
99 model: model.unwrap_or_else(|| "text-embedding-3-small".to_string()),
100 dimensions: dimensions.unwrap_or(1536),
101 }
102 }
103
104 pub fn with_model(api_key: String, model: String, dimensions: usize) -> Self {
106 Self {
107 client: reqwest::Client::new(),
108 api_key,
109 base_url: "https://api.openai.com/v1".to_string(),
110 model,
111 dimensions,
112 }
113 }
114
115 pub async fn embed_async(&self, text: &str) -> Result<Vec<f32>> {
117 let url = format!("{}/embeddings", self.base_url);
118
119 let response = self
120 .client
121 .post(&url)
122 .header("Authorization", format!("Bearer {}", self.api_key))
123 .header("HTTP-Referer", "https://github.com/engram")
125 .header("X-Title", "Engram Memory")
127 .json(&serde_json::json!({
128 "input": text,
129 "model": self.model,
130 }))
131 .send()
132 .await?;
133
134 if !response.status().is_success() {
135 let status = response.status();
136 let text = response.text().await.unwrap_or_default();
137 return Err(EngramError::Embedding(format!(
138 "Embedding API error {}: {}",
139 status, text
140 )));
141 }
142
143 let data: serde_json::Value = response.json().await?;
144 let embedding: Vec<f32> = data["data"][0]["embedding"]
145 .as_array()
146 .ok_or_else(|| EngramError::Embedding("Invalid response format".to_string()))?
147 .iter()
148 .filter_map(|v| v.as_f64().map(|f| f as f32))
149 .collect();
150
151 if embedding.len() != self.dimensions {
153 return Err(EngramError::Embedding(format!(
154 "Embedding dimensions mismatch: expected {}, got {}. Set OPENAI_EMBEDDING_DIMENSIONS={} to match your model.",
155 self.dimensions, embedding.len(), embedding.len()
156 )));
157 }
158
159 Ok(embedding)
160 }
161
162 pub async fn embed_batch_async(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
164 if texts.is_empty() {
165 return Ok(vec![]);
166 }
167
168 let url = format!("{}/embeddings", self.base_url);
169
170 let mut all_embeddings = Vec::with_capacity(texts.len());
172
173 for chunk in texts.chunks(2048) {
174 let response = self
175 .client
176 .post(&url)
177 .header("Authorization", format!("Bearer {}", self.api_key))
178 .header("HTTP-Referer", "https://github.com/engram")
180 .header("X-Title", "Engram Memory")
181 .json(&serde_json::json!({
182 "input": chunk,
183 "model": self.model,
184 }))
185 .send()
186 .await?;
187
188 if !response.status().is_success() {
189 let status = response.status();
190 let text = response.text().await.unwrap_or_default();
191 return Err(EngramError::Embedding(format!(
192 "Embedding API error {}: {}",
193 status, text
194 )));
195 }
196
197 let data: serde_json::Value = response.json().await?;
198 let embeddings: Vec<Vec<f32>> = data["data"]
199 .as_array()
200 .ok_or_else(|| EngramError::Embedding("Invalid response format".to_string()))?
201 .iter()
202 .map(|item| {
203 item["embedding"]
204 .as_array()
205 .map(|arr| {
206 arr.iter()
207 .filter_map(|v| v.as_f64().map(|f| f as f32))
208 .collect()
209 })
210 .unwrap_or_default()
211 })
212 .collect();
213
214 if !embeddings.is_empty() && embeddings[0].len() != self.dimensions {
216 return Err(EngramError::Embedding(format!(
217 "Embedding dimensions mismatch: expected {}, got {}. Set OPENAI_EMBEDDING_DIMENSIONS={} to match your model.",
218 self.dimensions, embeddings[0].len(), embeddings[0].len()
219 )));
220 }
221
222 all_embeddings.extend(embeddings);
223 }
224
225 Ok(all_embeddings)
226 }
227}
228
229#[cfg(feature = "openai")]
230impl Embedder for OpenAIEmbedder {
231 fn embed(&self, text: &str) -> Result<Vec<f32>> {
232 tokio::task::block_in_place(|| {
234 tokio::runtime::Handle::current().block_on(self.embed_async(text))
235 })
236 }
237
238 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
239 tokio::task::block_in_place(|| {
240 tokio::runtime::Handle::current().block_on(self.embed_batch_async(texts))
241 })
242 }
243
244 fn dimensions(&self) -> usize {
245 self.dimensions
246 }
247
248 fn model_name(&self) -> &str {
249 &self.model
250 }
251}
252
253pub fn create_embedder(config: &EmbeddingConfig) -> Result<Arc<dyn Embedder>> {
264 match config.model.as_str() {
265 #[cfg(feature = "openai")]
266 "openai" => {
267 let api_key = config
268 .api_key
269 .clone()
270 .ok_or_else(|| EngramError::Config(
271 "OPENAI_API_KEY required when ENGRAM_EMBEDDING_MODEL=openai".to_string()
272 ))?;
273 Ok(Arc::new(OpenAIEmbedder::with_config(
274 api_key,
275 config.base_url.clone(),
276 config.embedding_model.clone(),
277 Some(config.dimensions),
278 )))
279 }
280 #[cfg(not(feature = "openai"))]
281 "openai" => Err(EngramError::Config(
282 "OpenAI embeddings require the 'openai' feature to be enabled. Build with: cargo build --features openai".to_string(),
283 )),
284 "tfidf" => Ok(Arc::new(TfIdfEmbedder::new(config.dimensions))),
285 _ => Err(EngramError::Config(format!(
286 "Unknown embedding model: '{}'. Use 'openai' or 'tfidf'",
287 config.model
288 ))),
289 }
290}
291
292pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
294 if a.len() != b.len() || a.is_empty() {
295 return 0.0;
296 }
297
298 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
299 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
300 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
301
302 if norm_a == 0.0 || norm_b == 0.0 {
303 return 0.0;
304 }
305
306 dot / (norm_a * norm_b)
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn test_cosine_similarity() {
315 let a = vec![1.0, 0.0, 0.0];
316 let b = vec![1.0, 0.0, 0.0];
317 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
318
319 let c = vec![0.0, 1.0, 0.0];
320 assert!(cosine_similarity(&a, &c).abs() < 0.001);
321
322 let d = vec![-1.0, 0.0, 0.0];
323 assert!((cosine_similarity(&a, &d) + 1.0).abs() < 0.001);
324 }
325
326 #[test]
327 fn test_tfidf_embedder() {
328 let embedder = TfIdfEmbedder::new(384);
329 let embedding = embedder.embed("Hello world").unwrap();
330 assert_eq!(embedding.len(), 384);
331 }
332}