1use crate::vector_store::EMBEDDING_DIM;
6use std::collections::hash_map::DefaultHasher;
7use std::hash::{Hash, Hasher};
8
9#[derive(Debug, thiserror::Error)]
11pub enum EmbeddingError {
12 #[error("Invalid input: {0}")]
13 InvalidInput(String),
14
15 #[error("Model error: {0}")]
16 Model(String),
17}
18
19pub type Result<T> = std::result::Result<T, EmbeddingError>;
20
21pub trait EmbeddingGenerator: Send + Sync {
23 fn generate(&self, text: &str) -> Result<Vec<f32>>;
25
26 fn generate_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
28 texts.iter().map(|text| self.generate(text)).collect()
29 }
30
31 fn dimension(&self) -> usize;
33
34 fn average_embeddings(&self, embeddings: &[Vec<f32>]) -> Result<Vec<f32>> {
36 if embeddings.is_empty() {
37 return Err(EmbeddingError::InvalidInput(
38 "No embeddings to average".to_string(),
39 ));
40 }
41
42 let dim = self.dimension();
43 for emb in embeddings {
44 if emb.len() != dim {
45 return Err(EmbeddingError::InvalidInput(format!(
46 "Embedding dimension mismatch: expected {}, got {}",
47 dim,
48 emb.len()
49 )));
50 }
51 }
52
53 let mut averaged = vec![0.0; dim];
55 for emb in embeddings {
56 for (i, val) in emb.iter().enumerate() {
57 averaged[i] += val;
58 }
59 }
60
61 let count = embeddings.len() as f32;
62 for val in averaged.iter_mut() {
63 *val /= count;
64 }
65
66 let magnitude: f32 = averaged.iter().map(|x| x * x).sum::<f32>().sqrt();
68 if magnitude > 0.0 {
69 averaged.iter_mut().for_each(|x| *x /= magnitude);
70 }
71
72 Ok(averaged)
73 }
74}
75
76pub struct SimpleEmbeddingGenerator {
81 dimension: usize,
82}
83
84impl SimpleEmbeddingGenerator {
85 pub fn new() -> Self {
86 Self {
87 dimension: EMBEDDING_DIM,
88 }
89 }
90
91 fn hash_to_embedding(&self, text: &str) -> Vec<f32> {
93 let mut hasher = DefaultHasher::new();
94 text.hash(&mut hasher);
95 let base_hash = hasher.finish();
96
97 let mut embedding = Vec::with_capacity(self.dimension);
99 let mut seed = base_hash;
100
101 for i in 0..self.dimension {
102 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
104 let val = ((seed >> 16) as f32) / 65536.0; let normalized = (val * 2.0 - 1.0) * (1.0 + (i as f32 / self.dimension as f32).sin());
108 embedding.push(normalized);
109 }
110
111 let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
113 if magnitude > 0.0 {
114 embedding.iter_mut().for_each(|x| *x /= magnitude);
115 }
116
117 embedding
118 }
119}
120
121impl Default for SimpleEmbeddingGenerator {
122 fn default() -> Self {
123 Self::new()
124 }
125}
126
127impl EmbeddingGenerator for SimpleEmbeddingGenerator {
128 fn generate(&self, text: &str) -> Result<Vec<f32>> {
129 if text.is_empty() {
130 return Err(EmbeddingError::InvalidInput("Empty text".to_string()));
131 }
132
133 Ok(self.hash_to_embedding(text))
134 }
135
136 fn dimension(&self) -> usize {
137 self.dimension
138 }
139}
140
141#[cfg(feature = "onnx")]
145pub struct OnnxEmbeddingGenerator {
146 session: std::sync::Mutex<ort::session::Session>,
147 tokenizer: rust_tokenizers::tokenizer::BertTokenizer,
148 dimension: usize,
149}
150
151#[cfg(feature = "onnx")]
152impl OnnxEmbeddingGenerator {
153 pub fn new() -> Result<Self> {
157 let model_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
158 .join("models")
159 .join("all-minilm-l6-v2.onnx");
160
161 let session = ort::session::Session::builder()
163 .map_err(|e| EmbeddingError::Model(format!("Failed to create session builder: {}", e)))?
164 .commit_from_file(&model_path)
165 .map_err(|e| EmbeddingError::Model(format!("Failed to load model: {}", e)))?;
166
167 use rust_tokenizers::tokenizer::BertTokenizer;
169 use rust_tokenizers::vocab::{BertVocab, Vocab};
170
171 let vocab_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
174 .join("models")
175 .join("vocab.txt");
176
177 let vocab = BertVocab::from_file(&vocab_path)
178 .map_err(|e| EmbeddingError::Model(format!("Failed to load vocab: {}", e)))?;
179
180 let tokenizer = BertTokenizer::from_existing_vocab(vocab, true, true);
181
182 Ok(Self {
183 session: std::sync::Mutex::new(session),
184 tokenizer,
185 dimension: 384, })
187 }
188
189 fn mean_pooling(
191 &self,
192 token_embeddings: &ndarray::ArrayD<f32>,
193 attention_mask: &[i64],
194 ) -> Vec<f32> {
195 let shape = token_embeddings.shape();
196 let seq_len = shape[1];
197 let hidden_dim = shape[2];
198
199 let mut pooled = vec![0.0f32; hidden_dim];
200 let mut mask_sum = 0.0f32;
201
202 for i in 0..seq_len {
203 let mask_val = attention_mask[i] as f32;
204 mask_sum += mask_val;
205
206 for j in 0..hidden_dim {
207 pooled[j] += token_embeddings[[0, i, j]] * mask_val;
208 }
209 }
210
211 if mask_sum > 0.0 {
213 for val in pooled.iter_mut() {
214 *val /= mask_sum;
215 }
216 }
217
218 let norm: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
220 if norm > 0.0 {
221 pooled.iter_mut().for_each(|x| *x /= norm);
222 }
223
224 pooled
225 }
226}
227
228#[cfg(feature = "onnx")]
229impl EmbeddingGenerator for OnnxEmbeddingGenerator {
230 fn generate(&self, text: &str) -> Result<Vec<f32>> {
231 if text.is_empty() {
232 return Err(EmbeddingError::InvalidInput("Empty text".to_string()));
233 }
234
235 use rust_tokenizers::tokenizer::{Tokenizer, TruncationStrategy};
237
238 let tokenized = self.tokenizer.encode(
239 text,
240 None,
241 512, &TruncationStrategy::LongestFirst,
243 0, );
245
246 let input_ids: Vec<i64> = tokenized.token_ids.iter().map(|&x| x as i64).collect();
247 let attention_mask: Vec<i64> = tokenized.segment_ids.iter().map(|_| 1i64).collect();
248 let token_type_ids: Vec<i64> = tokenized.segment_ids.iter().map(|&x| x as i64).collect();
249
250 let seq_len = input_ids.len();
252
253 let input_ids_shape = ort::tensor::Shape::from(vec![1usize, seq_len]);
255 let input_ids_ref =
256 ort::value::TensorRef::from_array_view((input_ids_shape.clone(), input_ids.as_slice()))
257 .map_err(|e| {
258 EmbeddingError::Model(format!("Failed to create input_ids tensor: {}", e))
259 })?;
260
261 let attention_mask_ref = ort::value::TensorRef::from_array_view((
262 input_ids_shape.clone(),
263 attention_mask.as_slice(),
264 ))
265 .map_err(|e| {
266 EmbeddingError::Model(format!("Failed to create attention_mask tensor: {}", e))
267 })?;
268
269 let token_type_ids_ref =
270 ort::value::TensorRef::from_array_view((input_ids_shape, token_type_ids.as_slice()))
271 .map_err(|e| {
272 EmbeddingError::Model(format!("Failed to create token_type_ids tensor: {}", e))
273 })?;
274
275 let mut session = self
277 .session
278 .lock()
279 .map_err(|e| EmbeddingError::Model(format!("Failed to lock session: {}", e)))?;
280
281 let outputs = session
283 .run(ort::inputs![
284 "input_ids" => input_ids_ref,
285 "attention_mask" => attention_mask_ref,
286 "token_type_ids" => token_type_ids_ref
287 ])
288 .map_err(|e| EmbeddingError::Model(format!("Inference failed: {}", e)))?;
289
290 let output_tensor = outputs
292 .get("last_hidden_state")
293 .or_else(|| outputs.get("output"))
294 .unwrap_or(&outputs[0])
295 .try_extract_tensor::<f32>()
296 .map_err(|e| EmbeddingError::Model(format!("Failed to extract tensor: {}", e)))?;
297
298 let (_shape, data) = output_tensor;
300
301 use ndarray::ArrayD;
303 let array = ArrayD::from_shape_vec(vec![1, seq_len, self.dimension], data.to_vec())
304 .map_err(|e| EmbeddingError::Model(format!("Failed to reshape output: {}", e)))?;
305
306 let embedding = self.mean_pooling(&array, &attention_mask);
308
309 Ok(embedding)
310 }
311
312 fn dimension(&self) -> usize {
313 self.dimension
314 }
315
316 fn generate_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
318 if texts.is_empty() {
319 return Ok(Vec::new());
320 }
321
322 if texts.len() == 1 {
324 return Ok(vec![self.generate(texts[0])?]);
325 }
326
327 let mut embeddings = Vec::with_capacity(texts.len());
330 for text in texts {
331 embeddings.push(self.generate(text)?);
332 }
333
334 Ok(embeddings)
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn test_simple_generator_dimension() {
344 let generator = SimpleEmbeddingGenerator::new();
345 assert_eq!(generator.dimension(), EMBEDDING_DIM);
346 }
347
348 #[test]
349 fn test_simple_generator_basic() {
350 let generator = SimpleEmbeddingGenerator::new();
351
352 let text = "This is a test document";
353 let embedding = generator.generate(text).unwrap();
354
355 assert_eq!(embedding.len(), EMBEDDING_DIM);
356
357 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
359 assert!((norm - 1.0).abs() < 0.001);
360 }
361
362 #[test]
363 fn test_simple_generator_deterministic() {
364 let generator = SimpleEmbeddingGenerator::new();
365
366 let text = "Hello world";
367 let embedding1 = generator.generate(text).unwrap();
368 let embedding2 = generator.generate(text).unwrap();
369
370 assert_eq!(embedding1, embedding2);
372 }
373
374 #[test]
375 fn test_simple_generator_different_texts() {
376 let generator = SimpleEmbeddingGenerator::new();
377
378 let text1 = "First document";
379 let text2 = "Second document";
380
381 let embedding1 = generator.generate(text1).unwrap();
382 let embedding2 = generator.generate(text2).unwrap();
383
384 assert_ne!(embedding1, embedding2);
386 }
387
388 #[test]
389 fn test_simple_generator_empty_text() {
390 let generator = SimpleEmbeddingGenerator::new();
391
392 let result = generator.generate("");
393 assert!(result.is_err());
394 }
395
396 #[test]
397 fn test_batch_generation() {
398 let generator = SimpleEmbeddingGenerator::new();
399
400 let texts = vec!["First", "Second", "Third"];
401 let embeddings = generator.generate_batch(&texts).unwrap();
402
403 assert_eq!(embeddings.len(), 3);
404 for embedding in embeddings {
405 assert_eq!(embedding.len(), EMBEDDING_DIM);
406 }
407 }
408
409 #[test]
410 fn test_similar_texts_produce_similar_embeddings() {
411 let generator = SimpleEmbeddingGenerator::new();
412
413 let text1 = "The quick brown fox";
414 let text2 = "The quick brown fox jumps";
415
416 let embedding1 = generator.generate(text1).unwrap();
417 let embedding2 = generator.generate(text2).unwrap();
418
419 let dot_product: f32 = embedding1
421 .iter()
422 .zip(embedding2.iter())
423 .map(|(a, b)| a * b)
424 .sum();
425
426 assert!(dot_product.abs() <= 1.0);
430 }
431
432 #[test]
433 fn test_average_embeddings_basic() {
434 let generator = SimpleEmbeddingGenerator::new();
435
436 let text1 = "First chunk";
438 let text2 = "Second chunk";
439 let text3 = "Third chunk";
440
441 let emb1 = generator.generate(text1).unwrap();
442 let emb2 = generator.generate(text2).unwrap();
443 let emb3 = generator.generate(text3).unwrap();
444
445 let embeddings = vec![emb1, emb2, emb3];
446 let averaged = generator.average_embeddings(&embeddings).unwrap();
447
448 assert_eq!(averaged.len(), EMBEDDING_DIM);
450
451 let norm: f32 = averaged.iter().map(|x| x * x).sum::<f32>().sqrt();
453 assert!((norm - 1.0).abs() < 0.001);
454 }
455
456 #[test]
457 fn test_average_embeddings_single() {
458 let generator = SimpleEmbeddingGenerator::new();
459
460 let text = "Single chunk";
461 let embedding = generator.generate(text).unwrap();
462
463 let embeddings = vec![embedding.clone()];
464 let averaged = generator.average_embeddings(&embeddings).unwrap();
465
466 assert_eq!(averaged.len(), embedding.len());
468
469 let norm: f32 = averaged.iter().map(|x| x * x).sum::<f32>().sqrt();
471 assert!((norm - 1.0).abs() < 0.001);
472 }
473
474 #[test]
475 fn test_average_embeddings_empty() {
476 let generator = SimpleEmbeddingGenerator::new();
477
478 let embeddings: Vec<Vec<f32>> = vec![];
479 let result = generator.average_embeddings(&embeddings);
480
481 assert!(result.is_err());
482 match result {
483 Err(EmbeddingError::InvalidInput(msg)) => {
484 assert_eq!(msg, "No embeddings to average");
485 }
486 _ => panic!("Expected InvalidInput error"),
487 }
488 }
489
490 #[test]
491 fn test_average_embeddings_dimension_mismatch() {
492 let generator = SimpleEmbeddingGenerator::new();
493
494 let emb1 = generator.generate("First").unwrap();
495 let emb2 = vec![0.5; 128]; let embeddings = vec![emb1, emb2];
498 let result = generator.average_embeddings(&embeddings);
499
500 assert!(result.is_err());
501 match result {
502 Err(EmbeddingError::InvalidInput(msg)) => {
503 assert!(msg.contains("dimension mismatch"));
504 }
505 _ => panic!("Expected InvalidInput error for dimension mismatch"),
506 }
507 }
508
509 #[test]
510 fn test_average_embeddings_hierarchical() {
511 let generator = SimpleEmbeddingGenerator::new();
512
513 let child1 = generator.generate("Child chunk 1 content").unwrap();
515 let child2 = generator.generate("Child chunk 2 content").unwrap();
516 let child3 = generator.generate("Child chunk 3 content").unwrap();
517
518 let children = vec![child1.clone(), child2.clone(), child3.clone()];
519 let parent_embedding = generator.average_embeddings(&children).unwrap();
520
521 let norm: f32 = parent_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
523 assert!((norm - 1.0).abs() < 0.001);
524
525 assert_ne!(parent_embedding, child1);
527 assert_ne!(parent_embedding, child2);
528 assert_ne!(parent_embedding, child3);
529
530 let similarity1: f32 = parent_embedding
532 .iter()
533 .zip(child1.iter())
534 .map(|(a, b)| a * b)
535 .sum();
536 let similarity2: f32 = parent_embedding
537 .iter()
538 .zip(child2.iter())
539 .map(|(a, b)| a * b)
540 .sum();
541 let similarity3: f32 = parent_embedding
542 .iter()
543 .zip(child3.iter())
544 .map(|(a, b)| a * b)
545 .sum();
546
547 assert!(similarity1 > 0.0 && similarity1 <= 1.0);
549 assert!(similarity2 > 0.0 && similarity2 <= 1.0);
550 assert!(similarity3 > 0.0 && similarity3 <= 1.0);
551 }
552
553 #[test]
554 #[ignore] fn test_onnx_generator_basic() {
556 let generator = OnnxEmbeddingGenerator::new().expect("Failed to create ONNX generator");
557
558 let text = "This is a test sentence";
559 let embedding = generator.generate(text).unwrap();
560
561 assert_eq!(embedding.len(), 384);
563
564 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
566 assert!((norm - 1.0).abs() < 0.001, "Norm was {}", norm);
567 }
568
569 #[test]
570 #[ignore] fn test_onnx_semantic_similarity() {
572 let generator = OnnxEmbeddingGenerator::new().expect("Failed to create ONNX generator");
573
574 let text1 = "I love programming in Rust";
576 let text2 = "Rust programming is great";
577
578 let text3 = "The weather is sunny today";
580
581 let emb1 = generator.generate(text1).unwrap();
582 let emb2 = generator.generate(text2).unwrap();
583 let emb3 = generator.generate(text3).unwrap();
584
585 let sim_1_2: f32 = emb1.iter().zip(emb2.iter()).map(|(a, b)| a * b).sum();
587 let sim_1_3: f32 = emb1.iter().zip(emb3.iter()).map(|(a, b)| a * b).sum();
588
589 assert!(
591 sim_1_2 > sim_1_3,
592 "Similar sentences should have higher cosine similarity"
593 );
594 println!("Similarity (Rust/Rust): {:.4}", sim_1_2);
595 println!("Similarity (Rust/Weather): {:.4}", sim_1_3);
596
597 assert!(
599 sim_1_2 > 0.5,
600 "Similar sentences should have similarity > 0.5"
601 );
602 }
603
604 #[test]
605 #[ignore] fn test_onnx_vector_ops() {
607 let generator = OnnxEmbeddingGenerator::new().expect("Failed to create ONNX generator");
608
609 let question = "How do I create a vector in Rust?";
611 let answer1 = "Use Vec::new() to create an empty vector";
612 let answer2 = "The vec! macro creates a vector with initial values";
613 let unrelated = "Python is a popular programming language";
614
615 let q_emb = generator.generate(question).unwrap();
616 let a1_emb = generator.generate(answer1).unwrap();
617 let a2_emb = generator.generate(answer2).unwrap();
618 let un_emb = generator.generate(unrelated).unwrap();
619
620 let sim_q_a1: f32 = q_emb.iter().zip(a1_emb.iter()).map(|(a, b)| a * b).sum();
621 let sim_q_a2: f32 = q_emb.iter().zip(a2_emb.iter()).map(|(a, b)| a * b).sum();
622 let sim_q_un: f32 = q_emb.iter().zip(un_emb.iter()).map(|(a, b)| a * b).sum();
623
624 println!("Question-Answer1 similarity: {:.4}", sim_q_a1);
625 println!("Question-Answer2 similarity: {:.4}", sim_q_a2);
626 println!("Question-Unrelated similarity: {:.4}", sim_q_un);
627
628 assert!(sim_q_a1 > sim_q_un);
630 assert!(sim_q_a2 > sim_q_un);
631 }
632}