oxirs_vec/embeddings/
functions.rs1use crate::Vector;
6use anyhow::Result;
7
8use super::types::{EmbeddableContent, EmbeddingConfig};
9
10pub trait EmbeddingGenerator: Send + Sync + AsAny {
12 fn generate(&self, content: &EmbeddableContent) -> Result<Vector>;
14 fn generate_batch(&self, contents: &[EmbeddableContent]) -> Result<Vec<Vector>> {
16 contents.iter().map(|c| self.generate(c)).collect()
17 }
18 fn dimensions(&self) -> usize;
20 fn config(&self) -> &EmbeddingConfig;
22}
23pub trait AsAny {
25 fn as_any(&self) -> &dyn std::any::Any;
26 fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
27}
28#[cfg(test)]
29mod tests {
30 use super::*;
31 use crate::{SentenceTransformerGenerator, TransformerModelType};
32 #[test]
33 fn test_transformer_model_types() {
34 let config = EmbeddingConfig::default();
35 let bert = SentenceTransformerGenerator::new(config.clone());
36 assert!(matches!(bert.model_type(), TransformerModelType::BERT));
37 assert_eq!(bert.dimensions(), 384);
38 let roberta = SentenceTransformerGenerator::roberta(config.clone());
39 assert!(matches!(
40 roberta.model_type(),
41 TransformerModelType::RoBERTa
42 ));
43 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
44 assert!(matches!(
45 distilbert.model_type(),
46 TransformerModelType::DistilBERT
47 ));
48 assert_eq!(distilbert.dimensions(), 384);
49 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
50 assert!(matches!(
51 multibert.model_type(),
52 TransformerModelType::MultiBERT
53 ));
54 }
55 #[test]
56 fn test_model_details() {
57 let config = EmbeddingConfig::default();
58 let bert = SentenceTransformerGenerator::new(config.clone());
59 let bert_details = bert.model_details();
60 assert_eq!(bert_details.vocab_size, 30522);
61 assert_eq!(bert_details.num_layers, 12);
62 assert_eq!(bert_details.hidden_size, 768);
63 assert!(bert_details.supports_languages.contains(&"en".to_string()));
64 let roberta = SentenceTransformerGenerator::roberta(config.clone());
65 let roberta_details = roberta.model_details();
66 assert_eq!(roberta_details.vocab_size, 50265);
67 assert_eq!(roberta_details.max_position_embeddings, 514);
68 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
69 let distilbert_details = distilbert.model_details();
70 assert_eq!(distilbert_details.num_layers, 6);
71 assert_eq!(distilbert_details.hidden_size, 384);
72 assert!(distilbert_details.model_size_mb < bert_details.model_size_mb);
73 assert!(
74 distilbert_details.typical_inference_time_ms < bert_details.typical_inference_time_ms
75 );
76 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
77 let multibert_details = multibert.model_details();
78 assert_eq!(multibert_details.vocab_size, 120000);
79 assert!(multibert_details.supports_languages.len() > 10);
80 assert!(multibert_details
81 .supports_languages
82 .contains(&"zh".to_string()));
83 assert!(multibert_details
84 .supports_languages
85 .contains(&"de".to_string()));
86 }
87 #[test]
88 fn test_language_support() {
89 let config = EmbeddingConfig::default();
90 let bert = SentenceTransformerGenerator::new(config.clone());
91 assert!(bert.supports_language("en"));
92 assert!(!bert.supports_language("zh"));
93 assert!(!bert.supports_language("de"));
94 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
95 assert!(distilbert.supports_language("en"));
96 assert!(!distilbert.supports_language("zh"));
97 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
98 assert!(multibert.supports_language("en"));
99 assert!(multibert.supports_language("zh"));
100 assert!(multibert.supports_language("de"));
101 assert!(multibert.supports_language("fr"));
102 assert!(multibert.supports_language("es"));
103 assert!(!multibert.supports_language("unknown_lang"));
104 }
105 #[test]
106 fn test_efficiency_ratings() {
107 let config = EmbeddingConfig::default();
108 let bert = SentenceTransformerGenerator::new(config.clone());
109 let roberta = SentenceTransformerGenerator::roberta(config.clone());
110 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
111 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
112 assert!(distilbert.efficiency_rating() > bert.efficiency_rating());
113 assert!(distilbert.efficiency_rating() > roberta.efficiency_rating());
114 assert!(distilbert.efficiency_rating() > multibert.efficiency_rating());
115 assert!(bert.efficiency_rating() > roberta.efficiency_rating());
116 assert!(bert.efficiency_rating() > multibert.efficiency_rating());
117 assert!(roberta.efficiency_rating() > multibert.efficiency_rating());
118 }
119 #[test]
120 fn test_inference_time_estimation() {
121 let config = EmbeddingConfig::default();
122 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
123 let bert = SentenceTransformerGenerator::new(config.clone());
124 let short_time_distilbert = distilbert.estimate_inference_time(50);
125 let short_time_bert = bert.estimate_inference_time(50);
126 let long_time_distilbert = distilbert.estimate_inference_time(500);
127 let long_time_bert = bert.estimate_inference_time(500);
128 assert!(short_time_distilbert < short_time_bert);
129 assert!(long_time_distilbert < long_time_bert);
130 assert!(long_time_distilbert > short_time_distilbert);
131 assert!(long_time_bert > short_time_bert);
132 }
133 #[test]
134 fn test_model_specific_text_preprocessing() {
135 let config = EmbeddingConfig::default();
136 let bert = SentenceTransformerGenerator::new(config.clone());
137 let roberta = SentenceTransformerGenerator::roberta(config.clone());
138 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
139 let text = "Hello World";
140 let bert_processed = bert.preprocess_text_for_model(text, 512).unwrap();
141 assert!(bert_processed.contains("[CLS]"));
142 assert!(bert_processed.contains("[SEP]"));
143 assert!(bert_processed.contains("hello world"));
144 let roberta_processed = roberta.preprocess_text_for_model(text, 512).unwrap();
145 assert!(roberta_processed.contains("<s>"));
146 assert!(roberta_processed.contains("</s>"));
147 assert!(roberta_processed.contains("Hello World"));
148 let latin_text = "Hello World";
149 let chinese_text = "ä½ å¥½ä¸–ç•Œ";
150 let latin_processed = multibert
151 .preprocess_text_for_model(latin_text, 512)
152 .unwrap();
153 let chinese_processed = multibert
154 .preprocess_text_for_model(chinese_text, 512)
155 .unwrap();
156 assert!(latin_processed.contains("hello world"));
157 assert!(chinese_processed.contains("ä½ å¥½ä¸–ç•Œ"));
158 }
159 #[test]
160 fn test_embedding_generation_differences() {
161 let config = EmbeddingConfig::default();
162 let bert = SentenceTransformerGenerator::new(config.clone());
163 let roberta = SentenceTransformerGenerator::roberta(config.clone());
164 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
165 let content = EmbeddableContent::Text("This is a test sentence".to_string());
166 let bert_embedding = bert.generate(&content).unwrap();
167 let roberta_embedding = roberta.generate(&content).unwrap();
168 let distilbert_embedding = distilbert.generate(&content).unwrap();
169 assert_ne!(bert_embedding.as_f32(), roberta_embedding.as_f32());
170 assert_ne!(bert_embedding.as_f32(), distilbert_embedding.as_f32());
171 assert_ne!(roberta_embedding.as_f32(), distilbert_embedding.as_f32());
172 assert_eq!(distilbert_embedding.dimensions, 384);
173 assert_eq!(bert_embedding.dimensions, 384);
174 assert_eq!(roberta_embedding.dimensions, 384);
175 if config.normalize {
176 let bert_magnitude: f32 = bert_embedding
177 .as_f32()
178 .iter()
179 .map(|x| x * x)
180 .sum::<f32>()
181 .sqrt();
182 let roberta_magnitude: f32 = roberta_embedding
183 .as_f32()
184 .iter()
185 .map(|x| x * x)
186 .sum::<f32>()
187 .sqrt();
188 let distilbert_magnitude: f32 = distilbert_embedding
189 .as_f32()
190 .iter()
191 .map(|x| x * x)
192 .sum::<f32>()
193 .sqrt();
194 assert!((bert_magnitude - 1.0).abs() < 0.1);
195 assert!((roberta_magnitude - 1.0).abs() < 0.1);
196 assert!((distilbert_magnitude - 1.0).abs() < 0.1);
197 }
198 }
199 #[test]
200 fn test_tokenization_differences() {
201 let config = EmbeddingConfig::default();
202 let bert = SentenceTransformerGenerator::new(config.clone());
203 let roberta = SentenceTransformerGenerator::roberta(config.clone());
204 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
205 let model_details_bert = bert.get_model_details();
206 let model_details_roberta = roberta.get_model_details();
207 let model_details_multibert = multibert.get_model_details();
208 let complex_word = "preprocessing";
209 let bert_tokens =
210 bert.simulate_wordpiece_tokenization(complex_word, model_details_bert.vocab_size);
211 let roberta_tokens =
212 roberta.simulate_bpe_tokenization(complex_word, model_details_roberta.vocab_size);
213 let multibert_tokens = multibert
214 .simulate_multilingual_tokenization(complex_word, model_details_multibert.vocab_size);
215 assert!(roberta_tokens.len() >= bert_tokens.len());
216 assert!(multibert_tokens.len() <= bert_tokens.len());
217 for token in &bert_tokens {
218 assert!(*token < model_details_bert.vocab_size as u32);
219 }
220 for token in &roberta_tokens {
221 assert!(*token < model_details_roberta.vocab_size as u32);
222 }
223 for token in &multibert_tokens {
224 assert!(*token < model_details_multibert.vocab_size as u32);
225 }
226 }
227 #[test]
228 fn test_model_size_comparisons() {
229 let config = EmbeddingConfig::default();
230 let bert = SentenceTransformerGenerator::new(config.clone());
231 let roberta = SentenceTransformerGenerator::roberta(config.clone());
232 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
233 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
234 let bert_size = bert.model_size_mb();
235 let roberta_size = roberta.model_size_mb();
236 let distilbert_size = distilbert.model_size_mb();
237 let multibert_size = multibert.model_size_mb();
238 assert!(distilbert_size < bert_size);
239 assert!(distilbert_size < roberta_size);
240 assert!(distilbert_size < multibert_size);
241 assert!(multibert_size > bert_size);
242 assert!(multibert_size > roberta_size);
243 assert!(multibert_size > distilbert_size);
244 assert!(roberta_size > bert_size);
245 }
246}