converge_knowledge/embedding/
mod.rs1mod openai;
29
30pub use openai::{OpenAIConfig, OpenAIEmbedding, OpenAIModel, UsageSnapshot, UsageStats};
31
32use crate::error::{Error, Result};
33use std::collections::hash_map::DefaultHasher;
34use std::hash::{Hash, Hasher};
35
36#[async_trait::async_trait]
38pub trait EmbeddingProvider: Send + Sync {
39 async fn embed(&self, text: &str) -> Result<Vec<f32>>;
41
42 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
44 let mut embeddings = Vec::with_capacity(texts.len());
45 for text in texts {
46 embeddings.push(self.embed(text).await?);
47 }
48 Ok(embeddings)
49 }
50
51 fn dimensions(&self) -> usize;
53}
54
55pub struct EmbeddingEngine {
59 provider: Box<dyn EmbeddingProvider>,
60}
61
62impl EmbeddingEngine {
63 pub fn new(dimensions: usize) -> Self {
68 Self {
69 provider: Box::new(HashEmbedding::new(dimensions)),
70 }
71 }
72
73 pub fn from_env() -> Self {
77 match OpenAIEmbedding::from_env() {
78 Ok(provider) => Self {
79 provider: Box::new(provider),
80 },
81 Err(_) => {
82 tracing::warn!("OPENAI_API_KEY not set, falling back to hash embeddings");
83 Self::new(1536) }
85 }
86 }
87
88 pub fn from_env_required() -> Result<Self> {
90 let provider = OpenAIEmbedding::from_env()?;
91 Ok(Self {
92 provider: Box::new(provider),
93 })
94 }
95
96 pub fn with_openai(api_key: impl Into<String>, model: Option<String>) -> Self {
98 Self {
99 provider: Box::new(OpenAIEmbedding::new(api_key, model)),
100 }
101 }
102
103 pub fn with_openai_config(api_key: impl Into<String>, config: OpenAIConfig) -> Self {
105 Self {
106 provider: Box::new(OpenAIEmbedding::with_config(api_key, config)),
107 }
108 }
109
110 pub fn with_provider(provider: Box<dyn EmbeddingProvider>) -> Self {
112 Self { provider }
113 }
114
115 pub fn dimensions(&self) -> usize {
117 self.provider.dimensions()
118 }
119
120 pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
122 if let Some(hash_provider) = self.as_hash_provider() {
125 hash_provider.embed_sync(text)
126 } else {
127 let rt = tokio::runtime::Handle::try_current()
129 .map(|h| h.block_on(self.provider.embed(text)))
130 .unwrap_or_else(|_| {
131 let hash = HashEmbedding::new(self.dimensions());
133 hash.embed_sync(text)
134 });
135 rt
136 }
137 }
138
139 pub async fn embed_async(&self, text: &str) -> Result<Vec<f32>> {
141 self.provider.embed(text).await
142 }
143
144 pub async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
146 self.provider.embed_batch(texts).await
147 }
148
149 fn as_hash_provider(&self) -> Option<&HashEmbedding> {
151 None }
154
155 pub fn similarity(&self, a: &[f32], b: &[f32]) -> f32 {
157 if a.len() != b.len() {
158 return 0.0;
159 }
160
161 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
162 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
163 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
164
165 if norm_a == 0.0 || norm_b == 0.0 {
166 0.0
167 } else {
168 dot / (norm_a * norm_b)
169 }
170 }
171}
172
173pub struct HashEmbedding {
175 dimensions: usize,
176}
177
178impl HashEmbedding {
179 pub fn new(dimensions: usize) -> Self {
181 Self { dimensions }
182 }
183
184 pub fn embed_sync(&self, text: &str) -> Result<Vec<f32>> {
186 if text.is_empty() {
187 return Err(Error::embedding("Cannot embed empty text"));
188 }
189
190 let mut embedding = vec![0.0f32; self.dimensions];
191 let normalized_text = text.to_lowercase();
192
193 for word in normalized_text.split_whitespace() {
195 self.add_word_embedding(&mut embedding, word, 1.0);
196 }
197
198 let words: Vec<&str> = normalized_text.split_whitespace().collect();
200 for window in words.windows(2) {
201 let bigram = format!("{} {}", window[0], window[1]);
202 self.add_word_embedding(&mut embedding, &bigram, 0.5);
203 }
204
205 for window in words.windows(3) {
207 let trigram = format!("{} {} {}", window[0], window[1], window[2]);
208 self.add_word_embedding(&mut embedding, &trigram, 0.3);
209 }
210
211 for word in words.iter() {
213 for char_ngram in word.as_bytes().windows(3) {
214 let hash = self.hash_bytes(char_ngram);
215 let idx = (hash as usize) % self.dimensions;
216 embedding[idx] += 0.1;
217 }
218 }
219
220 self.normalize(&mut embedding);
222
223 Ok(embedding)
224 }
225
226 fn add_word_embedding(&self, embedding: &mut [f32], text: &str, weight: f32) {
227 let hash = self.hash_text(text);
228 for i in 0..8 {
229 let idx = ((hash.wrapping_add(i * 0x9e3779b9)) as usize) % self.dimensions;
230 let sign = if (hash >> i) & 1 == 0 { 1.0 } else { -1.0 };
231 embedding[idx] += sign * weight;
232 }
233 }
234
235 fn hash_text(&self, text: &str) -> u64 {
236 let mut hasher = DefaultHasher::new();
237 text.hash(&mut hasher);
238 hasher.finish()
239 }
240
241 fn hash_bytes(&self, bytes: &[u8]) -> u64 {
242 let mut hasher = DefaultHasher::new();
243 bytes.hash(&mut hasher);
244 hasher.finish()
245 }
246
247 fn normalize(&self, embedding: &mut [f32]) {
248 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
249 if norm > 0.0 {
250 for x in embedding.iter_mut() {
251 *x /= norm;
252 }
253 }
254 }
255}
256
257#[async_trait::async_trait]
258impl EmbeddingProvider for HashEmbedding {
259 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
260 self.embed_sync(text)
261 }
262
263 fn dimensions(&self) -> usize {
264 self.dimensions
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 #[test]
273 fn test_embedding_dimensions() {
274 let engine = EmbeddingEngine::new(128);
275 let embedding = engine.embed("test text").unwrap();
276 assert_eq!(embedding.len(), 128);
277 }
278
279 #[test]
280 fn test_embedding_consistency() {
281 let engine = EmbeddingEngine::new(64);
282 let emb1 = engine.embed("hello world").unwrap();
283 let emb2 = engine.embed("hello world").unwrap();
284 assert_eq!(emb1, emb2);
285 }
286
287 #[test]
288 fn test_embedding_similarity() {
289 let engine = EmbeddingEngine::new(128);
290
291 let emb1 = engine.embed("rust programming language").unwrap();
292 let emb2 = engine.embed("rust programming").unwrap();
293 let emb3 = engine.embed("cooking recipes").unwrap();
294
295 let sim_similar = engine.similarity(&emb1, &emb2);
296 let sim_different = engine.similarity(&emb1, &emb3);
297
298 assert!(sim_similar > sim_different);
299 }
300
301 #[test]
302 fn test_normalized_embeddings() {
303 let engine = EmbeddingEngine::new(256);
304 let embedding = engine.embed("some text here").unwrap();
305
306 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
307 assert!((norm - 1.0).abs() < 1e-5);
308 }
309
310 #[test]
311 fn test_empty_text_error() {
312 let engine = EmbeddingEngine::new(64);
313 assert!(engine.embed("").is_err());
314 }
315}