1mod openai;
29
30pub use openai::{OpenAIConfig, OpenAIEmbedding, OpenAIModel, UsageSnapshot, UsageStats};
31
32use crate::error::{Error, Result};
33use std::any::Any;
34use std::collections::hash_map::DefaultHasher;
35use std::hash::{Hash, Hasher};
36
37#[async_trait::async_trait]
39pub trait EmbeddingProvider: Any + Send + Sync {
40 async fn embed(&self, text: &str) -> Result<Vec<f32>>;
42
43 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
45 let mut embeddings = Vec::with_capacity(texts.len());
46 for text in texts {
47 embeddings.push(self.embed(text).await?);
48 }
49 Ok(embeddings)
50 }
51
52 fn dimensions(&self) -> usize;
54
55 fn as_any(&self) -> &dyn Any;
57}
58
59pub struct EmbeddingEngine {
63 provider: Box<dyn EmbeddingProvider>,
64}
65
66impl EmbeddingEngine {
67 pub fn new(dimensions: usize) -> Self {
72 Self {
73 provider: Box::new(HashEmbedding::new(dimensions)),
74 }
75 }
76
77 pub fn from_env() -> Self {
81 if let Ok(provider) = OpenAIEmbedding::from_env() {
82 Self {
83 provider: Box::new(provider),
84 }
85 } else {
86 tracing::warn!("OPENAI_API_KEY not set, falling back to hash embeddings");
87 Self::new(1536) }
89 }
90
91 pub fn from_env_required() -> Result<Self> {
93 let provider = OpenAIEmbedding::from_env()?;
94 Ok(Self {
95 provider: Box::new(provider),
96 })
97 }
98
99 pub fn with_openai(api_key: impl Into<String>, model: Option<String>) -> Self {
101 Self {
102 provider: Box::new(OpenAIEmbedding::new(api_key, model)),
103 }
104 }
105
106 pub fn with_openai_config(api_key: impl Into<String>, config: OpenAIConfig) -> Self {
108 Self {
109 provider: Box::new(OpenAIEmbedding::with_config(api_key, config)),
110 }
111 }
112
113 pub fn with_provider(provider: Box<dyn EmbeddingProvider>) -> Self {
115 Self { provider }
116 }
117
118 pub fn dimensions(&self) -> usize {
120 self.provider.dimensions()
121 }
122
123 pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
125 self.provider.embed(text).await
126 }
127
128 pub async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
130 self.provider.embed_batch(texts).await
131 }
132
133 pub fn similarity(&self, a: &[f32], b: &[f32]) -> f32 {
135 if a.len() != b.len() {
136 return 0.0;
137 }
138
139 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
140 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
141 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
142
143 if norm_a == 0.0 || norm_b == 0.0 {
144 0.0
145 } else {
146 dot / (norm_a * norm_b)
147 }
148 }
149}
150
151pub struct HashEmbedding {
153 dimensions: usize,
154}
155
156impl HashEmbedding {
157 pub fn new(dimensions: usize) -> Self {
159 Self { dimensions }
160 }
161
162 fn embed_sync(&self, text: &str) -> Result<Vec<f32>> {
163 if text.is_empty() {
164 return Err(Error::embedding("Cannot embed empty text"));
165 }
166
167 let mut embedding = vec![0.0f32; self.dimensions];
168 let normalized_text = text.to_lowercase();
169
170 for word in normalized_text.split_whitespace() {
172 self.add_word_embedding(&mut embedding, word, 1.0);
173 }
174
175 let words: Vec<&str> = normalized_text.split_whitespace().collect();
177 for window in words.windows(2) {
178 let bigram = format!("{} {}", window[0], window[1]);
179 self.add_word_embedding(&mut embedding, &bigram, 0.5);
180 }
181
182 for window in words.windows(3) {
184 let trigram = format!("{} {} {}", window[0], window[1], window[2]);
185 self.add_word_embedding(&mut embedding, &trigram, 0.3);
186 }
187
188 for word in &words {
190 for char_ngram in word.as_bytes().windows(3) {
191 let hash = self.hash_bytes(char_ngram);
192 let idx = (hash as usize) % self.dimensions;
193 embedding[idx] += 0.1;
194 }
195 }
196
197 self.normalize(&mut embedding);
199
200 Ok(embedding)
201 }
202
203 fn add_word_embedding(&self, embedding: &mut [f32], text: &str, weight: f32) {
204 let hash = self.hash_text(text);
205 for i in 0..8 {
206 let idx = ((hash.wrapping_add(i * 0x9e37_79b9)) as usize) % self.dimensions;
207 let sign = if (hash >> i) & 1 == 0 { 1.0 } else { -1.0 };
208 embedding[idx] += sign * weight;
209 }
210 }
211
212 fn hash_text(&self, text: &str) -> u64 {
213 let mut hasher = DefaultHasher::new();
214 text.hash(&mut hasher);
215 hasher.finish()
216 }
217
218 fn hash_bytes(&self, bytes: &[u8]) -> u64 {
219 let mut hasher = DefaultHasher::new();
220 bytes.hash(&mut hasher);
221 hasher.finish()
222 }
223
224 fn normalize(&self, embedding: &mut [f32]) {
225 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
226 if norm > 0.0 {
227 for x in embedding.iter_mut() {
228 *x /= norm;
229 }
230 }
231 }
232}
233
234#[async_trait::async_trait]
235impl EmbeddingProvider for HashEmbedding {
236 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
237 self.embed_sync(text)
238 }
239
240 fn dimensions(&self) -> usize {
241 self.dimensions
242 }
243
244 fn as_any(&self) -> &dyn Any {
245 self
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[tokio::test]
254 async fn test_embedding_dimensions() {
255 let engine = EmbeddingEngine::new(128);
256 let embedding = engine.embed("test text").await.unwrap();
257 assert_eq!(embedding.len(), 128);
258 }
259
260 #[tokio::test]
261 async fn test_embedding_consistency() {
262 let engine = EmbeddingEngine::new(64);
263 let emb1 = engine.embed("hello world").await.unwrap();
264 let emb2 = engine.embed("hello world").await.unwrap();
265 assert_eq!(emb1, emb2);
266 }
267
268 #[tokio::test]
269 async fn test_embedding_similarity() {
270 let engine = EmbeddingEngine::new(128);
271
272 let emb1 = engine.embed("rust programming language").await.unwrap();
273 let emb2 = engine.embed("rust programming").await.unwrap();
274 let emb3 = engine.embed("cooking recipes").await.unwrap();
275
276 let sim_similar = engine.similarity(&emb1, &emb2);
277 let sim_different = engine.similarity(&emb1, &emb3);
278
279 assert!(sim_similar > sim_different);
280 }
281
282 #[tokio::test]
283 async fn test_normalized_embeddings() {
284 let engine = EmbeddingEngine::new(256);
285 let embedding = engine.embed("some text here").await.unwrap();
286
287 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
288 assert!((norm - 1.0).abs() < 1e-5);
289 }
290
291 #[tokio::test]
292 async fn test_empty_text_error() {
293 let engine = EmbeddingEngine::new(64);
294 assert!(engine.embed("").await.is_err());
295 }
296}