1mod cache;
16mod queue;
17mod tfidf;
18
19pub use cache::{EmbeddingCache, EmbeddingCacheStats};
20pub use queue::{get_embedding, get_embedding_status, EmbeddingQueue, EmbeddingWorker};
21pub use tfidf::TfIdfEmbedder;
22
23use std::sync::Arc;
24
25use crate::error::{EngramError, Result};
26use crate::types::EmbeddingConfig;
27
28pub trait Embedder: Send + Sync {
30 fn embed(&self, text: &str) -> Result<Vec<f32>>;
32
33 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
35 texts.iter().map(|t| self.embed(t)).collect()
36 }
37
38 fn dimensions(&self) -> usize;
40
41 fn model_name(&self) -> &str;
43}
44
45#[cfg(feature = "openai")]
50pub struct OpenAIEmbedder {
51 client: reqwest::Client,
52 api_key: String,
53 base_url: String,
54 model: String,
55 dimensions: usize,
56}
57
58#[cfg(feature = "openai")]
59impl OpenAIEmbedder {
60 pub fn new(api_key: String) -> Self {
62 Self {
63 client: reqwest::Client::new(),
64 api_key,
65 base_url: "https://api.openai.com/v1".to_string(),
66 model: "text-embedding-3-small".to_string(),
67 dimensions: 1536,
68 }
69 }
70
71 pub fn with_config(
79 api_key: String,
80 base_url: Option<String>,
81 model: Option<String>,
82 dimensions: Option<usize>,
83 ) -> Self {
84 Self {
85 client: reqwest::Client::new(),
86 api_key,
87 base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
88 model: model.unwrap_or_else(|| "text-embedding-3-small".to_string()),
89 dimensions: dimensions.unwrap_or(1536),
90 }
91 }
92
93 pub fn with_model(api_key: String, model: String, dimensions: usize) -> Self {
95 Self {
96 client: reqwest::Client::new(),
97 api_key,
98 base_url: "https://api.openai.com/v1".to_string(),
99 model,
100 dimensions,
101 }
102 }
103
104 pub async fn embed_async(&self, text: &str) -> Result<Vec<f32>> {
106 let url = format!("{}/embeddings", self.base_url);
107
108 let response = self
109 .client
110 .post(&url)
111 .header("Authorization", format!("Bearer {}", self.api_key))
112 .header("HTTP-Referer", "https://github.com/engram")
114 .header("X-Title", "Engram Memory")
116 .json(&serde_json::json!({
117 "input": text,
118 "model": self.model,
119 }))
120 .send()
121 .await?;
122
123 if !response.status().is_success() {
124 let status = response.status();
125 let text = response.text().await.unwrap_or_default();
126 return Err(EngramError::Embedding(format!(
127 "Embedding API error {}: {}",
128 status, text
129 )));
130 }
131
132 let data: serde_json::Value = response.json().await?;
133 let embedding: Vec<f32> = data["data"][0]["embedding"]
134 .as_array()
135 .ok_or_else(|| EngramError::Embedding("Invalid response format".to_string()))?
136 .iter()
137 .filter_map(|v| v.as_f64().map(|f| f as f32))
138 .collect();
139
140 if embedding.len() != self.dimensions {
142 return Err(EngramError::Embedding(format!(
143 "Embedding dimensions mismatch: expected {}, got {}. Set OPENAI_EMBEDDING_DIMENSIONS={} to match your model.",
144 self.dimensions, embedding.len(), embedding.len()
145 )));
146 }
147
148 Ok(embedding)
149 }
150
151 pub async fn embed_batch_async(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
153 if texts.is_empty() {
154 return Ok(vec![]);
155 }
156
157 let url = format!("{}/embeddings", self.base_url);
158
159 let mut all_embeddings = Vec::with_capacity(texts.len());
161
162 for chunk in texts.chunks(2048) {
163 let response = self
164 .client
165 .post(&url)
166 .header("Authorization", format!("Bearer {}", self.api_key))
167 .header("HTTP-Referer", "https://github.com/engram")
169 .header("X-Title", "Engram Memory")
170 .json(&serde_json::json!({
171 "input": chunk,
172 "model": self.model,
173 }))
174 .send()
175 .await?;
176
177 if !response.status().is_success() {
178 let status = response.status();
179 let text = response.text().await.unwrap_or_default();
180 return Err(EngramError::Embedding(format!(
181 "Embedding API error {}: {}",
182 status, text
183 )));
184 }
185
186 let data: serde_json::Value = response.json().await?;
187 let embeddings: Vec<Vec<f32>> = data["data"]
188 .as_array()
189 .ok_or_else(|| EngramError::Embedding("Invalid response format".to_string()))?
190 .iter()
191 .map(|item| {
192 item["embedding"]
193 .as_array()
194 .map(|arr| {
195 arr.iter()
196 .filter_map(|v| v.as_f64().map(|f| f as f32))
197 .collect()
198 })
199 .unwrap_or_default()
200 })
201 .collect();
202
203 if !embeddings.is_empty() && embeddings[0].len() != self.dimensions {
205 return Err(EngramError::Embedding(format!(
206 "Embedding dimensions mismatch: expected {}, got {}. Set OPENAI_EMBEDDING_DIMENSIONS={} to match your model.",
207 self.dimensions, embeddings[0].len(), embeddings[0].len()
208 )));
209 }
210
211 all_embeddings.extend(embeddings);
212 }
213
214 Ok(all_embeddings)
215 }
216}
217
218#[cfg(feature = "openai")]
219impl Embedder for OpenAIEmbedder {
220 fn embed(&self, text: &str) -> Result<Vec<f32>> {
221 tokio::task::block_in_place(|| {
223 tokio::runtime::Handle::current().block_on(self.embed_async(text))
224 })
225 }
226
227 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
228 tokio::task::block_in_place(|| {
229 tokio::runtime::Handle::current().block_on(self.embed_batch_async(texts))
230 })
231 }
232
233 fn dimensions(&self) -> usize {
234 self.dimensions
235 }
236
237 fn model_name(&self) -> &str {
238 &self.model
239 }
240}
241
242pub fn create_embedder(config: &EmbeddingConfig) -> Result<Arc<dyn Embedder>> {
253 match config.model.as_str() {
254 #[cfg(feature = "openai")]
255 "openai" => {
256 let api_key = config
257 .api_key
258 .clone()
259 .ok_or_else(|| EngramError::Config(
260 "OPENAI_API_KEY required when ENGRAM_EMBEDDING_MODEL=openai".to_string()
261 ))?;
262 Ok(Arc::new(OpenAIEmbedder::with_config(
263 api_key,
264 config.base_url.clone(),
265 config.embedding_model.clone(),
266 Some(config.dimensions),
267 )))
268 }
269 #[cfg(not(feature = "openai"))]
270 "openai" => Err(EngramError::Config(
271 "OpenAI embeddings require the 'openai' feature to be enabled. Build with: cargo build --features openai".to_string(),
272 )),
273 "tfidf" => Ok(Arc::new(TfIdfEmbedder::new(config.dimensions))),
274 _ => Err(EngramError::Config(format!(
275 "Unknown embedding model: '{}'. Use 'openai' or 'tfidf'",
276 config.model
277 ))),
278 }
279}
280
281pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
283 if a.len() != b.len() || a.is_empty() {
284 return 0.0;
285 }
286
287 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
288 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
289 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
290
291 if norm_a == 0.0 || norm_b == 0.0 {
292 return 0.0;
293 }
294
295 dot / (norm_a * norm_b)
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn test_cosine_similarity() {
304 let a = vec![1.0, 0.0, 0.0];
305 let b = vec![1.0, 0.0, 0.0];
306 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
307
308 let c = vec![0.0, 1.0, 0.0];
309 assert!(cosine_similarity(&a, &c).abs() < 0.001);
310
311 let d = vec![-1.0, 0.0, 0.0];
312 assert!((cosine_similarity(&a, &d) + 1.0).abs() < 0.001);
313 }
314
315 #[test]
316 fn test_tfidf_embedder() {
317 let embedder = TfIdfEmbedder::new(384);
318 let embedding = embedder.embed("Hello world").unwrap();
319 assert_eq!(embedding.len(), 384);
320 }
321}