1mod cache;
16mod provider;
17mod queue;
18mod tfidf;
19
20#[cfg(feature = "cohere")]
21pub mod cohere;
22#[cfg(feature = "multimodal")]
23pub mod clip;
24#[cfg(feature = "ollama")]
25pub mod ollama;
26#[cfg(feature = "onnx-embed")]
27pub mod onnx;
28#[cfg(feature = "voyage")]
29pub mod voyage;
30
31pub use cache::{EmbeddingCache, EmbeddingCacheStats};
32#[cfg(feature = "multimodal")]
33pub use clip::{ClipEmbedder, MultimodalEmbedder, CLIP_PROVIDER_NAME};
34pub use provider::{EmbeddingProvider, EmbeddingProviderInfo, EmbeddingRegistry};
35pub use queue::{get_embedding, get_embedding_status, EmbeddingQueue, EmbeddingWorker};
36pub use tfidf::TfIdfEmbedder;
37
38use std::sync::Arc;
39
40use crate::error::{EngramError, Result};
41use crate::types::EmbeddingConfig;
42
43pub trait Embedder: Send + Sync {
45 fn embed(&self, text: &str) -> Result<Vec<f32>>;
47
48 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
50 texts.iter().map(|t| self.embed(t)).collect()
51 }
52
53 fn dimensions(&self) -> usize;
55
56 fn model_name(&self) -> &str;
58}
59
60#[cfg(feature = "openai")]
65pub struct OpenAIEmbedder {
66 client: reqwest::Client,
67 api_key: String,
68 base_url: String,
69 model: String,
70 dimensions: usize,
71}
72
73#[cfg(feature = "openai")]
74impl OpenAIEmbedder {
75 pub fn new(api_key: String) -> Self {
77 Self {
78 client: reqwest::Client::new(),
79 api_key,
80 base_url: "https://api.openai.com/v1".to_string(),
81 model: "text-embedding-3-small".to_string(),
82 dimensions: 1536,
83 }
84 }
85
86 pub fn with_config(
94 api_key: String,
95 base_url: Option<String>,
96 model: Option<String>,
97 dimensions: Option<usize>,
98 ) -> Self {
99 Self {
100 client: reqwest::Client::new(),
101 api_key,
102 base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
103 model: model.unwrap_or_else(|| "text-embedding-3-small".to_string()),
104 dimensions: dimensions.unwrap_or(1536),
105 }
106 }
107
108 pub fn with_model(api_key: String, model: String, dimensions: usize) -> Self {
110 Self {
111 client: reqwest::Client::new(),
112 api_key,
113 base_url: "https://api.openai.com/v1".to_string(),
114 model,
115 dimensions,
116 }
117 }
118
119 pub async fn embed_async(&self, text: &str) -> Result<Vec<f32>> {
121 let url = format!("{}/embeddings", self.base_url);
122
123 let response = self
124 .client
125 .post(&url)
126 .header("Authorization", format!("Bearer {}", self.api_key))
127 .header("HTTP-Referer", "https://github.com/engram")
129 .header("X-Title", "Engram Memory")
131 .json(&serde_json::json!({
132 "input": text,
133 "model": self.model,
134 }))
135 .send()
136 .await?;
137
138 if !response.status().is_success() {
139 let status = response.status();
140 let text = response.text().await.unwrap_or_default();
141 return Err(EngramError::Embedding(format!(
142 "Embedding API error {}: {}",
143 status, text
144 )));
145 }
146
147 let data: serde_json::Value = response.json().await?;
148 let embedding: Vec<f32> = data["data"][0]["embedding"]
149 .as_array()
150 .ok_or_else(|| EngramError::Embedding("Invalid response format".to_string()))?
151 .iter()
152 .filter_map(|v| v.as_f64().map(|f| f as f32))
153 .collect();
154
155 if embedding.len() != self.dimensions {
157 return Err(EngramError::Embedding(format!(
158 "Embedding dimensions mismatch: expected {}, got {}. Set OPENAI_EMBEDDING_DIMENSIONS={} to match your model.",
159 self.dimensions, embedding.len(), embedding.len()
160 )));
161 }
162
163 Ok(embedding)
164 }
165
166 pub async fn embed_batch_async(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
168 if texts.is_empty() {
169 return Ok(vec![]);
170 }
171
172 let url = format!("{}/embeddings", self.base_url);
173
174 let mut all_embeddings = Vec::with_capacity(texts.len());
176
177 for chunk in texts.chunks(2048) {
178 let response = self
179 .client
180 .post(&url)
181 .header("Authorization", format!("Bearer {}", self.api_key))
182 .header("HTTP-Referer", "https://github.com/engram")
184 .header("X-Title", "Engram Memory")
185 .json(&serde_json::json!({
186 "input": chunk,
187 "model": self.model,
188 }))
189 .send()
190 .await?;
191
192 if !response.status().is_success() {
193 let status = response.status();
194 let text = response.text().await.unwrap_or_default();
195 return Err(EngramError::Embedding(format!(
196 "Embedding API error {}: {}",
197 status, text
198 )));
199 }
200
201 let data: serde_json::Value = response.json().await?;
202 let embeddings: Vec<Vec<f32>> = data["data"]
203 .as_array()
204 .ok_or_else(|| EngramError::Embedding("Invalid response format".to_string()))?
205 .iter()
206 .map(|item| {
207 item["embedding"]
208 .as_array()
209 .map(|arr| {
210 arr.iter()
211 .filter_map(|v| v.as_f64().map(|f| f as f32))
212 .collect()
213 })
214 .unwrap_or_default()
215 })
216 .collect();
217
218 if !embeddings.is_empty() && embeddings[0].len() != self.dimensions {
220 return Err(EngramError::Embedding(format!(
221 "Embedding dimensions mismatch: expected {}, got {}. Set OPENAI_EMBEDDING_DIMENSIONS={} to match your model.",
222 self.dimensions, embeddings[0].len(), embeddings[0].len()
223 )));
224 }
225
226 all_embeddings.extend(embeddings);
227 }
228
229 Ok(all_embeddings)
230 }
231}
232
233#[cfg(feature = "openai")]
234impl Embedder for OpenAIEmbedder {
235 fn embed(&self, text: &str) -> Result<Vec<f32>> {
236 tokio::task::block_in_place(|| {
238 tokio::runtime::Handle::current().block_on(self.embed_async(text))
239 })
240 }
241
242 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
243 tokio::task::block_in_place(|| {
244 tokio::runtime::Handle::current().block_on(self.embed_batch_async(texts))
245 })
246 }
247
248 fn dimensions(&self) -> usize {
249 self.dimensions
250 }
251
252 fn model_name(&self) -> &str {
253 &self.model
254 }
255}
256
257pub fn create_embedder(config: &EmbeddingConfig) -> Result<Arc<dyn Embedder>> {
268 match config.model.as_str() {
269 #[cfg(feature = "multimodal")]
270 "clip" => {
271 clip::create_clip_embedder()
272 .map(|e| e as Arc<dyn Embedder>)
273 }
274 #[cfg(feature = "openai")]
275 "openai" => {
276 let api_key = config
277 .api_key
278 .clone()
279 .ok_or_else(|| EngramError::Config(
280 "OPENAI_API_KEY required when ENGRAM_EMBEDDING_MODEL=openai".to_string()
281 ))?;
282 Ok(Arc::new(OpenAIEmbedder::with_config(
283 api_key,
284 config.base_url.clone(),
285 config.embedding_model.clone(),
286 Some(config.dimensions),
287 )))
288 }
289 #[cfg(not(feature = "openai"))]
290 "openai" => Err(EngramError::Config(
291 "OpenAI embeddings require the 'openai' feature to be enabled. Build with: cargo build --features openai".to_string(),
292 )),
293 "tfidf" => Ok(Arc::new(TfIdfEmbedder::new(config.dimensions))),
294 _ => Err(EngramError::Config(format!(
295 "Unknown embedding model: '{}'. Use 'openai' or 'tfidf'",
296 config.model
297 ))),
298 }
299}
300
301pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
303 if a.len() != b.len() || a.is_empty() {
304 return 0.0;
305 }
306
307 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
308 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
309 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
310
311 if norm_a == 0.0 || norm_b == 0.0 {
312 return 0.0;
313 }
314
315 dot / (norm_a * norm_b)
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_cosine_similarity() {
324 let a = vec![1.0, 0.0, 0.0];
325 let b = vec![1.0, 0.0, 0.0];
326 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
327
328 let c = vec![0.0, 1.0, 0.0];
329 assert!(cosine_similarity(&a, &c).abs() < 0.001);
330
331 let d = vec![-1.0, 0.0, 0.0];
332 assert!((cosine_similarity(&a, &d) + 1.0).abs() < 0.001);
333 }
334
335 #[test]
336 fn test_tfidf_embedder() {
337 let embedder = TfIdfEmbedder::new(384);
338 let embedding = embedder.embed("Hello world").unwrap();
339 assert_eq!(embedding.len(), 384);
340 }
341}