frame_catalog/
embeddings.rs

1//! Embedding generation for text chunks
2//!
3//! Provides text embedding functionality for vector search
4
5use crate::vector_store::EMBEDDING_DIM;
6use std::collections::hash_map::DefaultHasher;
7use std::hash::{Hash, Hasher};
8
9/// Embedding generator error
10#[derive(Debug, thiserror::Error)]
11pub enum EmbeddingError {
12    #[error("Invalid input: {0}")]
13    InvalidInput(String),
14
15    #[error("Model error: {0}")]
16    Model(String),
17}
18
19pub type Result<T> = std::result::Result<T, EmbeddingError>;
20
21/// Trait for embedding generators
22pub trait EmbeddingGenerator: Send + Sync {
23    /// Generate an embedding for the given text
24    fn generate(&self, text: &str) -> Result<Vec<f32>>;
25
26    /// Generate embeddings for multiple texts (batch processing)
27    fn generate_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
28        texts.iter().map(|text| self.generate(text)).collect()
29    }
30
31    /// Get the embedding dimension
32    fn dimension(&self) -> usize;
33
34    /// Average multiple embeddings (for hierarchical parent chunks)
35    fn average_embeddings(&self, embeddings: &[Vec<f32>]) -> Result<Vec<f32>> {
36        if embeddings.is_empty() {
37            return Err(EmbeddingError::InvalidInput(
38                "No embeddings to average".to_string(),
39            ));
40        }
41
42        let dim = self.dimension();
43        for emb in embeddings {
44            if emb.len() != dim {
45                return Err(EmbeddingError::InvalidInput(format!(
46                    "Embedding dimension mismatch: expected {}, got {}",
47                    dim,
48                    emb.len()
49                )));
50            }
51        }
52
53        // Average all dimensions
54        let mut averaged = vec![0.0; dim];
55        for emb in embeddings {
56            for (i, val) in emb.iter().enumerate() {
57                averaged[i] += val;
58            }
59        }
60
61        let count = embeddings.len() as f32;
62        for val in averaged.iter_mut() {
63            *val /= count;
64        }
65
66        // Re-normalize the averaged vector
67        let magnitude: f32 = averaged.iter().map(|x| x * x).sum::<f32>().sqrt();
68        if magnitude > 0.0 {
69            averaged.iter_mut().for_each(|x| *x /= magnitude);
70        }
71
72        Ok(averaged)
73    }
74}
75
76/// Simple hash-based embedding generator (for testing/fallback)
77///
78/// This creates deterministic but not semantically meaningful embeddings.
79/// In production, use a proper embedding model.
80pub struct SimpleEmbeddingGenerator {
81    dimension: usize,
82}
83
84impl SimpleEmbeddingGenerator {
85    pub fn new() -> Self {
86        Self {
87            dimension: EMBEDDING_DIM,
88        }
89    }
90
91    /// Create embeddings using a deterministic hash function
92    fn hash_to_embedding(&self, text: &str) -> Vec<f32> {
93        let mut hasher = DefaultHasher::new();
94        text.hash(&mut hasher);
95        let base_hash = hasher.finish();
96
97        // Generate deterministic values based on hash
98        let mut embedding = Vec::with_capacity(self.dimension);
99        let mut seed = base_hash;
100
101        for i in 0..self.dimension {
102            // Simple LCG (Linear Congruential Generator)
103            seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
104            let val = ((seed >> 16) as f32) / 65536.0; // Normalize to [0, 1]
105
106            // Convert to [-1, 1] and add some structure
107            let normalized = (val * 2.0 - 1.0) * (1.0 + (i as f32 / self.dimension as f32).sin());
108            embedding.push(normalized);
109        }
110
111        // Normalize the vector
112        let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
113        if magnitude > 0.0 {
114            embedding.iter_mut().for_each(|x| *x /= magnitude);
115        }
116
117        embedding
118    }
119}
120
121impl Default for SimpleEmbeddingGenerator {
122    fn default() -> Self {
123        Self::new()
124    }
125}
126
127impl EmbeddingGenerator for SimpleEmbeddingGenerator {
128    fn generate(&self, text: &str) -> Result<Vec<f32>> {
129        if text.is_empty() {
130            return Err(EmbeddingError::InvalidInput("Empty text".to_string()));
131        }
132
133        Ok(self.hash_to_embedding(text))
134    }
135
136    fn dimension(&self) -> usize {
137        self.dimension
138    }
139}
140
141/// ONNX-based embedding generator using sentence-transformers model
142///
143/// Uses all-MiniLM-L6-v2 model for 384-dimensional semantic embeddings
144#[cfg(feature = "onnx")]
145pub struct OnnxEmbeddingGenerator {
146    session: std::sync::Mutex<ort::session::Session>,
147    tokenizer: rust_tokenizers::tokenizer::BertTokenizer,
148    dimension: usize,
149}
150
151#[cfg(feature = "onnx")]
152impl OnnxEmbeddingGenerator {
153    /// Create a new ONNX embedding generator
154    ///
155    /// Model and tokenizer should be in the models/ directory
156    pub fn new() -> Result<Self> {
157        let model_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
158            .join("models")
159            .join("all-minilm-l6-v2.onnx");
160
161        // Load ONNX model with ort 2.0 API
162        let session = ort::session::Session::builder()
163            .map_err(|e| EmbeddingError::Model(format!("Failed to create session builder: {}", e)))?
164            .commit_from_file(&model_path)
165            .map_err(|e| EmbeddingError::Model(format!("Failed to load model: {}", e)))?;
166
167        // Load tokenizer using rust_tokenizers (pure Rust, no native dependencies)
168        use rust_tokenizers::tokenizer::BertTokenizer;
169        use rust_tokenizers::vocab::{BertVocab, Vocab};
170
171        // rust_tokenizers expects vocab.txt, so we convert the tokenizer.json if needed
172        // For now, use a simple approach - create from vocab
173        let vocab_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
174            .join("models")
175            .join("vocab.txt");
176
177        let vocab = BertVocab::from_file(&vocab_path)
178            .map_err(|e| EmbeddingError::Model(format!("Failed to load vocab: {}", e)))?;
179
180        let tokenizer = BertTokenizer::from_existing_vocab(vocab, true, true);
181
182        Ok(Self {
183            session: std::sync::Mutex::new(session),
184            tokenizer,
185            dimension: 384, // all-MiniLM-L6-v2 dimension
186        })
187    }
188
189    /// Mean pooling over token embeddings
190    fn mean_pooling(
191        &self,
192        token_embeddings: &ndarray::ArrayD<f32>,
193        attention_mask: &[i64],
194    ) -> Vec<f32> {
195        let shape = token_embeddings.shape();
196        let seq_len = shape[1];
197        let hidden_dim = shape[2];
198
199        let mut pooled = vec![0.0f32; hidden_dim];
200        let mut mask_sum = 0.0f32;
201
202        for i in 0..seq_len {
203            let mask_val = attention_mask[i] as f32;
204            mask_sum += mask_val;
205
206            for j in 0..hidden_dim {
207                pooled[j] += token_embeddings[[0, i, j]] * mask_val;
208            }
209        }
210
211        // Divide by sum of mask (number of actual tokens)
212        if mask_sum > 0.0 {
213            for val in pooled.iter_mut() {
214                *val /= mask_sum;
215            }
216        }
217
218        // L2 normalize
219        let norm: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
220        if norm > 0.0 {
221            pooled.iter_mut().for_each(|x| *x /= norm);
222        }
223
224        pooled
225    }
226}
227
228#[cfg(feature = "onnx")]
229impl EmbeddingGenerator for OnnxEmbeddingGenerator {
230    fn generate(&self, text: &str) -> Result<Vec<f32>> {
231        if text.is_empty() {
232            return Err(EmbeddingError::InvalidInput("Empty text".to_string()));
233        }
234
235        // Tokenize using rust_tokenizers API
236        use rust_tokenizers::tokenizer::{Tokenizer, TruncationStrategy};
237
238        let tokenized = self.tokenizer.encode(
239            text,
240            None,
241            512, // max_len
242            &TruncationStrategy::LongestFirst,
243            0, // stride
244        );
245
246        let input_ids: Vec<i64> = tokenized.token_ids.iter().map(|&x| x as i64).collect();
247        let attention_mask: Vec<i64> = tokenized.segment_ids.iter().map(|_| 1i64).collect();
248        let token_type_ids: Vec<i64> = tokenized.segment_ids.iter().map(|&x| x as i64).collect();
249
250        // Create ONNX input tensors using ort 2.0 API
251        let seq_len = input_ids.len();
252
253        // Create TensorRef for each input
254        let input_ids_shape = ort::tensor::Shape::from(vec![1usize, seq_len]);
255        let input_ids_ref =
256            ort::value::TensorRef::from_array_view((input_ids_shape.clone(), input_ids.as_slice()))
257                .map_err(|e| {
258                    EmbeddingError::Model(format!("Failed to create input_ids tensor: {}", e))
259                })?;
260
261        let attention_mask_ref = ort::value::TensorRef::from_array_view((
262            input_ids_shape.clone(),
263            attention_mask.as_slice(),
264        ))
265        .map_err(|e| {
266            EmbeddingError::Model(format!("Failed to create attention_mask tensor: {}", e))
267        })?;
268
269        let token_type_ids_ref =
270            ort::value::TensorRef::from_array_view((input_ids_shape, token_type_ids.as_slice()))
271                .map_err(|e| {
272                    EmbeddingError::Model(format!("Failed to create token_type_ids tensor: {}", e))
273                })?;
274
275        // Lock the session mutex for inference
276        let mut session = self
277            .session
278            .lock()
279            .map_err(|e| EmbeddingError::Model(format!("Failed to lock session: {}", e)))?;
280
281        // Run inference using ort 2.0 inputs! macro with named inputs
282        let outputs = session
283            .run(ort::inputs![
284                "input_ids" => input_ids_ref,
285                "attention_mask" => attention_mask_ref,
286                "token_type_ids" => token_type_ids_ref
287            ])
288            .map_err(|e| EmbeddingError::Model(format!("Inference failed: {}", e)))?;
289
290        // Extract token embeddings (last hidden state)
291        let output_tensor = outputs
292            .get("last_hidden_state")
293            .or_else(|| outputs.get("output"))
294            .unwrap_or(&outputs[0])
295            .try_extract_tensor::<f32>()
296            .map_err(|e| EmbeddingError::Model(format!("Failed to extract tensor: {}", e)))?;
297
298        // Destructure tuple: try_extract_tensor returns (&Shape, &[f32])
299        let (_shape, data) = output_tensor;
300
301        // Convert to ndarray for mean pooling
302        use ndarray::ArrayD;
303        let array = ArrayD::from_shape_vec(vec![1, seq_len, self.dimension], data.to_vec())
304            .map_err(|e| EmbeddingError::Model(format!("Failed to reshape output: {}", e)))?;
305
306        // Mean pooling
307        let embedding = self.mean_pooling(&array, &attention_mask);
308
309        Ok(embedding)
310    }
311
312    fn dimension(&self) -> usize {
313        self.dimension
314    }
315
316    /// Generate embeddings for multiple texts efficiently (batch processing)
317    fn generate_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
318        if texts.is_empty() {
319            return Ok(Vec::new());
320        }
321
322        // For small batches, single-item processing might be faster due to overhead
323        if texts.len() == 1 {
324            return Ok(vec![self.generate(texts[0])?]);
325        }
326
327        // For now, process sequentially (true batch processing is complex with ONNX)
328        // This is still faster than calling from different parts of code due to cache locality
329        let mut embeddings = Vec::with_capacity(texts.len());
330        for text in texts {
331            embeddings.push(self.generate(text)?);
332        }
333
334        Ok(embeddings)
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn test_simple_generator_dimension() {
344        let generator = SimpleEmbeddingGenerator::new();
345        assert_eq!(generator.dimension(), EMBEDDING_DIM);
346    }
347
348    #[test]
349    fn test_simple_generator_basic() {
350        let generator = SimpleEmbeddingGenerator::new();
351
352        let text = "This is a test document";
353        let embedding = generator.generate(text).unwrap();
354
355        assert_eq!(embedding.len(), EMBEDDING_DIM);
356
357        // Check that it's normalized (L2 norm should be close to 1)
358        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
359        assert!((norm - 1.0).abs() < 0.001);
360    }
361
362    #[test]
363    fn test_simple_generator_deterministic() {
364        let generator = SimpleEmbeddingGenerator::new();
365
366        let text = "Hello world";
367        let embedding1 = generator.generate(text).unwrap();
368        let embedding2 = generator.generate(text).unwrap();
369
370        // Should produce the same embedding for the same text
371        assert_eq!(embedding1, embedding2);
372    }
373
374    #[test]
375    fn test_simple_generator_different_texts() {
376        let generator = SimpleEmbeddingGenerator::new();
377
378        let text1 = "First document";
379        let text2 = "Second document";
380
381        let embedding1 = generator.generate(text1).unwrap();
382        let embedding2 = generator.generate(text2).unwrap();
383
384        // Different texts should produce different embeddings
385        assert_ne!(embedding1, embedding2);
386    }
387
388    #[test]
389    fn test_simple_generator_empty_text() {
390        let generator = SimpleEmbeddingGenerator::new();
391
392        let result = generator.generate("");
393        assert!(result.is_err());
394    }
395
396    #[test]
397    fn test_batch_generation() {
398        let generator = SimpleEmbeddingGenerator::new();
399
400        let texts = vec!["First", "Second", "Third"];
401        let embeddings = generator.generate_batch(&texts).unwrap();
402
403        assert_eq!(embeddings.len(), 3);
404        for embedding in embeddings {
405            assert_eq!(embedding.len(), EMBEDDING_DIM);
406        }
407    }
408
409    #[test]
410    fn test_similar_texts_produce_similar_embeddings() {
411        let generator = SimpleEmbeddingGenerator::new();
412
413        let text1 = "The quick brown fox";
414        let text2 = "The quick brown fox jumps";
415
416        let embedding1 = generator.generate(text1).unwrap();
417        let embedding2 = generator.generate(text2).unwrap();
418
419        // Calculate cosine similarity
420        let dot_product: f32 = embedding1
421            .iter()
422            .zip(embedding2.iter())
423            .map(|(a, b)| a * b)
424            .sum();
425
426        // Embeddings are normalized, so dot product = cosine similarity
427        // Similar texts should have some similarity (though this is a simple hash-based approach)
428        // We just check that the computation works
429        assert!(dot_product.abs() <= 1.0);
430    }
431
432    #[test]
433    fn test_average_embeddings_basic() {
434        let generator = SimpleEmbeddingGenerator::new();
435
436        // Generate 3 embeddings
437        let text1 = "First chunk";
438        let text2 = "Second chunk";
439        let text3 = "Third chunk";
440
441        let emb1 = generator.generate(text1).unwrap();
442        let emb2 = generator.generate(text2).unwrap();
443        let emb3 = generator.generate(text3).unwrap();
444
445        let embeddings = vec![emb1, emb2, emb3];
446        let averaged = generator.average_embeddings(&embeddings).unwrap();
447
448        // Check dimension
449        assert_eq!(averaged.len(), EMBEDDING_DIM);
450
451        // Check that it's normalized (L2 norm should be close to 1)
452        let norm: f32 = averaged.iter().map(|x| x * x).sum::<f32>().sqrt();
453        assert!((norm - 1.0).abs() < 0.001);
454    }
455
456    #[test]
457    fn test_average_embeddings_single() {
458        let generator = SimpleEmbeddingGenerator::new();
459
460        let text = "Single chunk";
461        let embedding = generator.generate(text).unwrap();
462
463        let embeddings = vec![embedding.clone()];
464        let averaged = generator.average_embeddings(&embeddings).unwrap();
465
466        // Averaging a single embedding should return the same embedding
467        assert_eq!(averaged.len(), embedding.len());
468
469        // Check that it's still normalized
470        let norm: f32 = averaged.iter().map(|x| x * x).sum::<f32>().sqrt();
471        assert!((norm - 1.0).abs() < 0.001);
472    }
473
474    #[test]
475    fn test_average_embeddings_empty() {
476        let generator = SimpleEmbeddingGenerator::new();
477
478        let embeddings: Vec<Vec<f32>> = vec![];
479        let result = generator.average_embeddings(&embeddings);
480
481        assert!(result.is_err());
482        match result {
483            Err(EmbeddingError::InvalidInput(msg)) => {
484                assert_eq!(msg, "No embeddings to average");
485            }
486            _ => panic!("Expected InvalidInput error"),
487        }
488    }
489
490    #[test]
491    fn test_average_embeddings_dimension_mismatch() {
492        let generator = SimpleEmbeddingGenerator::new();
493
494        let emb1 = generator.generate("First").unwrap();
495        let emb2 = vec![0.5; 128]; // Wrong dimension
496
497        let embeddings = vec![emb1, emb2];
498        let result = generator.average_embeddings(&embeddings);
499
500        assert!(result.is_err());
501        match result {
502            Err(EmbeddingError::InvalidInput(msg)) => {
503                assert!(msg.contains("dimension mismatch"));
504            }
505            _ => panic!("Expected InvalidInput error for dimension mismatch"),
506        }
507    }
508
509    #[test]
510    fn test_average_embeddings_hierarchical() {
511        let generator = SimpleEmbeddingGenerator::new();
512
513        // Simulate hierarchical scenario: parent chunk has 3 children
514        let child1 = generator.generate("Child chunk 1 content").unwrap();
515        let child2 = generator.generate("Child chunk 2 content").unwrap();
516        let child3 = generator.generate("Child chunk 3 content").unwrap();
517
518        let children = vec![child1.clone(), child2.clone(), child3.clone()];
519        let parent_embedding = generator.average_embeddings(&children).unwrap();
520
521        // Parent embedding should be normalized
522        let norm: f32 = parent_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
523        assert!((norm - 1.0).abs() < 0.001);
524
525        // Parent embedding should be different from any single child
526        assert_ne!(parent_embedding, child1);
527        assert_ne!(parent_embedding, child2);
528        assert_ne!(parent_embedding, child3);
529
530        // Calculate cosine similarity between parent and children
531        let similarity1: f32 = parent_embedding
532            .iter()
533            .zip(child1.iter())
534            .map(|(a, b)| a * b)
535            .sum();
536        let similarity2: f32 = parent_embedding
537            .iter()
538            .zip(child2.iter())
539            .map(|(a, b)| a * b)
540            .sum();
541        let similarity3: f32 = parent_embedding
542            .iter()
543            .zip(child3.iter())
544            .map(|(a, b)| a * b)
545            .sum();
546
547        // Parent should have reasonable similarity to all children
548        assert!(similarity1 > 0.0 && similarity1 <= 1.0);
549        assert!(similarity2 > 0.0 && similarity2 <= 1.0);
550        assert!(similarity3 > 0.0 && similarity3 <= 1.0);
551    }
552
553    #[test]
554    #[ignore] // Requires model files, run with --ignored
555    fn test_onnx_generator_basic() {
556        let generator = OnnxEmbeddingGenerator::new().expect("Failed to create ONNX generator");
557
558        let text = "This is a test sentence";
559        let embedding = generator.generate(text).unwrap();
560
561        // Check dimension (MiniLM-L6-v2 produces 384-dim embeddings)
562        assert_eq!(embedding.len(), 384);
563
564        // Check that it's normalized (L2 norm should be close to 1)
565        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
566        assert!((norm - 1.0).abs() < 0.001, "Norm was {}", norm);
567    }
568
569    #[test]
570    #[ignore] // Requires model files
571    fn test_onnx_semantic_similarity() {
572        let generator = OnnxEmbeddingGenerator::new().expect("Failed to create ONNX generator");
573
574        // Similar sentences
575        let text1 = "I love programming in Rust";
576        let text2 = "Rust programming is great";
577
578        // Dissimilar sentence
579        let text3 = "The weather is sunny today";
580
581        let emb1 = generator.generate(text1).unwrap();
582        let emb2 = generator.generate(text2).unwrap();
583        let emb3 = generator.generate(text3).unwrap();
584
585        // Calculate cosine similarities
586        let sim_1_2: f32 = emb1.iter().zip(emb2.iter()).map(|(a, b)| a * b).sum();
587        let sim_1_3: f32 = emb1.iter().zip(emb3.iter()).map(|(a, b)| a * b).sum();
588
589        // Similar sentences should have higher similarity than dissimilar ones
590        assert!(
591            sim_1_2 > sim_1_3,
592            "Similar sentences should have higher cosine similarity"
593        );
594        println!("Similarity (Rust/Rust): {:.4}", sim_1_2);
595        println!("Similarity (Rust/Weather): {:.4}", sim_1_3);
596
597        // Typically, similar sentences should have similarity > 0.5
598        assert!(
599            sim_1_2 > 0.5,
600            "Similar sentences should have similarity > 0.5"
601        );
602    }
603
604    #[test]
605    #[ignore] // Requires model files
606    fn test_onnx_vector_ops() {
607        let generator = OnnxEmbeddingGenerator::new().expect("Failed to create ONNX generator");
608
609        // Technical question about vectors
610        let question = "How do I create a vector in Rust?";
611        let answer1 = "Use Vec::new() to create an empty vector";
612        let answer2 = "The vec! macro creates a vector with initial values";
613        let unrelated = "Python is a popular programming language";
614
615        let q_emb = generator.generate(question).unwrap();
616        let a1_emb = generator.generate(answer1).unwrap();
617        let a2_emb = generator.generate(answer2).unwrap();
618        let un_emb = generator.generate(unrelated).unwrap();
619
620        let sim_q_a1: f32 = q_emb.iter().zip(a1_emb.iter()).map(|(a, b)| a * b).sum();
621        let sim_q_a2: f32 = q_emb.iter().zip(a2_emb.iter()).map(|(a, b)| a * b).sum();
622        let sim_q_un: f32 = q_emb.iter().zip(un_emb.iter()).map(|(a, b)| a * b).sum();
623
624        println!("Question-Answer1 similarity: {:.4}", sim_q_a1);
625        println!("Question-Answer2 similarity: {:.4}", sim_q_a2);
626        println!("Question-Unrelated similarity: {:.4}", sim_q_un);
627
628        // Both answers should be more similar to the question than the unrelated text
629        assert!(sim_q_a1 > sim_q_un);
630        assert!(sim_q_a2 > sim_q_un);
631    }
632}