converge_knowledge/embedding/
mod.rs1mod openai;
29
30pub use openai::{OpenAIConfig, OpenAIEmbedding, OpenAIModel, UsageSnapshot, UsageStats};
31
32use crate::error::{Error, Result};
33use std::any::Any;
34use std::collections::hash_map::DefaultHasher;
35use std::hash::{Hash, Hasher};
36
37#[async_trait::async_trait]
39pub trait EmbeddingProvider: Any + Send + Sync {
40 async fn embed(&self, text: &str) -> Result<Vec<f32>>;
42
43 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
45 let mut embeddings = Vec::with_capacity(texts.len());
46 for text in texts {
47 embeddings.push(self.embed(text).await?);
48 }
49 Ok(embeddings)
50 }
51
52 fn dimensions(&self) -> usize;
54
55 fn as_any(&self) -> &dyn Any;
57}
58
59pub struct EmbeddingEngine {
63 provider: Box<dyn EmbeddingProvider>,
64}
65
66impl EmbeddingEngine {
67 pub fn new(dimensions: usize) -> Self {
72 Self {
73 provider: Box::new(HashEmbedding::new(dimensions)),
74 }
75 }
76
77 pub fn from_env() -> Self {
81 match OpenAIEmbedding::from_env() {
82 Ok(provider) => Self {
83 provider: Box::new(provider),
84 },
85 Err(_) => {
86 tracing::warn!("OPENAI_API_KEY not set, falling back to hash embeddings");
87 Self::new(1536) }
89 }
90 }
91
92 pub fn from_env_required() -> Result<Self> {
94 let provider = OpenAIEmbedding::from_env()?;
95 Ok(Self {
96 provider: Box::new(provider),
97 })
98 }
99
100 pub fn with_openai(api_key: impl Into<String>, model: Option<String>) -> Self {
102 Self {
103 provider: Box::new(OpenAIEmbedding::new(api_key, model)),
104 }
105 }
106
107 pub fn with_openai_config(api_key: impl Into<String>, config: OpenAIConfig) -> Self {
109 Self {
110 provider: Box::new(OpenAIEmbedding::with_config(api_key, config)),
111 }
112 }
113
114 pub fn with_provider(provider: Box<dyn EmbeddingProvider>) -> Self {
116 Self { provider }
117 }
118
119 pub fn dimensions(&self) -> usize {
121 self.provider.dimensions()
122 }
123
124 pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
126 if let Some(hash_provider) = self.as_hash_provider() {
129 hash_provider.embed_sync(text)
130 } else {
131 tokio::runtime::Handle::try_current()
134 .map(|h| h.block_on(self.provider.embed(text)))
135 .unwrap_or_else(|_| {
136 let hash = HashEmbedding::new(self.dimensions());
138 hash.embed_sync(text)
139 })
140 }
141 }
142
143 pub async fn embed_async(&self, text: &str) -> Result<Vec<f32>> {
145 self.provider.embed(text).await
146 }
147
148 pub async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
150 self.provider.embed_batch(texts).await
151 }
152
153 fn as_hash_provider(&self) -> Option<&HashEmbedding> {
155 self.provider.as_any().downcast_ref::<HashEmbedding>()
156 }
157
158 pub fn similarity(&self, a: &[f32], b: &[f32]) -> f32 {
160 if a.len() != b.len() {
161 return 0.0;
162 }
163
164 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
165 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
166 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
167
168 if norm_a == 0.0 || norm_b == 0.0 {
169 0.0
170 } else {
171 dot / (norm_a * norm_b)
172 }
173 }
174}
175
176pub struct HashEmbedding {
178 dimensions: usize,
179}
180
181impl HashEmbedding {
182 pub fn new(dimensions: usize) -> Self {
184 Self { dimensions }
185 }
186
187 pub fn embed_sync(&self, text: &str) -> Result<Vec<f32>> {
189 if text.is_empty() {
190 return Err(Error::embedding("Cannot embed empty text"));
191 }
192
193 let mut embedding = vec![0.0f32; self.dimensions];
194 let normalized_text = text.to_lowercase();
195
196 for word in normalized_text.split_whitespace() {
198 self.add_word_embedding(&mut embedding, word, 1.0);
199 }
200
201 let words: Vec<&str> = normalized_text.split_whitespace().collect();
203 for window in words.windows(2) {
204 let bigram = format!("{} {}", window[0], window[1]);
205 self.add_word_embedding(&mut embedding, &bigram, 0.5);
206 }
207
208 for window in words.windows(3) {
210 let trigram = format!("{} {} {}", window[0], window[1], window[2]);
211 self.add_word_embedding(&mut embedding, &trigram, 0.3);
212 }
213
214 for word in words.iter() {
216 for char_ngram in word.as_bytes().windows(3) {
217 let hash = self.hash_bytes(char_ngram);
218 let idx = (hash as usize) % self.dimensions;
219 embedding[idx] += 0.1;
220 }
221 }
222
223 self.normalize(&mut embedding);
225
226 Ok(embedding)
227 }
228
229 fn add_word_embedding(&self, embedding: &mut [f32], text: &str, weight: f32) {
230 let hash = self.hash_text(text);
231 for i in 0..8 {
232 let idx = ((hash.wrapping_add(i * 0x9e3779b9)) as usize) % self.dimensions;
233 let sign = if (hash >> i) & 1 == 0 { 1.0 } else { -1.0 };
234 embedding[idx] += sign * weight;
235 }
236 }
237
238 fn hash_text(&self, text: &str) -> u64 {
239 let mut hasher = DefaultHasher::new();
240 text.hash(&mut hasher);
241 hasher.finish()
242 }
243
244 fn hash_bytes(&self, bytes: &[u8]) -> u64 {
245 let mut hasher = DefaultHasher::new();
246 bytes.hash(&mut hasher);
247 hasher.finish()
248 }
249
250 fn normalize(&self, embedding: &mut [f32]) {
251 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
252 if norm > 0.0 {
253 for x in embedding.iter_mut() {
254 *x /= norm;
255 }
256 }
257 }
258}
259
260#[async_trait::async_trait]
261impl EmbeddingProvider for HashEmbedding {
262 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
263 self.embed_sync(text)
264 }
265
266 fn dimensions(&self) -> usize {
267 self.dimensions
268 }
269
270 fn as_any(&self) -> &dyn Any {
271 self
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_embedding_dimensions() {
281 let engine = EmbeddingEngine::new(128);
282 let embedding = engine.embed("test text").unwrap();
283 assert_eq!(embedding.len(), 128);
284 }
285
286 #[test]
287 fn test_embedding_consistency() {
288 let engine = EmbeddingEngine::new(64);
289 let emb1 = engine.embed("hello world").unwrap();
290 let emb2 = engine.embed("hello world").unwrap();
291 assert_eq!(emb1, emb2);
292 }
293
294 #[test]
295 fn test_embedding_similarity() {
296 let engine = EmbeddingEngine::new(128);
297
298 let emb1 = engine.embed("rust programming language").unwrap();
299 let emb2 = engine.embed("rust programming").unwrap();
300 let emb3 = engine.embed("cooking recipes").unwrap();
301
302 let sim_similar = engine.similarity(&emb1, &emb2);
303 let sim_different = engine.similarity(&emb1, &emb3);
304
305 assert!(sim_similar > sim_different);
306 }
307
308 #[test]
309 fn test_normalized_embeddings() {
310 let engine = EmbeddingEngine::new(256);
311 let embedding = engine.embed("some text here").unwrap();
312
313 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
314 assert!((norm - 1.0).abs() < 1e-5);
315 }
316
317 #[test]
318 fn test_empty_text_error() {
319 let engine = EmbeddingEngine::new(64);
320 assert!(engine.embed("").is_err());
321 }
322}