oxirs_vec/embeddings/
functions.rs1use crate::Vector;
6use anyhow::Result;
7
8use super::types::{EmbeddableContent, EmbeddingConfig};
9
10pub trait EmbeddingGenerator: Send + Sync + AsAny {
12 fn generate(&self, content: &EmbeddableContent) -> Result<Vector>;
14 fn generate_batch(&self, contents: &[EmbeddableContent]) -> Result<Vec<Vector>> {
16 contents.iter().map(|c| self.generate(c)).collect()
17 }
18 fn dimensions(&self) -> usize;
20 fn config(&self) -> &EmbeddingConfig;
22}
23pub trait AsAny {
25 fn as_any(&self) -> &dyn std::any::Any;
26 fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
27}
28#[cfg(test)]
29mod tests {
30 use super::*;
31 use crate::{SentenceTransformerGenerator, TransformerModelType};
32 #[test]
33 fn test_transformer_model_types() {
34 let config = EmbeddingConfig::default();
35 let bert = SentenceTransformerGenerator::new(config.clone());
36 assert!(matches!(bert.model_type(), TransformerModelType::BERT));
37 assert_eq!(bert.dimensions(), 384);
38 let roberta = SentenceTransformerGenerator::roberta(config.clone());
39 assert!(matches!(
40 roberta.model_type(),
41 TransformerModelType::RoBERTa
42 ));
43 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
44 assert!(matches!(
45 distilbert.model_type(),
46 TransformerModelType::DistilBERT
47 ));
48 assert_eq!(distilbert.dimensions(), 384);
49 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
50 assert!(matches!(
51 multibert.model_type(),
52 TransformerModelType::MultiBERT
53 ));
54 }
55 #[test]
56 fn test_model_details() {
57 let config = EmbeddingConfig::default();
58 let bert = SentenceTransformerGenerator::new(config.clone());
59 let bert_details = bert.model_details();
60 assert_eq!(bert_details.vocab_size, 30522);
61 assert_eq!(bert_details.num_layers, 12);
62 assert_eq!(bert_details.hidden_size, 768);
63 assert!(bert_details.supports_languages.contains(&"en".to_string()));
64 let roberta = SentenceTransformerGenerator::roberta(config.clone());
65 let roberta_details = roberta.model_details();
66 assert_eq!(roberta_details.vocab_size, 50265);
67 assert_eq!(roberta_details.max_position_embeddings, 514);
68 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
69 let distilbert_details = distilbert.model_details();
70 assert_eq!(distilbert_details.num_layers, 6);
71 assert_eq!(distilbert_details.hidden_size, 384);
72 assert!(distilbert_details.model_size_mb < bert_details.model_size_mb);
73 assert!(
74 distilbert_details.typical_inference_time_ms < bert_details.typical_inference_time_ms
75 );
76 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
77 let multibert_details = multibert.model_details();
78 assert_eq!(multibert_details.vocab_size, 120000);
79 assert!(multibert_details.supports_languages.len() > 10);
80 assert!(multibert_details
81 .supports_languages
82 .contains(&"zh".to_string()));
83 assert!(multibert_details
84 .supports_languages
85 .contains(&"de".to_string()));
86 }
87 #[test]
88 fn test_language_support() {
89 let config = EmbeddingConfig::default();
90 let bert = SentenceTransformerGenerator::new(config.clone());
91 assert!(bert.supports_language("en"));
92 assert!(!bert.supports_language("zh"));
93 assert!(!bert.supports_language("de"));
94 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
95 assert!(distilbert.supports_language("en"));
96 assert!(!distilbert.supports_language("zh"));
97 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
98 assert!(multibert.supports_language("en"));
99 assert!(multibert.supports_language("zh"));
100 assert!(multibert.supports_language("de"));
101 assert!(multibert.supports_language("fr"));
102 assert!(multibert.supports_language("es"));
103 assert!(!multibert.supports_language("unknown_lang"));
104 }
105 #[test]
106 fn test_efficiency_ratings() {
107 let config = EmbeddingConfig::default();
108 let bert = SentenceTransformerGenerator::new(config.clone());
109 let roberta = SentenceTransformerGenerator::roberta(config.clone());
110 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
111 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
112 assert!(distilbert.efficiency_rating() > bert.efficiency_rating());
113 assert!(distilbert.efficiency_rating() > roberta.efficiency_rating());
114 assert!(distilbert.efficiency_rating() > multibert.efficiency_rating());
115 assert!(bert.efficiency_rating() > roberta.efficiency_rating());
116 assert!(bert.efficiency_rating() > multibert.efficiency_rating());
117 assert!(roberta.efficiency_rating() > multibert.efficiency_rating());
118 }
119 #[test]
120 fn test_inference_time_estimation() {
121 let config = EmbeddingConfig::default();
122 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
123 let bert = SentenceTransformerGenerator::new(config.clone());
124 let short_time_distilbert = distilbert.estimate_inference_time(50);
125 let short_time_bert = bert.estimate_inference_time(50);
126 let long_time_distilbert = distilbert.estimate_inference_time(500);
127 let long_time_bert = bert.estimate_inference_time(500);
128 assert!(short_time_distilbert < short_time_bert);
129 assert!(long_time_distilbert < long_time_bert);
130 assert!(long_time_distilbert > short_time_distilbert);
131 assert!(long_time_bert > short_time_bert);
132 }
133 #[test]
134 fn test_model_specific_text_preprocessing() -> Result<()> {
135 let config = EmbeddingConfig::default();
136 let bert = SentenceTransformerGenerator::new(config.clone());
137 let roberta = SentenceTransformerGenerator::roberta(config.clone());
138 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
139 let text = "Hello World";
140 let bert_processed = bert.preprocess_text_for_model(text, 512)?;
141 assert!(bert_processed.contains("[CLS]"));
142 assert!(bert_processed.contains("[SEP]"));
143 assert!(bert_processed.contains("hello world"));
144 let roberta_processed = roberta.preprocess_text_for_model(text, 512)?;
145 assert!(roberta_processed.contains("<s>"));
146 assert!(roberta_processed.contains("</s>"));
147 assert!(roberta_processed.contains("Hello World"));
148 let latin_text = "Hello World";
149 let chinese_text = "ä½ å¥½ä¸–ç•Œ";
150 let latin_processed = multibert.preprocess_text_for_model(latin_text, 512)?;
151 let chinese_processed = multibert.preprocess_text_for_model(chinese_text, 512)?;
152 assert!(latin_processed.contains("hello world"));
153 assert!(chinese_processed.contains("ä½ å¥½ä¸–ç•Œ"));
154 Ok(())
155 }
156 #[test]
157 fn test_embedding_generation_differences() -> Result<()> {
158 let config = EmbeddingConfig::default();
159 let bert = SentenceTransformerGenerator::new(config.clone());
160 let roberta = SentenceTransformerGenerator::roberta(config.clone());
161 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
162 let content = EmbeddableContent::Text("This is a test sentence".to_string());
163 let bert_embedding = bert.generate(&content)?;
164 let roberta_embedding = roberta.generate(&content)?;
165 let distilbert_embedding = distilbert.generate(&content)?;
166 assert_ne!(bert_embedding.as_f32(), roberta_embedding.as_f32());
167 assert_ne!(bert_embedding.as_f32(), distilbert_embedding.as_f32());
168 assert_ne!(roberta_embedding.as_f32(), distilbert_embedding.as_f32());
169 assert_eq!(distilbert_embedding.dimensions, 384);
170 assert_eq!(bert_embedding.dimensions, 384);
171 assert_eq!(roberta_embedding.dimensions, 384);
172 if config.normalize {
173 let bert_magnitude: f32 = bert_embedding
174 .as_f32()
175 .iter()
176 .map(|x| x * x)
177 .sum::<f32>()
178 .sqrt();
179 let roberta_magnitude: f32 = roberta_embedding
180 .as_f32()
181 .iter()
182 .map(|x| x * x)
183 .sum::<f32>()
184 .sqrt();
185 let distilbert_magnitude: f32 = distilbert_embedding
186 .as_f32()
187 .iter()
188 .map(|x| x * x)
189 .sum::<f32>()
190 .sqrt();
191 assert!((bert_magnitude - 1.0).abs() < 0.1);
192 assert!((roberta_magnitude - 1.0).abs() < 0.1);
193 assert!((distilbert_magnitude - 1.0).abs() < 0.1);
194 }
195 Ok(())
196 }
197 #[test]
198 fn test_tokenization_differences() {
199 let config = EmbeddingConfig::default();
200 let bert = SentenceTransformerGenerator::new(config.clone());
201 let roberta = SentenceTransformerGenerator::roberta(config.clone());
202 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
203 let model_details_bert = bert.get_model_details();
204 let model_details_roberta = roberta.get_model_details();
205 let model_details_multibert = multibert.get_model_details();
206 let complex_word = "preprocessing";
207 let bert_tokens =
208 bert.simulate_wordpiece_tokenization(complex_word, model_details_bert.vocab_size);
209 let roberta_tokens =
210 roberta.simulate_bpe_tokenization(complex_word, model_details_roberta.vocab_size);
211 let multibert_tokens = multibert
212 .simulate_multilingual_tokenization(complex_word, model_details_multibert.vocab_size);
213 assert!(roberta_tokens.len() >= bert_tokens.len());
214 assert!(multibert_tokens.len() <= bert_tokens.len());
215 for token in &bert_tokens {
216 assert!(*token < model_details_bert.vocab_size as u32);
217 }
218 for token in &roberta_tokens {
219 assert!(*token < model_details_roberta.vocab_size as u32);
220 }
221 for token in &multibert_tokens {
222 assert!(*token < model_details_multibert.vocab_size as u32);
223 }
224 }
225 #[test]
226 fn test_model_size_comparisons() {
227 let config = EmbeddingConfig::default();
228 let bert = SentenceTransformerGenerator::new(config.clone());
229 let roberta = SentenceTransformerGenerator::roberta(config.clone());
230 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
231 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
232 let bert_size = bert.model_size_mb();
233 let roberta_size = roberta.model_size_mb();
234 let distilbert_size = distilbert.model_size_mb();
235 let multibert_size = multibert.model_size_mb();
236 assert!(distilbert_size < bert_size);
237 assert!(distilbert_size < roberta_size);
238 assert!(distilbert_size < multibert_size);
239 assert!(multibert_size > bert_size);
240 assert!(multibert_size > roberta_size);
241 assert!(multibert_size > distilbert_size);
242 assert!(roberta_size > bert_size);
243 }
244}