Skip to main content

oxirs_vec/embeddings/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::Vector;
6use anyhow::Result;
7
8use super::types::{EmbeddableContent, EmbeddingConfig};
9
10/// Embedding generator trait
11pub trait EmbeddingGenerator: Send + Sync + AsAny {
12    /// Generate embedding for content
13    fn generate(&self, content: &EmbeddableContent) -> Result<Vector>;
14    /// Generate embeddings for multiple contents in batch
15    fn generate_batch(&self, contents: &[EmbeddableContent]) -> Result<Vec<Vector>> {
16        contents.iter().map(|c| self.generate(c)).collect()
17    }
18    /// Get the embedding dimensions
19    fn dimensions(&self) -> usize;
20    /// Get the model configuration
21    fn config(&self) -> &EmbeddingConfig;
22}
23/// Extension trait to add downcast functionality
24pub trait AsAny {
25    fn as_any(&self) -> &dyn std::any::Any;
26    fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
27}
28#[cfg(test)]
29mod tests {
30    use super::*;
31    use crate::{SentenceTransformerGenerator, TransformerModelType};
32    #[test]
33    fn test_transformer_model_types() {
34        let config = EmbeddingConfig::default();
35        let bert = SentenceTransformerGenerator::new(config.clone());
36        assert!(matches!(bert.model_type(), TransformerModelType::BERT));
37        assert_eq!(bert.dimensions(), 384);
38        let roberta = SentenceTransformerGenerator::roberta(config.clone());
39        assert!(matches!(
40            roberta.model_type(),
41            TransformerModelType::RoBERTa
42        ));
43        let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
44        assert!(matches!(
45            distilbert.model_type(),
46            TransformerModelType::DistilBERT
47        ));
48        assert_eq!(distilbert.dimensions(), 384);
49        let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
50        assert!(matches!(
51            multibert.model_type(),
52            TransformerModelType::MultiBERT
53        ));
54    }
55    #[test]
56    fn test_model_details() {
57        let config = EmbeddingConfig::default();
58        let bert = SentenceTransformerGenerator::new(config.clone());
59        let bert_details = bert.model_details();
60        assert_eq!(bert_details.vocab_size, 30522);
61        assert_eq!(bert_details.num_layers, 12);
62        assert_eq!(bert_details.hidden_size, 768);
63        assert!(bert_details.supports_languages.contains(&"en".to_string()));
64        let roberta = SentenceTransformerGenerator::roberta(config.clone());
65        let roberta_details = roberta.model_details();
66        assert_eq!(roberta_details.vocab_size, 50265);
67        assert_eq!(roberta_details.max_position_embeddings, 514);
68        let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
69        let distilbert_details = distilbert.model_details();
70        assert_eq!(distilbert_details.num_layers, 6);
71        assert_eq!(distilbert_details.hidden_size, 384);
72        assert!(distilbert_details.model_size_mb < bert_details.model_size_mb);
73        assert!(
74            distilbert_details.typical_inference_time_ms < bert_details.typical_inference_time_ms
75        );
76        let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
77        let multibert_details = multibert.model_details();
78        assert_eq!(multibert_details.vocab_size, 120000);
79        assert!(multibert_details.supports_languages.len() > 10);
80        assert!(multibert_details
81            .supports_languages
82            .contains(&"zh".to_string()));
83        assert!(multibert_details
84            .supports_languages
85            .contains(&"de".to_string()));
86    }
87    #[test]
88    fn test_language_support() {
89        let config = EmbeddingConfig::default();
90        let bert = SentenceTransformerGenerator::new(config.clone());
91        assert!(bert.supports_language("en"));
92        assert!(!bert.supports_language("zh"));
93        assert!(!bert.supports_language("de"));
94        let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
95        assert!(distilbert.supports_language("en"));
96        assert!(!distilbert.supports_language("zh"));
97        let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
98        assert!(multibert.supports_language("en"));
99        assert!(multibert.supports_language("zh"));
100        assert!(multibert.supports_language("de"));
101        assert!(multibert.supports_language("fr"));
102        assert!(multibert.supports_language("es"));
103        assert!(!multibert.supports_language("unknown_lang"));
104    }
105    #[test]
106    fn test_efficiency_ratings() {
107        let config = EmbeddingConfig::default();
108        let bert = SentenceTransformerGenerator::new(config.clone());
109        let roberta = SentenceTransformerGenerator::roberta(config.clone());
110        let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
111        let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
112        assert!(distilbert.efficiency_rating() > bert.efficiency_rating());
113        assert!(distilbert.efficiency_rating() > roberta.efficiency_rating());
114        assert!(distilbert.efficiency_rating() > multibert.efficiency_rating());
115        assert!(bert.efficiency_rating() > roberta.efficiency_rating());
116        assert!(bert.efficiency_rating() > multibert.efficiency_rating());
117        assert!(roberta.efficiency_rating() > multibert.efficiency_rating());
118    }
119    #[test]
120    fn test_inference_time_estimation() {
121        let config = EmbeddingConfig::default();
122        let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
123        let bert = SentenceTransformerGenerator::new(config.clone());
124        let short_time_distilbert = distilbert.estimate_inference_time(50);
125        let short_time_bert = bert.estimate_inference_time(50);
126        let long_time_distilbert = distilbert.estimate_inference_time(500);
127        let long_time_bert = bert.estimate_inference_time(500);
128        assert!(short_time_distilbert < short_time_bert);
129        assert!(long_time_distilbert < long_time_bert);
130        assert!(long_time_distilbert > short_time_distilbert);
131        assert!(long_time_bert > short_time_bert);
132    }
133    #[test]
134    fn test_model_specific_text_preprocessing() {
135        let config = EmbeddingConfig::default();
136        let bert = SentenceTransformerGenerator::new(config.clone());
137        let roberta = SentenceTransformerGenerator::roberta(config.clone());
138        let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
139        let text = "Hello World";
140        let bert_processed = bert.preprocess_text_for_model(text, 512).unwrap();
141        assert!(bert_processed.contains("[CLS]"));
142        assert!(bert_processed.contains("[SEP]"));
143        assert!(bert_processed.contains("hello world"));
144        let roberta_processed = roberta.preprocess_text_for_model(text, 512).unwrap();
145        assert!(roberta_processed.contains("<s>"));
146        assert!(roberta_processed.contains("</s>"));
147        assert!(roberta_processed.contains("Hello World"));
148        let latin_text = "Hello World";
149        let chinese_text = "你好世界";
150        let latin_processed = multibert
151            .preprocess_text_for_model(latin_text, 512)
152            .unwrap();
153        let chinese_processed = multibert
154            .preprocess_text_for_model(chinese_text, 512)
155            .unwrap();
156        assert!(latin_processed.contains("hello world"));
157        assert!(chinese_processed.contains("你好世界"));
158    }
159    #[test]
160    fn test_embedding_generation_differences() {
161        let config = EmbeddingConfig::default();
162        let bert = SentenceTransformerGenerator::new(config.clone());
163        let roberta = SentenceTransformerGenerator::roberta(config.clone());
164        let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
165        let content = EmbeddableContent::Text("This is a test sentence".to_string());
166        let bert_embedding = bert.generate(&content).unwrap();
167        let roberta_embedding = roberta.generate(&content).unwrap();
168        let distilbert_embedding = distilbert.generate(&content).unwrap();
169        assert_ne!(bert_embedding.as_f32(), roberta_embedding.as_f32());
170        assert_ne!(bert_embedding.as_f32(), distilbert_embedding.as_f32());
171        assert_ne!(roberta_embedding.as_f32(), distilbert_embedding.as_f32());
172        assert_eq!(distilbert_embedding.dimensions, 384);
173        assert_eq!(bert_embedding.dimensions, 384);
174        assert_eq!(roberta_embedding.dimensions, 384);
175        if config.normalize {
176            let bert_magnitude: f32 = bert_embedding
177                .as_f32()
178                .iter()
179                .map(|x| x * x)
180                .sum::<f32>()
181                .sqrt();
182            let roberta_magnitude: f32 = roberta_embedding
183                .as_f32()
184                .iter()
185                .map(|x| x * x)
186                .sum::<f32>()
187                .sqrt();
188            let distilbert_magnitude: f32 = distilbert_embedding
189                .as_f32()
190                .iter()
191                .map(|x| x * x)
192                .sum::<f32>()
193                .sqrt();
194            assert!((bert_magnitude - 1.0).abs() < 0.1);
195            assert!((roberta_magnitude - 1.0).abs() < 0.1);
196            assert!((distilbert_magnitude - 1.0).abs() < 0.1);
197        }
198    }
199    #[test]
200    fn test_tokenization_differences() {
201        let config = EmbeddingConfig::default();
202        let bert = SentenceTransformerGenerator::new(config.clone());
203        let roberta = SentenceTransformerGenerator::roberta(config.clone());
204        let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
205        let model_details_bert = bert.get_model_details();
206        let model_details_roberta = roberta.get_model_details();
207        let model_details_multibert = multibert.get_model_details();
208        let complex_word = "preprocessing";
209        let bert_tokens =
210            bert.simulate_wordpiece_tokenization(complex_word, model_details_bert.vocab_size);
211        let roberta_tokens =
212            roberta.simulate_bpe_tokenization(complex_word, model_details_roberta.vocab_size);
213        let multibert_tokens = multibert
214            .simulate_multilingual_tokenization(complex_word, model_details_multibert.vocab_size);
215        assert!(roberta_tokens.len() >= bert_tokens.len());
216        assert!(multibert_tokens.len() <= bert_tokens.len());
217        for token in &bert_tokens {
218            assert!(*token < model_details_bert.vocab_size as u32);
219        }
220        for token in &roberta_tokens {
221            assert!(*token < model_details_roberta.vocab_size as u32);
222        }
223        for token in &multibert_tokens {
224            assert!(*token < model_details_multibert.vocab_size as u32);
225        }
226    }
227    #[test]
228    fn test_model_size_comparisons() {
229        let config = EmbeddingConfig::default();
230        let bert = SentenceTransformerGenerator::new(config.clone());
231        let roberta = SentenceTransformerGenerator::roberta(config.clone());
232        let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
233        let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
234        let bert_size = bert.model_size_mb();
235        let roberta_size = roberta.model_size_mb();
236        let distilbert_size = distilbert.model_size_mb();
237        let multibert_size = multibert.model_size_mb();
238        assert!(distilbert_size < bert_size);
239        assert!(distilbert_size < roberta_size);
240        assert!(distilbert_size < multibert_size);
241        assert!(multibert_size > bert_size);
242        assert!(multibert_size > roberta_size);
243        assert!(multibert_size > distilbert_size);
244        assert!(roberta_size > bert_size);
245    }
246}