1use crate::Vector;
4use anyhow::{anyhow, Result};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::hash::{Hash, Hasher};
8use std::time::Duration;
9#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct EmbeddingConfig {
14 pub model_name: String,
15 pub dimensions: usize,
16 pub max_sequence_length: usize,
17 pub normalize: bool,
18}
19
20impl Default for EmbeddingConfig {
21 fn default() -> Self {
22 Self {
23 model_name: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
24 dimensions: 384,
25 max_sequence_length: 512,
26 normalize: true,
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
33pub enum EmbeddableContent {
34 Text(String),
36 RdfResource {
38 uri: String,
39 label: Option<String>,
40 description: Option<String>,
41 properties: HashMap<String, Vec<String>>,
42 },
43 SparqlQuery(String),
45 GraphPattern(String),
47}
48
49impl EmbeddableContent {
50 pub fn to_text(&self) -> String {
52 match self {
53 EmbeddableContent::Text(text) => text.clone(),
54 EmbeddableContent::RdfResource {
55 uri,
56 label,
57 description,
58 properties,
59 } => {
60 let mut text_parts = vec![uri.clone()];
61
62 if let Some(label) = label {
63 text_parts.push(format!("label: {label}"));
64 }
65
66 if let Some(desc) = description {
67 text_parts.push(format!("description: {desc}"));
68 }
69
70 for (prop, values) in properties {
71 text_parts.push(format!("{prop}: {}", values.join(", ")));
72 }
73
74 text_parts.join(" ")
75 }
76 EmbeddableContent::SparqlQuery(query) => query.clone(),
77 EmbeddableContent::GraphPattern(pattern) => pattern.clone(),
78 }
79 }
80
81 pub fn content_hash(&self) -> u64 {
83 let mut hasher = std::collections::hash_map::DefaultHasher::new();
84 self.to_text().hash(&mut hasher);
85 hasher.finish()
86 }
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub enum EmbeddingStrategy {
92 TfIdf,
94 SentenceTransformer,
96 Transformer(TransformerModelType),
98 Word2Vec(crate::word2vec::Word2VecConfig),
100 OpenAI(OpenAIConfig),
102 Custom(String),
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct OpenAIConfig {
109 pub api_key: String,
111 pub model: String,
113 pub base_url: String,
115 pub timeout_seconds: u64,
117 pub requests_per_minute: u32,
119 pub batch_size: usize,
121 pub enable_cache: bool,
123 pub cache_size: usize,
125 pub cache_ttl_seconds: u64,
127 pub max_retries: u32,
129 pub retry_delay_ms: u64,
131 pub retry_strategy: RetryStrategy,
133 pub track_costs: bool,
135 pub enable_metrics: bool,
137 pub user_agent: String,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
143pub enum RetryStrategy {
144 Fixed,
146 ExponentialBackoff,
148 LinearBackoff,
150}
151
152impl Default for OpenAIConfig {
153 fn default() -> Self {
154 Self {
155 api_key: std::env::var("OPENAI_API_KEY").unwrap_or_default(),
156 model: "text-embedding-3-small".to_string(),
157 base_url: "https://api.openai.com/v1".to_string(),
158 timeout_seconds: 30,
159 requests_per_minute: 3000,
160 batch_size: 100,
161 enable_cache: true,
162 cache_size: 10000,
163 cache_ttl_seconds: 3600, max_retries: 3,
165 retry_delay_ms: 1000,
166 retry_strategy: RetryStrategy::ExponentialBackoff,
167 track_costs: true,
168 enable_metrics: true,
169 user_agent: "oxirs-vec/0.1.0".to_string(),
170 }
171 }
172}
173
174impl OpenAIConfig {
175 pub fn production() -> Self {
177 Self {
178 requests_per_minute: 1000, cache_size: 50000,
180 cache_ttl_seconds: 7200, max_retries: 5,
182 retry_strategy: RetryStrategy::ExponentialBackoff,
183 ..Default::default()
184 }
185 }
186
187 pub fn development() -> Self {
189 Self {
190 requests_per_minute: 100,
191 cache_size: 1000,
192 cache_ttl_seconds: 300, max_retries: 2,
194 ..Default::default()
195 }
196 }
197
198 pub fn validate(&self) -> Result<()> {
200 if self.api_key.is_empty() {
201 return Err(anyhow!("OpenAI API key is required"));
202 }
203 if self.requests_per_minute == 0 {
204 return Err(anyhow!("requests_per_minute must be greater than 0"));
205 }
206 if self.batch_size == 0 {
207 return Err(anyhow!("batch_size must be greater than 0"));
208 }
209 if self.timeout_seconds == 0 {
210 return Err(anyhow!("timeout_seconds must be greater than 0"));
211 }
212 Ok(())
213 }
214}
215
216pub trait EmbeddingGenerator: Send + Sync + AsAny {
218 fn generate(&self, content: &EmbeddableContent) -> Result<Vector>;
220
221 fn generate_batch(&self, contents: &[EmbeddableContent]) -> Result<Vec<Vector>> {
223 contents.iter().map(|c| self.generate(c)).collect()
224 }
225
226 fn dimensions(&self) -> usize;
228
229 fn config(&self) -> &EmbeddingConfig;
231}
232
233pub struct TfIdfEmbeddingGenerator {
235 config: EmbeddingConfig,
236 vocabulary: HashMap<String, usize>,
237 idf_scores: HashMap<String, f32>,
238}
239
240impl TfIdfEmbeddingGenerator {
241 pub fn new(config: EmbeddingConfig) -> Self {
242 Self {
243 config,
244 vocabulary: HashMap::new(),
245 idf_scores: HashMap::new(),
246 }
247 }
248
249 pub fn build_vocabulary(&mut self, documents: &[String]) -> Result<()> {
251 let mut word_counts: HashMap<String, usize> = HashMap::new();
252 let mut doc_counts: HashMap<String, usize> = HashMap::new();
253
254 for doc in documents {
255 let words: Vec<String> = self.tokenize(doc);
256 let unique_words: std::collections::HashSet<_> = words.iter().collect();
257
258 for word in &words {
259 *word_counts.entry(word.clone()).or_insert(0) += 1;
260 }
261
262 for word in unique_words {
263 *doc_counts.entry(word.clone()).or_insert(0) += 1;
264 }
265 }
266
267 let mut word_freq: Vec<(String, usize)> = word_counts.into_iter().collect();
269 word_freq.sort_by(|a, b| b.1.cmp(&a.1));
270
271 self.vocabulary = word_freq
272 .into_iter()
273 .take(self.config.dimensions)
274 .enumerate()
275 .map(|(idx, (word, _))| (word, idx))
276 .collect();
277
278 let total_docs = documents.len() as f32;
280 for word in self.vocabulary.keys() {
281 let doc_freq = doc_counts.get(word).unwrap_or(&0);
282 let idf = (total_docs / (*doc_freq as f32 + 1.0)).ln();
283 self.idf_scores.insert(word.clone(), idf);
284 }
285
286 Ok(())
287 }
288
289 fn tokenize(&self, text: &str) -> Vec<String> {
290 text.to_lowercase()
291 .split_whitespace()
292 .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
293 .filter(|s| !s.is_empty())
294 .map(String::from)
295 .collect()
296 }
297
298 fn calculate_tf_idf(&self, text: &str) -> Vector {
299 let words = self.tokenize(text);
300 let mut tf_counts: HashMap<String, usize> = HashMap::new();
301
302 for word in &words {
303 *tf_counts.entry(word.clone()).or_insert(0) += 1;
304 }
305
306 let total_words = words.len() as f32;
307 let mut embedding = vec![0.0; self.config.dimensions];
308
309 for (word, count) in tf_counts {
310 if let Some(&idx) = self.vocabulary.get(&word) {
311 let tf = count as f32 / total_words;
312 let idf = self.idf_scores.get(&word).unwrap_or(&0.0);
313 embedding[idx] = tf * idf;
314 }
315 }
316
317 if self.config.normalize {
318 self.normalize_vector(&mut embedding);
319 }
320
321 Vector::new(embedding)
322 }
323
324 fn normalize_vector(&self, vector: &mut [f32]) {
325 let magnitude: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
326 if magnitude > 0.0 {
327 for value in vector {
328 *value /= magnitude;
329 }
330 }
331 }
332}
333
334impl EmbeddingGenerator for TfIdfEmbeddingGenerator {
335 fn generate(&self, content: &EmbeddableContent) -> Result<Vector> {
336 if self.vocabulary.is_empty() {
337 return Err(anyhow!(
338 "Vocabulary not built. Call build_vocabulary first."
339 ));
340 }
341
342 let text = content.to_text();
343 Ok(self.calculate_tf_idf(&text))
344 }
345
346 fn dimensions(&self) -> usize {
347 self.config.dimensions
348 }
349
350 fn config(&self) -> &EmbeddingConfig {
351 &self.config
352 }
353}
354
355pub struct SentenceTransformerGenerator {
357 config: EmbeddingConfig,
358 model_type: TransformerModelType,
359}
360
361#[derive(Debug, Clone, Serialize, Deserialize, Default)]
363pub enum TransformerModelType {
364 #[default]
366 BERT,
367 RoBERTa,
369 DistilBERT,
371 MultiBERT,
373 Custom(String),
375}
376
377#[derive(Debug, Clone)]
379pub struct ModelDetails {
380 pub vocab_size: usize,
381 pub num_layers: usize,
382 pub num_attention_heads: usize,
383 pub hidden_size: usize,
384 pub intermediate_size: usize,
385 pub max_position_embeddings: usize,
386 pub supports_languages: Vec<String>,
387 pub model_size_mb: usize,
388 pub typical_inference_time_ms: u64,
389}
390
391impl SentenceTransformerGenerator {
392 pub fn new(config: EmbeddingConfig) -> Self {
393 Self {
394 config,
395 model_type: TransformerModelType::default(),
396 }
397 }
398
399 pub fn with_model_type(config: EmbeddingConfig, model_type: TransformerModelType) -> Self {
400 Self { config, model_type }
401 }
402
403 pub fn roberta(config: EmbeddingConfig) -> Self {
405 Self::with_model_type(config, TransformerModelType::RoBERTa)
406 }
407
408 pub fn distilbert(config: EmbeddingConfig) -> Self {
410 let adjusted_config = EmbeddingConfig {
411 dimensions: 384, ..config
413 };
414 Self::with_model_type(adjusted_config, TransformerModelType::DistilBERT)
415 }
416
417 pub fn multilingual_bert(config: EmbeddingConfig) -> Self {
419 Self::with_model_type(config, TransformerModelType::MultiBERT)
420 }
421
422 pub fn model_type(&self) -> &TransformerModelType {
424 &self.model_type
425 }
426
427 pub fn model_details(&self) -> ModelDetails {
429 self.get_model_details()
430 }
431
432 pub fn supports_language(&self, language_code: &str) -> bool {
434 let details = self.get_model_details();
435 details
436 .supports_languages
437 .contains(&language_code.to_string())
438 }
439
440 pub fn estimate_inference_time(&self, text_length: usize) -> u64 {
442 let details = self.get_model_details();
443 let base_time = details.typical_inference_time_ms;
444
445 let length_factor = (text_length as f64 / 100.0).sqrt().max(1.0);
447 (base_time as f64 * length_factor) as u64
448 }
449
450 pub fn model_size_mb(&self) -> usize {
452 self.get_model_details().model_size_mb
453 }
454
455 pub fn efficiency_rating(&self) -> f32 {
457 match &self.model_type {
458 TransformerModelType::DistilBERT => 1.5, TransformerModelType::BERT => 1.0, TransformerModelType::RoBERTa => 0.95, TransformerModelType::MultiBERT => 0.8, TransformerModelType::Custom(_) => 1.0, }
464 }
465
466 fn get_model_config(&self) -> (usize, usize, f32) {
468 match &self.model_type {
469 TransformerModelType::BERT => (self.config.dimensions, 512, 1.0), TransformerModelType::RoBERTa => (self.config.dimensions, 514, 0.95), TransformerModelType::DistilBERT => (self.config.dimensions, 512, 1.5), TransformerModelType::MultiBERT => (self.config.dimensions, 512, 0.8), TransformerModelType::Custom(_) => {
474 (self.config.dimensions, self.config.max_sequence_length, 1.0)
475 }
476 }
477 }
478
479 fn get_model_details(&self) -> ModelDetails {
481 match &self.model_type {
482 TransformerModelType::BERT => ModelDetails {
483 vocab_size: 30522,
484 num_layers: 12,
485 num_attention_heads: 12,
486 hidden_size: 768,
487 intermediate_size: 3072,
488 max_position_embeddings: 512,
489 supports_languages: vec!["en".to_string()],
490 model_size_mb: 440,
491 typical_inference_time_ms: 50,
492 },
493 TransformerModelType::RoBERTa => ModelDetails {
494 vocab_size: 50265,
495 num_layers: 12,
496 num_attention_heads: 12,
497 hidden_size: 768,
498 intermediate_size: 3072,
499 max_position_embeddings: 514,
500 supports_languages: vec!["en".to_string()],
501 model_size_mb: 470,
502 typical_inference_time_ms: 55, },
504 TransformerModelType::DistilBERT => ModelDetails {
505 vocab_size: 30522,
506 num_layers: 6, num_attention_heads: 12,
508 hidden_size: 384, intermediate_size: 1536,
510 max_position_embeddings: 512,
511 supports_languages: vec!["en".to_string()],
512 model_size_mb: 250, typical_inference_time_ms: 25, },
515 TransformerModelType::MultiBERT => ModelDetails {
516 vocab_size: 120000, num_layers: 12,
518 num_attention_heads: 12,
519 hidden_size: 768,
520 intermediate_size: 3072,
521 max_position_embeddings: 512,
522 supports_languages: vec![
523 "en".to_string(),
524 "de".to_string(),
525 "fr".to_string(),
526 "es".to_string(),
527 "it".to_string(),
528 "pt".to_string(),
529 "ru".to_string(),
530 "zh".to_string(),
531 "ja".to_string(),
532 "ko".to_string(),
533 "ar".to_string(),
534 "hi".to_string(),
535 "th".to_string(),
536 "tr".to_string(),
537 "pl".to_string(),
538 "nl".to_string(),
539 "sv".to_string(),
540 "da".to_string(),
541 "no".to_string(),
542 "fi".to_string(),
543 ], model_size_mb: 670, typical_inference_time_ms: 70, },
547 TransformerModelType::Custom(_path) => ModelDetails {
548 vocab_size: 50000, num_layers: 12,
550 num_attention_heads: 12,
551 hidden_size: self.config.dimensions,
552 intermediate_size: self.config.dimensions * 4,
553 max_position_embeddings: self.config.max_sequence_length,
554 supports_languages: vec!["unknown".to_string()],
555 model_size_mb: 500, typical_inference_time_ms: 60,
557 },
558 }
559 }
560
561 fn generate_with_model(&self, text: &str) -> Result<Vector> {
563 let _text_hash = {
564 use std::hash::{Hash, Hasher};
565 let mut hasher = std::collections::hash_map::DefaultHasher::new();
566 text.hash(&mut hasher);
567 hasher.finish()
568 };
569
570 let (dimensions, max_len, _efficiency) = self.get_model_config();
571 let model_details = self.get_model_details();
572
573 let processed_text = self.preprocess_text_for_model(text, max_len)?;
575
576 let token_ids = self.simulate_tokenization(&processed_text, &model_details);
578
579 let values =
581 self.generate_embeddings_from_tokens(&token_ids, dimensions, &model_details)?;
582
583 if self.config.normalize {
584 let magnitude: f32 = values.iter().map(|x| x * x).sum::<f32>().sqrt();
585 if magnitude > 0.0 {
586 let mut normalized_values = values;
587 for value in &mut normalized_values {
588 *value /= magnitude;
589 }
590 return Ok(Vector::new(normalized_values));
591 }
592 }
593
594 Ok(Vector::new(values))
595 }
596
597 fn preprocess_text_for_model(&self, text: &str, max_len: usize) -> Result<String> {
599 let processed = match &self.model_type {
600 TransformerModelType::BERT => {
601 let truncated = if text.len() > max_len - 20 {
603 &text[..max_len - 20]
605 } else {
606 text
607 };
608 format!("[CLS] {} [SEP]", truncated.to_lowercase())
609 }
610 TransformerModelType::RoBERTa => {
611 let truncated = if text.len() > max_len - 10 {
613 &text[..max_len - 10]
614 } else {
615 text
616 };
617 format!("<s>{truncated}</s>") }
619 TransformerModelType::DistilBERT => {
620 let truncated = if text.len() > max_len - 20 {
622 &text[..max_len - 20]
623 } else {
624 text
625 };
626 format!("[CLS] {} [SEP]", truncated.to_lowercase())
627 }
628 TransformerModelType::MultiBERT => {
629 let truncated = if text.len() > max_len - 20 {
631 &text[..max_len - 20]
632 } else {
633 text
634 };
635 let has_non_latin = !text.is_ascii();
637 if has_non_latin {
638 format!("[CLS] {truncated} [SEP]") } else {
640 format!("[CLS] {} [SEP]", truncated.to_lowercase()) }
642 }
643 TransformerModelType::Custom(_) => {
644 let truncated = if text.len() > max_len {
646 &text[..max_len]
647 } else {
648 text
649 };
650 truncated.to_string()
651 }
652 };
653
654 Ok(processed)
655 }
656
657 fn simulate_tokenization(&self, text: &str, model_details: &ModelDetails) -> Vec<u32> {
659 let mut token_ids = Vec::new();
660
661 let words: Vec<&str> = text.split_whitespace().collect();
663
664 for word in words {
665 let subwords = match &self.model_type {
667 TransformerModelType::RoBERTa => {
668 self.simulate_bpe_tokenization(word, model_details.vocab_size)
670 }
671 TransformerModelType::DistilBERT | TransformerModelType::BERT => {
672 self.simulate_wordpiece_tokenization(word, model_details.vocab_size)
674 }
675 TransformerModelType::MultiBERT => {
676 self.simulate_multilingual_tokenization(word, model_details.vocab_size)
678 }
679 TransformerModelType::Custom(_) => {
680 vec![self.word_to_token_id(word, model_details.vocab_size)]
682 }
683 };
684
685 token_ids.extend(subwords);
686 }
687
688 token_ids.truncate(model_details.max_position_embeddings - 2); token_ids
691 }
692
693 fn simulate_bpe_tokenization(&self, word: &str, vocab_size: usize) -> Vec<u32> {
695 let mut tokens = Vec::new();
696 let mut remaining = word;
697
698 while !remaining.is_empty() {
699 let chunk_size = if remaining.len() > 4 {
700 4
701 } else {
702 remaining.len()
703 };
704 let chunk = &remaining[..chunk_size];
705 tokens.push(self.word_to_token_id(chunk, vocab_size));
706 remaining = &remaining[chunk_size..];
707 }
708
709 tokens
710 }
711
712 fn simulate_wordpiece_tokenization(&self, word: &str, vocab_size: usize) -> Vec<u32> {
714 if word.len() <= 6 {
715 vec![self.word_to_token_id(word, vocab_size)]
716 } else {
717 let mid = word.len() / 2;
718 vec![
719 self.word_to_token_id(&word[..mid], vocab_size),
720 self.word_to_token_id(&format!("##{}", &word[mid..]), vocab_size), ]
722 }
723 }
724
725 fn simulate_multilingual_tokenization(&self, word: &str, vocab_size: usize) -> Vec<u32> {
727 if word.len() <= 10 {
729 vec![self.word_to_token_id(word, vocab_size)]
730 } else {
731 let mid = word.len() / 2;
732 vec![
733 self.word_to_token_id(&word[..mid], vocab_size),
734 self.word_to_token_id(&word[mid..], vocab_size),
735 ]
736 }
737 }
738
739 fn word_to_token_id(&self, word: &str, vocab_size: usize) -> u32 {
741 use std::hash::{Hash, Hasher};
742 let mut hasher = std::collections::hash_map::DefaultHasher::new();
743 word.hash(&mut hasher);
744 (hasher.finish() % vocab_size as u64) as u32
745 }
746
747 fn generate_embeddings_from_tokens(
749 &self,
750 token_ids: &[u32],
751 dimensions: usize,
752 model_details: &ModelDetails,
753 ) -> Result<Vec<f32>> {
754 let mut values = vec![0.0; dimensions];
755
756 match &self.model_type {
758 TransformerModelType::BERT => {
759 self.generate_bert_style_embeddings(token_ids, &mut values, model_details)
760 }
761 TransformerModelType::RoBERTa => {
762 self.generate_roberta_style_embeddings(token_ids, &mut values, model_details)
763 }
764 TransformerModelType::DistilBERT => {
765 self.generate_distilbert_style_embeddings(token_ids, &mut values, model_details)
766 }
767 TransformerModelType::MultiBERT => {
768 self.generate_multibert_style_embeddings(token_ids, &mut values, model_details)
769 }
770 TransformerModelType::Custom(_) => {
771 self.generate_custom_style_embeddings(token_ids, &mut values, model_details)
772 }
773 }
774
775 Ok(values)
776 }
777
778 fn generate_bert_style_embeddings(
780 &self,
781 token_ids: &[u32],
782 values: &mut [f32],
783 _model_details: &ModelDetails,
784 ) {
785 for (i, &token_id) in token_ids.iter().enumerate() {
786 let mut seed = token_id as u64;
787 for value in values.iter_mut() {
788 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
789 let normalized = (seed as f32) / (u64::MAX as f32);
790 let position_encoding =
791 ((i as f32 / 512.0) * 2.0 * std::f32::consts::PI).sin() * 0.1;
792 *value += ((normalized - 0.5) * 2.0) + position_encoding;
793 }
794 }
795
796 if !token_ids.is_empty() {
798 for value in values.iter_mut() {
799 *value /= token_ids.len() as f32;
800 }
801 }
802 }
803
804 fn generate_roberta_style_embeddings(
806 &self,
807 token_ids: &[u32],
808 values: &mut [f32],
809 _model_details: &ModelDetails,
810 ) {
811 for (i, &token_id) in token_ids.iter().enumerate() {
812 let mut seed = token_id.wrapping_mul(31415927); for value in values.iter_mut() {
814 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
815 let normalized = (seed as f32) / (u64::MAX as f32);
816 let position_encoding =
818 ((i as f32 + 2.0) / 514.0 * 2.0 * std::f32::consts::PI).cos() * 0.1;
819 *value += ((normalized - 0.5) * 2.0) + position_encoding;
820 }
821 }
822
823 if !token_ids.is_empty() {
824 for value in values.iter_mut() {
825 *value /= token_ids.len() as f32;
826 }
827 }
828 }
829
830 fn generate_distilbert_style_embeddings(
832 &self,
833 token_ids: &[u32],
834 values: &mut [f32],
835 _model_details: &ModelDetails,
836 ) {
837 for (i, &token_id) in token_ids.iter().enumerate() {
839 let mut seed = token_id as u64;
840 for value in values.iter_mut() {
841 seed = seed.wrapping_mul(982451653).wrapping_add(12345); let normalized = (seed as f32) / (u64::MAX as f32);
843 let position_encoding = (i as f32 / 512.0).sin() * 0.05;
845 *value += ((normalized - 0.5) * 1.5) + position_encoding; }
847 }
848
849 if !token_ids.is_empty() {
850 for value in values.iter_mut() {
851 *value /= token_ids.len() as f32;
852 }
853 }
854 }
855
856 fn generate_multibert_style_embeddings(
858 &self,
859 token_ids: &[u32],
860 values: &mut [f32],
861 _model_details: &ModelDetails,
862 ) {
863 for (i, &token_id) in token_ids.iter().enumerate() {
864 let mut seed = token_id.wrapping_mul(2654435761); for j in 0..values.len() {
867 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
868 let normalized = (seed as f32) / (u64::MAX as f32);
869 let position_encoding =
870 ((i as f32 / 512.0) * 2.0 * std::f32::consts::PI).sin() * 0.08;
871 let cross_lingual_bias =
873 (j as f32 / values.len() as f32 * std::f32::consts::PI).cos() * 0.05;
874 values[j] += ((normalized - 0.5) * 1.8) + position_encoding + cross_lingual_bias;
875 }
876 }
877
878 if !token_ids.is_empty() {
879 for value in values.iter_mut() {
880 *value /= token_ids.len() as f32;
881 }
882 }
883 }
884
885 fn generate_custom_style_embeddings(
887 &self,
888 token_ids: &[u32],
889 values: &mut [f32],
890 _model_details: &ModelDetails,
891 ) {
892 for &token_id in token_ids {
894 let mut seed = token_id as u64;
895 for value in values.iter_mut() {
896 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
897 let normalized = (seed as f32) / (u64::MAX as f32);
898 *value += (normalized - 0.5) * 2.0;
899 }
900 }
901
902 if !token_ids.is_empty() {
903 for value in values.iter_mut() {
904 *value /= token_ids.len() as f32;
905 }
906 }
907 }
908}
909
910impl EmbeddingGenerator for SentenceTransformerGenerator {
911 fn generate(&self, content: &EmbeddableContent) -> Result<Vector> {
912 let text = content.to_text();
913 self.generate_with_model(&text)
914 }
915
916 fn dimensions(&self) -> usize {
917 self.config.dimensions
918 }
919
920 fn config(&self) -> &EmbeddingConfig {
921 &self.config
922 }
923}
924
925pub struct EmbeddingCache {
927 cache: HashMap<u64, Vector>,
928 max_size: usize,
929 access_order: Vec<u64>,
930}
931
932impl EmbeddingCache {
933 pub fn new(max_size: usize) -> Self {
934 Self {
935 cache: HashMap::new(),
936 max_size,
937 access_order: Vec::new(),
938 }
939 }
940
941 pub fn get(&mut self, content: &EmbeddableContent) -> Option<&Vector> {
942 let hash = content.content_hash();
943 if let Some(vector) = self.cache.get(&hash) {
944 if let Some(pos) = self.access_order.iter().position(|&x| x == hash) {
946 self.access_order.remove(pos);
947 }
948 self.access_order.push(hash);
949 Some(vector)
950 } else {
951 None
952 }
953 }
954
955 pub fn insert(&mut self, content: &EmbeddableContent, vector: Vector) {
956 let hash = content.content_hash();
957
958 if self.cache.len() >= self.max_size && !self.cache.contains_key(&hash) {
960 if let Some(&lru_hash) = self.access_order.first() {
961 self.cache.remove(&lru_hash);
962 self.access_order.remove(0);
963 }
964 }
965
966 self.cache.insert(hash, vector);
967 self.access_order.push(hash);
968 }
969
970 pub fn clear(&mut self) {
971 self.cache.clear();
972 self.access_order.clear();
973 }
974
975 pub fn size(&self) -> usize {
976 self.cache.len()
977 }
978}
979
980pub struct EmbeddingManager {
982 generator: Box<dyn EmbeddingGenerator>,
983 cache: EmbeddingCache,
984 strategy: EmbeddingStrategy,
985}
986
987impl EmbeddingManager {
988 pub fn new(strategy: EmbeddingStrategy, cache_size: usize) -> Result<Self> {
989 let generator: Box<dyn EmbeddingGenerator> = match &strategy {
990 EmbeddingStrategy::TfIdf => {
991 let config = EmbeddingConfig::default();
992 Box::new(TfIdfEmbeddingGenerator::new(config))
993 }
994 EmbeddingStrategy::SentenceTransformer => {
995 let config = EmbeddingConfig::default();
996 Box::new(SentenceTransformerGenerator::new(config))
997 }
998 EmbeddingStrategy::Transformer(model_type) => {
999 let config = EmbeddingConfig {
1000 model_name: format!("{model_type:?}"),
1001 dimensions: match model_type {
1002 TransformerModelType::DistilBERT => 384, _ => 768, },
1005 max_sequence_length: 512,
1006 normalize: true,
1007 };
1008 Box::new(SentenceTransformerGenerator::with_model_type(
1009 config,
1010 model_type.clone(),
1011 ))
1012 }
1013 EmbeddingStrategy::Word2Vec(word2vec_config) => {
1014 let embedding_config = EmbeddingConfig {
1015 model_name: "word2vec".to_string(),
1016 dimensions: word2vec_config.dimensions,
1017 max_sequence_length: 512,
1018 normalize: word2vec_config.normalize,
1019 };
1020 Box::new(crate::word2vec::Word2VecEmbeddingGenerator::new(
1021 word2vec_config.clone(),
1022 embedding_config,
1023 )?)
1024 }
1025 EmbeddingStrategy::OpenAI(openai_config) => {
1026 Box::new(OpenAIEmbeddingGenerator::new(openai_config.clone())?)
1027 }
1028 EmbeddingStrategy::Custom(_model_path) => {
1029 let config = EmbeddingConfig::default();
1031 Box::new(SentenceTransformerGenerator::new(config))
1032 }
1033 };
1034
1035 Ok(Self {
1036 generator,
1037 cache: EmbeddingCache::new(cache_size),
1038 strategy,
1039 })
1040 }
1041
1042 pub fn get_embedding(&mut self, content: &EmbeddableContent) -> Result<Vector> {
1044 if let Some(cached) = self.cache.get(content) {
1045 return Ok(cached.clone());
1046 }
1047
1048 let embedding = self.generator.generate(content)?;
1049 self.cache.insert(content, embedding.clone());
1050 Ok(embedding)
1051 }
1052
1053 pub fn precompute_embeddings(&mut self, contents: &[EmbeddableContent]) -> Result<()> {
1055 let embeddings = self.generator.generate_batch(contents)?;
1056
1057 for (content, embedding) in contents.iter().zip(embeddings) {
1058 self.cache.insert(content, embedding);
1059 }
1060
1061 Ok(())
1062 }
1063
1064 pub fn build_vocabulary(&mut self, documents: &[String]) -> Result<()> {
1066 if let EmbeddingStrategy::TfIdf = self.strategy {
1067 if let Some(tfidf_gen) = self
1068 .generator
1069 .as_any_mut()
1070 .downcast_mut::<TfIdfEmbeddingGenerator>()
1071 {
1072 tfidf_gen.build_vocabulary(documents)?;
1073 }
1074 }
1075 Ok(())
1076 }
1077
1078 pub fn dimensions(&self) -> usize {
1079 self.generator.dimensions()
1080 }
1081
1082 pub fn cache_stats(&self) -> (usize, usize) {
1083 (self.cache.size(), self.cache.max_size)
1084 }
1085}
1086
1087pub trait AsAny {
1089 fn as_any(&self) -> &dyn std::any::Any;
1090 fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
1091}
1092
1093impl AsAny for TfIdfEmbeddingGenerator {
1094 fn as_any(&self) -> &dyn std::any::Any {
1095 self
1096 }
1097
1098 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
1099 self
1100 }
1101}
1102
1103impl AsAny for SentenceTransformerGenerator {
1104 fn as_any(&self) -> &dyn std::any::Any {
1105 self
1106 }
1107
1108 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
1109 self
1110 }
1111}
1112
1113pub struct OpenAIEmbeddingGenerator {
1115 config: EmbeddingConfig,
1116 openai_config: OpenAIConfig,
1117 client: reqwest::Client,
1118 rate_limiter: RateLimiter,
1119 request_cache: std::sync::Arc<std::sync::Mutex<lru::LruCache<u64, CachedEmbedding>>>,
1120 metrics: OpenAIMetrics,
1121}
1122
1123#[derive(Debug, Clone)]
1125pub struct CachedEmbedding {
1126 pub vector: Vector,
1127 pub cached_at: std::time::SystemTime,
1128 pub model: String,
1129 pub cost_usd: f64,
1130}
1131
1132#[derive(Debug, Clone, Default)]
1134pub struct OpenAIMetrics {
1135 pub total_requests: u64,
1136 pub successful_requests: u64,
1137 pub failed_requests: u64,
1138 pub total_tokens_processed: u64,
1139 pub cache_hits: u64,
1140 pub cache_misses: u64,
1141 pub total_cost_usd: f64,
1142 pub retry_count: u64,
1143 pub rate_limit_waits: u64,
1144 pub average_response_time_ms: f64,
1145 pub last_request_time: Option<std::time::SystemTime>,
1146 pub requests_by_model: HashMap<String, u64>,
1147 pub errors_by_type: HashMap<String, u64>,
1148}
1149
1150impl OpenAIMetrics {
1151 pub fn cache_hit_ratio(&self) -> f64 {
1153 if self.cache_hits + self.cache_misses == 0 {
1154 0.0
1155 } else {
1156 self.cache_hits as f64 / (self.cache_hits + self.cache_misses) as f64
1157 }
1158 }
1159
1160 pub fn success_rate(&self) -> f64 {
1162 if self.total_requests == 0 {
1163 0.0
1164 } else {
1165 self.successful_requests as f64 / self.total_requests as f64
1166 }
1167 }
1168
1169 pub fn average_cost_per_request(&self) -> f64 {
1171 if self.successful_requests == 0 {
1172 0.0
1173 } else {
1174 self.total_cost_usd / self.successful_requests as f64
1175 }
1176 }
1177
1178 pub fn report(&self) -> String {
1180 format!(
1181 "OpenAI Metrics Report:\n\
1182 Total Requests: {}\n\
1183 Success Rate: {:.2}%\n\
1184 Cache Hit Ratio: {:.2}%\n\
1185 Total Cost: ${:.4}\n\
1186 Avg Cost/Request: ${:.6}\n\
1187 Avg Response Time: {:.2}ms\n\
1188 Retries: {}\n\
1189 Rate Limit Waits: {}",
1190 self.total_requests,
1191 self.success_rate() * 100.0,
1192 self.cache_hit_ratio() * 100.0,
1193 self.total_cost_usd,
1194 self.average_cost_per_request(),
1195 self.average_response_time_ms,
1196 self.retry_count,
1197 self.rate_limit_waits
1198 )
1199 }
1200}
1201
1202pub struct RateLimiter {
1204 requests_per_minute: u32,
1205 request_times: std::collections::VecDeque<std::time::Instant>,
1206}
1207
1208impl RateLimiter {
1209 pub fn new(requests_per_minute: u32) -> Self {
1210 Self {
1211 requests_per_minute,
1212 request_times: std::collections::VecDeque::new(),
1213 }
1214 }
1215
1216 pub async fn wait_if_needed(&mut self) {
1217 let now = std::time::Instant::now();
1218 let minute_ago = now - std::time::Duration::from_secs(60);
1219
1220 while let Some(&front_time) = self.request_times.front() {
1222 if front_time < minute_ago {
1223 self.request_times.pop_front();
1224 } else {
1225 break;
1226 }
1227 }
1228
1229 if self.request_times.len() >= self.requests_per_minute as usize {
1231 if let Some(&oldest) = self.request_times.front() {
1232 let wait_time = oldest + std::time::Duration::from_secs(60) - now;
1233 if !wait_time.is_zero() {
1234 tokio::time::sleep(wait_time).await;
1235 }
1236 }
1237 }
1238
1239 self.request_times.push_back(now);
1240 }
1241}
1242
1243impl OpenAIEmbeddingGenerator {
1244 pub fn new(openai_config: OpenAIConfig) -> Result<Self> {
1245 openai_config.validate()?;
1246
1247 let client = reqwest::Client::builder()
1248 .timeout(std::time::Duration::from_secs(
1249 openai_config.timeout_seconds,
1250 ))
1251 .user_agent(&openai_config.user_agent)
1252 .build()
1253 .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
1254
1255 let embedding_config = EmbeddingConfig {
1256 model_name: openai_config.model.clone(),
1257 dimensions: Self::get_model_dimensions(&openai_config.model),
1258 max_sequence_length: 8191, normalize: true,
1260 };
1261
1262 let cache_size = if openai_config.enable_cache {
1263 std::num::NonZeroUsize::new(openai_config.cache_size)
1264 .unwrap_or(std::num::NonZeroUsize::new(1000).unwrap())
1265 } else {
1266 std::num::NonZeroUsize::new(1).unwrap()
1267 };
1268
1269 Ok(Self {
1270 config: embedding_config,
1271 openai_config: openai_config.clone(),
1272 client,
1273 rate_limiter: RateLimiter::new(openai_config.requests_per_minute),
1274 request_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(
1275 cache_size,
1276 ))),
1277 metrics: OpenAIMetrics::default(),
1278 })
1279 }
1280
1281 fn get_model_dimensions(model: &str) -> usize {
1283 match model {
1284 "text-embedding-ada-002" => 1536,
1285 "text-embedding-3-small" => 1536,
1286 "text-embedding-3-large" => 3072,
1287 "text-embedding-004" => 1536,
1288 _ => 1536, }
1290 }
1291
1292 fn get_model_cost_per_1k_tokens(model: &str) -> f64 {
1294 match model {
1295 "text-embedding-ada-002" => 0.0001,
1296 "text-embedding-3-small" => 0.00002,
1297 "text-embedding-3-large" => 0.00013,
1298 "text-embedding-004" => 0.00002,
1299 _ => 0.0001, }
1301 }
1302
1303 fn calculate_cost(&self, texts: &[String]) -> f64 {
1305 if !self.openai_config.track_costs {
1306 return 0.0;
1307 }
1308
1309 let total_tokens: usize = texts.iter().map(|t| t.len() / 4).sum(); let cost_per_1k = Self::get_model_cost_per_1k_tokens(&self.openai_config.model);
1311 (total_tokens as f64 / 1000.0) * cost_per_1k
1312 }
1313
1314 fn is_cache_valid(&self, cached: &CachedEmbedding) -> bool {
1316 if self.openai_config.cache_ttl_seconds == 0 {
1317 return true; }
1319
1320 let elapsed = cached
1321 .cached_at
1322 .elapsed()
1323 .unwrap_or(std::time::Duration::from_secs(u64::MAX));
1324
1325 elapsed.as_secs() < self.openai_config.cache_ttl_seconds
1326 }
1327
1328 async fn make_request(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
1330 let start_time = std::time::Instant::now();
1331 let mut attempts = 0;
1332
1333 while attempts < self.openai_config.max_retries {
1334 match self.try_request(texts).await {
1335 Ok(embeddings) => {
1336 if self.openai_config.enable_metrics {
1338 let response_time = start_time.elapsed().as_millis() as f64;
1339 self.update_response_time(response_time);
1340
1341 let cost = self.calculate_cost(texts);
1342 self.metrics.total_cost_usd += cost;
1343
1344 *self
1345 .metrics
1346 .requests_by_model
1347 .entry(self.openai_config.model.clone())
1348 .or_insert(0) += 1;
1349 }
1350
1351 return Ok(embeddings);
1352 }
1353 Err(e) => {
1354 attempts += 1;
1355 self.metrics.retry_count += 1;
1356
1357 let error_type = if e.to_string().contains("rate_limit") {
1359 "rate_limit"
1360 } else if e.to_string().contains("timeout") {
1361 "timeout"
1362 } else if e.to_string().contains("401") {
1363 "unauthorized"
1364 } else if e.to_string().contains("400") {
1365 "bad_request"
1366 } else {
1367 "other"
1368 };
1369
1370 *self
1371 .metrics
1372 .errors_by_type
1373 .entry(error_type.to_string())
1374 .or_insert(0) += 1;
1375
1376 if attempts >= self.openai_config.max_retries {
1377 return Err(e);
1378 }
1379
1380 let delay = self.calculate_retry_delay(attempts);
1382 tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
1383 }
1384 }
1385 }
1386
1387 Err(anyhow!("Max retries exceeded"))
1388 }
1389
1390 fn calculate_retry_delay(&self, attempt: u32) -> u64 {
1392 let base_delay = self.openai_config.retry_delay_ms;
1393
1394 match self.openai_config.retry_strategy {
1395 RetryStrategy::Fixed => base_delay,
1396 RetryStrategy::LinearBackoff => base_delay * attempt as u64,
1397 RetryStrategy::ExponentialBackoff => {
1398 let delay = base_delay * (2_u64.pow(attempt - 1));
1399 let jitter = {
1401 #[allow(unused_imports)]
1402 use scirs2_core::random::{Random, Rng};
1403 let mut rng = Random::seed(42);
1404 (delay as f64 * 0.25 * (rng.gen_range(0.0..1.0) - 0.5)) as u64
1405 };
1406 delay.saturating_add(jitter).min(30000) }
1408 }
1409 }
1410
1411 fn update_response_time(&mut self, response_time_ms: f64) {
1413 if self.metrics.successful_requests == 0 {
1414 self.metrics.average_response_time_ms = response_time_ms;
1415 } else {
1416 let total =
1418 self.metrics.average_response_time_ms * self.metrics.successful_requests as f64;
1419 self.metrics.average_response_time_ms =
1420 (total + response_time_ms) / (self.metrics.successful_requests + 1) as f64;
1421 }
1422 }
1423
1424 async fn try_request(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
1425 self.rate_limiter.wait_if_needed().await;
1426
1427 let request_body = serde_json::json!({
1428 "model": self.openai_config.model,
1429 "input": texts,
1430 "encoding_format": "float"
1431 });
1432
1433 let response = self
1434 .client
1435 .post(format!("{}/embeddings", self.openai_config.base_url))
1436 .header(
1437 "Authorization",
1438 format!("Bearer {}", self.openai_config.api_key),
1439 )
1440 .header("Content-Type", "application/json")
1441 .json(&request_body)
1442 .send()
1443 .await
1444 .map_err(|e| anyhow!("Request failed: {}", e))?;
1445
1446 if !response.status().is_success() {
1447 let status = response.status();
1448 let error_text = response.text().await.unwrap_or_default();
1449 return Err(anyhow!(
1450 "API request failed with status {}: {}",
1451 status,
1452 error_text
1453 ));
1454 }
1455
1456 let response_data: serde_json::Value = response
1457 .json()
1458 .await
1459 .map_err(|e| anyhow!("Failed to parse response: {}", e))?;
1460
1461 let embeddings_data = response_data["data"]
1462 .as_array()
1463 .ok_or_else(|| anyhow!("Invalid response format: missing data array"))?;
1464
1465 let mut embeddings = Vec::new();
1466 for item in embeddings_data {
1467 let embedding = item["embedding"]
1468 .as_array()
1469 .ok_or_else(|| anyhow!("Invalid response format: missing embedding"))?;
1470
1471 let vec: Result<Vec<f32>, _> = embedding
1472 .iter()
1473 .map(|v| {
1474 v.as_f64()
1475 .ok_or_else(|| anyhow!("Invalid embedding value"))
1476 .map(|f| f as f32)
1477 })
1478 .collect();
1479
1480 embeddings.push(vec?);
1481 }
1482
1483 Ok(embeddings)
1484 }
1485
1486 pub async fn generate_async(&mut self, content: &EmbeddableContent) -> Result<Vector> {
1488 let text = content.to_text();
1489
1490 if self.openai_config.enable_cache {
1492 let hash = content.content_hash();
1493
1494 let cached_vector = match self.request_cache.lock() {
1496 Ok(mut cache) => {
1497 if let Some(cached) = cache.get(&hash) {
1498 let is_valid = cached.cached_at.elapsed().unwrap_or_default()
1499 < Duration::from_secs(self.openai_config.cache_ttl_seconds);
1500 if is_valid {
1501 Some(cached.vector.clone())
1502 } else {
1503 None
1504 }
1505 } else {
1506 None
1507 }
1508 }
1509 _ => None,
1510 };
1511
1512 if let Some(result) = cached_vector {
1513 self.update_cache_hit();
1514 return Ok(result);
1515 } else {
1516 if let Ok(mut cache) = self.request_cache.lock() {
1518 cache.pop(&hash);
1519 }
1520 self.update_cache_miss();
1521 }
1522 }
1523
1524 let embeddings = match self.make_request(std::slice::from_ref(&text)).await {
1525 Ok(embeddings) => {
1526 self.update_metrics_success(std::slice::from_ref(&text));
1527 embeddings
1528 }
1529 Err(e) => {
1530 self.update_metrics_failure();
1531 return Err(e);
1532 }
1533 };
1534
1535 if embeddings.is_empty() {
1536 self.update_metrics_failure();
1537 return Err(anyhow!("No embeddings returned from API"));
1538 }
1539
1540 let vector = Vector::new(embeddings[0].clone());
1541
1542 if self.openai_config.enable_cache {
1544 let hash = content.content_hash();
1545 let cost = self.calculate_cost(std::slice::from_ref(&text));
1546 let cached_embedding = CachedEmbedding {
1547 vector: vector.clone(),
1548 cached_at: std::time::SystemTime::now(),
1549 model: self.openai_config.model.clone(),
1550 cost_usd: cost,
1551 };
1552 if let Ok(mut cache) = self.request_cache.lock() {
1553 cache.put(hash, cached_embedding);
1554 }
1555 }
1556
1557 Ok(vector)
1558 }
1559
1560 pub async fn generate_batch_async(
1562 &mut self,
1563 contents: &[EmbeddableContent],
1564 ) -> Result<Vec<Vector>> {
1565 if contents.is_empty() {
1566 return Ok(Vec::new());
1567 }
1568
1569 let mut results = Vec::with_capacity(contents.len());
1570 let batch_size = self.openai_config.batch_size;
1571
1572 for chunk in contents.chunks(batch_size) {
1573 let texts: Vec<String> = chunk.iter().map(|c| c.to_text()).collect();
1574
1575 let embeddings = match self.make_request(&texts).await {
1576 Ok(embeddings) => {
1577 self.update_metrics_success(&texts);
1578 embeddings
1579 }
1580 Err(e) => {
1581 self.update_metrics_failure();
1582 return Err(e);
1583 }
1584 };
1585
1586 if embeddings.len() != chunk.len() {
1587 self.update_metrics_failure();
1588 return Err(anyhow!("Mismatch between request and response sizes"));
1589 }
1590
1591 let batch_cost = self.calculate_cost(&texts) / chunk.len() as f64;
1592
1593 for (content, embedding) in chunk.iter().zip(embeddings) {
1594 let vector = Vector::new(embedding);
1595
1596 if self.openai_config.enable_cache {
1598 let hash = content.content_hash();
1599 let cached_embedding = CachedEmbedding {
1600 vector: vector.clone(),
1601 cached_at: std::time::SystemTime::now(),
1602 model: self.openai_config.model.clone(),
1603 cost_usd: batch_cost,
1604 };
1605 if let Ok(mut cache) = self.request_cache.lock() {
1606 cache.put(hash, cached_embedding);
1607 }
1608 }
1609
1610 results.push(vector);
1611 }
1612 }
1613
1614 Ok(results)
1615 }
1616
1617 pub fn clear_cache(&mut self) {
1619 if let Ok(mut cache) = self.request_cache.lock() {
1620 cache.clear();
1621 }
1622 }
1623
1624 pub fn cache_stats(&self) -> (usize, Option<usize>) {
1626 match self.request_cache.lock() {
1627 Ok(cache) => (cache.len(), Some(cache.cap().into())),
1628 _ => (0, None),
1629 }
1630 }
1631
1632 pub fn get_cache_cost(&self) -> f64 {
1634 match self.request_cache.lock() {
1635 Ok(cache) => cache.iter().map(|(_, cached)| cached.cost_usd).sum(),
1636 _ => 0.0,
1637 }
1638 }
1639
1640 pub fn get_metrics(&self) -> &OpenAIMetrics {
1642 &self.metrics
1643 }
1644
1645 pub fn reset_metrics(&mut self) {
1647 self.metrics = OpenAIMetrics::default();
1648 }
1649
1650 fn estimate_tokens(&self, text: &str) -> u64 {
1652 (text.len() / 4).max(1) as u64
1655 }
1656
1657 fn calculate_cost_from_tokens(&self, total_tokens: u64) -> f64 {
1659 let cost_per_1k_tokens = match self.openai_config.model.as_str() {
1661 "text-embedding-ada-002" => 0.0001, "text-embedding-3-small" => 0.00002, "text-embedding-3-large" => 0.00013, _ => 0.0001, };
1666
1667 (total_tokens as f64 / 1000.0) * cost_per_1k_tokens
1668 }
1669
1670 fn update_metrics_success(&mut self, texts: &[String]) {
1672 self.metrics.total_requests += 1;
1673 self.metrics.successful_requests += 1;
1674
1675 let total_tokens: u64 = texts.iter().map(|text| self.estimate_tokens(text)).sum();
1676
1677 self.metrics.total_tokens_processed += total_tokens;
1678 self.metrics.total_cost_usd += self.calculate_cost_from_tokens(total_tokens);
1679 }
1680
1681 fn update_metrics_failure(&mut self) {
1683 self.metrics.total_requests += 1;
1684 self.metrics.failed_requests += 1;
1685 }
1686
1687 fn update_cache_hit(&mut self) {
1689 self.metrics.cache_hits += 1;
1690 }
1691
1692 fn update_cache_miss(&mut self) {
1693 self.metrics.cache_misses += 1;
1694 }
1695}
1696
1697impl EmbeddingGenerator for OpenAIEmbeddingGenerator {
1698 fn generate(&self, content: &EmbeddableContent) -> Result<Vector> {
1699 if self.openai_config.enable_cache {
1701 let hash = content.content_hash();
1702 if let Ok(mut cache) = self.request_cache.lock() {
1703 if let Some(cached) = cache.get(&hash) {
1704 return Ok(cached.vector.clone());
1705 }
1706 }
1707 }
1708
1709 let rt = tokio::runtime::Runtime::new()
1712 .map_err(|e| anyhow!("Failed to create async runtime: {}", e))?;
1713
1714 let mut temp_generator = OpenAIEmbeddingGenerator {
1716 config: self.config.clone(),
1717 openai_config: self.openai_config.clone(),
1718 client: self.client.clone(),
1719 rate_limiter: RateLimiter::new(self.openai_config.requests_per_minute),
1720 request_cache: self.request_cache.clone(),
1721 metrics: self.metrics.clone(),
1722 };
1723
1724 rt.block_on(temp_generator.generate_async(content))
1725 }
1726
1727 fn dimensions(&self) -> usize {
1728 self.config.dimensions
1729 }
1730
1731 fn config(&self) -> &EmbeddingConfig {
1732 &self.config
1733 }
1734}
1735
1736impl AsAny for OpenAIEmbeddingGenerator {
1737 fn as_any(&self) -> &dyn std::any::Any {
1738 self
1739 }
1740
1741 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
1742 self
1743 }
1744}
1745
1746#[cfg(test)]
1748pub struct MockEmbeddingGenerator {
1749 config: EmbeddingConfig,
1750}
1751
1752#[cfg(test)]
1753impl Default for MockEmbeddingGenerator {
1754 fn default() -> Self {
1755 Self::new()
1756 }
1757}
1758
1759#[cfg(test)]
1760impl MockEmbeddingGenerator {
1761 pub fn new() -> Self {
1762 Self {
1763 config: EmbeddingConfig {
1764 dimensions: 128,
1765 ..Default::default()
1766 },
1767 }
1768 }
1769
1770 pub fn with_dimensions(dimensions: usize) -> Self {
1771 Self {
1772 config: EmbeddingConfig {
1773 dimensions,
1774 ..Default::default()
1775 },
1776 }
1777 }
1778}
1779
1780#[cfg(test)]
1781impl AsAny for MockEmbeddingGenerator {
1782 fn as_any(&self) -> &dyn std::any::Any {
1783 self
1784 }
1785
1786 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
1787 self
1788 }
1789}
1790
1791#[cfg(test)]
1792impl EmbeddingGenerator for MockEmbeddingGenerator {
1793 fn generate(&self, content: &EmbeddableContent) -> Result<crate::Vector> {
1794 let text = content.to_text();
1795
1796 let mut hasher = std::collections::hash_map::DefaultHasher::new();
1798 text.hash(&mut hasher);
1799 let hash = hasher.finish();
1800
1801 let mut embedding = Vec::with_capacity(self.config.dimensions);
1802 let mut seed = hash;
1803
1804 for _ in 0..self.config.dimensions {
1805 seed = seed.wrapping_mul(1664525).wrapping_add(1013904223);
1807 let value = (seed as f64 / u64::MAX as f64) as f32;
1808 embedding.push(value * 2.0 - 1.0); }
1810
1811 let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
1813 if magnitude > 0.0 {
1814 for value in &mut embedding {
1815 *value /= magnitude;
1816 }
1817 }
1818
1819 Ok(crate::Vector::new(embedding))
1820 }
1821
1822 fn dimensions(&self) -> usize {
1823 self.config.dimensions
1824 }
1825
1826 fn config(&self) -> &EmbeddingConfig {
1827 &self.config
1828 }
1829}
1830
1831#[cfg(test)]
1832mod tests {
1833 use super::*;
1834
1835 #[test]
1836 fn test_transformer_model_types() {
1837 let config = EmbeddingConfig::default();
1838
1839 let bert = SentenceTransformerGenerator::new(config.clone());
1841 assert!(matches!(bert.model_type(), TransformerModelType::BERT));
1842 assert_eq!(bert.dimensions(), 384); let roberta = SentenceTransformerGenerator::roberta(config.clone());
1846 assert!(matches!(
1847 roberta.model_type(),
1848 TransformerModelType::RoBERTa
1849 ));
1850
1851 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
1853 assert!(matches!(
1854 distilbert.model_type(),
1855 TransformerModelType::DistilBERT
1856 ));
1857 assert_eq!(distilbert.dimensions(), 384); let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
1861 assert!(matches!(
1862 multibert.model_type(),
1863 TransformerModelType::MultiBERT
1864 ));
1865 }
1866
1867 #[test]
1868 fn test_model_details() {
1869 let config = EmbeddingConfig::default();
1870
1871 let bert = SentenceTransformerGenerator::new(config.clone());
1873 let bert_details = bert.model_details();
1874 assert_eq!(bert_details.vocab_size, 30522);
1875 assert_eq!(bert_details.num_layers, 12);
1876 assert_eq!(bert_details.hidden_size, 768);
1877 assert!(bert_details.supports_languages.contains(&"en".to_string()));
1878
1879 let roberta = SentenceTransformerGenerator::roberta(config.clone());
1881 let roberta_details = roberta.model_details();
1882 assert_eq!(roberta_details.vocab_size, 50265); assert_eq!(roberta_details.max_position_embeddings, 514); let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
1887 let distilbert_details = distilbert.model_details();
1888 assert_eq!(distilbert_details.num_layers, 6); assert_eq!(distilbert_details.hidden_size, 384); assert!(distilbert_details.model_size_mb < bert_details.model_size_mb); assert!(
1892 distilbert_details.typical_inference_time_ms < bert_details.typical_inference_time_ms
1893 ); let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
1897 let multibert_details = multibert.model_details();
1898 assert_eq!(multibert_details.vocab_size, 120000); assert!(multibert_details.supports_languages.len() > 10); assert!(multibert_details
1901 .supports_languages
1902 .contains(&"zh".to_string())); assert!(multibert_details
1904 .supports_languages
1905 .contains(&"de".to_string())); }
1907
1908 #[test]
1909 fn test_language_support() {
1910 let config = EmbeddingConfig::default();
1911
1912 let bert = SentenceTransformerGenerator::new(config.clone());
1914 assert!(bert.supports_language("en"));
1915 assert!(!bert.supports_language("zh"));
1916 assert!(!bert.supports_language("de"));
1917
1918 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
1919 assert!(distilbert.supports_language("en"));
1920 assert!(!distilbert.supports_language("zh"));
1921
1922 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
1924 assert!(multibert.supports_language("en"));
1925 assert!(multibert.supports_language("zh"));
1926 assert!(multibert.supports_language("de"));
1927 assert!(multibert.supports_language("fr"));
1928 assert!(multibert.supports_language("es"));
1929 assert!(!multibert.supports_language("unknown_lang"));
1930 }
1931
1932 #[test]
1933 fn test_efficiency_ratings() {
1934 let config = EmbeddingConfig::default();
1935
1936 let bert = SentenceTransformerGenerator::new(config.clone());
1937 let roberta = SentenceTransformerGenerator::roberta(config.clone());
1938 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
1939 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
1940
1941 assert!(distilbert.efficiency_rating() > bert.efficiency_rating());
1943 assert!(distilbert.efficiency_rating() > roberta.efficiency_rating());
1944 assert!(distilbert.efficiency_rating() > multibert.efficiency_rating());
1945
1946 assert!(bert.efficiency_rating() > roberta.efficiency_rating());
1948
1949 assert!(bert.efficiency_rating() > multibert.efficiency_rating());
1951 assert!(roberta.efficiency_rating() > multibert.efficiency_rating());
1952 }
1953
1954 #[test]
1955 fn test_inference_time_estimation() {
1956 let config = EmbeddingConfig::default();
1957 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
1958 let bert = SentenceTransformerGenerator::new(config.clone());
1959
1960 let short_time_distilbert = distilbert.estimate_inference_time(50);
1962 let short_time_bert = bert.estimate_inference_time(50);
1963
1964 let long_time_distilbert = distilbert.estimate_inference_time(500);
1966 let long_time_bert = bert.estimate_inference_time(500);
1967
1968 assert!(short_time_distilbert < short_time_bert);
1970 assert!(long_time_distilbert < long_time_bert);
1971
1972 assert!(long_time_distilbert > short_time_distilbert);
1974 assert!(long_time_bert > short_time_bert);
1975 }
1976
1977 #[test]
1978 fn test_model_specific_text_preprocessing() {
1979 let config = EmbeddingConfig::default();
1980
1981 let bert = SentenceTransformerGenerator::new(config.clone());
1982 let roberta = SentenceTransformerGenerator::roberta(config.clone());
1983 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
1984
1985 let text = "Hello World";
1986
1987 let bert_processed = bert.preprocess_text_for_model(text, 512).unwrap();
1989 assert!(bert_processed.contains("[CLS]"));
1990 assert!(bert_processed.contains("[SEP]"));
1991 assert!(bert_processed.contains("hello world")); let roberta_processed = roberta.preprocess_text_for_model(text, 512).unwrap();
1995 assert!(roberta_processed.contains("<s>"));
1996 assert!(roberta_processed.contains("</s>"));
1997 assert!(roberta_processed.contains("Hello World")); let latin_text = "Hello World";
2001 let chinese_text = "你好世界";
2002
2003 let latin_processed = multibert
2004 .preprocess_text_for_model(latin_text, 512)
2005 .unwrap();
2006 let chinese_processed = multibert
2007 .preprocess_text_for_model(chinese_text, 512)
2008 .unwrap();
2009
2010 assert!(latin_processed.contains("hello world")); assert!(chinese_processed.contains("你好世界")); }
2013
2014 #[test]
2015 fn test_embedding_generation_differences() {
2016 let config = EmbeddingConfig::default();
2017
2018 let bert = SentenceTransformerGenerator::new(config.clone());
2019 let roberta = SentenceTransformerGenerator::roberta(config.clone());
2020 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
2021
2022 let content = EmbeddableContent::Text("This is a test sentence".to_string());
2023
2024 let bert_embedding = bert.generate(&content).unwrap();
2025 let roberta_embedding = roberta.generate(&content).unwrap();
2026 let distilbert_embedding = distilbert.generate(&content).unwrap();
2027
2028 assert_ne!(bert_embedding.as_f32(), roberta_embedding.as_f32());
2030 assert_ne!(bert_embedding.as_f32(), distilbert_embedding.as_f32());
2031 assert_ne!(roberta_embedding.as_f32(), distilbert_embedding.as_f32());
2032
2033 assert_eq!(distilbert_embedding.dimensions, 384);
2035 assert_eq!(bert_embedding.dimensions, 384); assert_eq!(roberta_embedding.dimensions, 384);
2037
2038 if config.normalize {
2040 let bert_magnitude: f32 = bert_embedding
2041 .as_f32()
2042 .iter()
2043 .map(|x| x * x)
2044 .sum::<f32>()
2045 .sqrt();
2046 let roberta_magnitude: f32 = roberta_embedding
2047 .as_f32()
2048 .iter()
2049 .map(|x| x * x)
2050 .sum::<f32>()
2051 .sqrt();
2052 let distilbert_magnitude: f32 = distilbert_embedding
2053 .as_f32()
2054 .iter()
2055 .map(|x| x * x)
2056 .sum::<f32>()
2057 .sqrt();
2058
2059 assert!((bert_magnitude - 1.0).abs() < 0.1);
2060 assert!((roberta_magnitude - 1.0).abs() < 0.1);
2061 assert!((distilbert_magnitude - 1.0).abs() < 0.1);
2062 }
2063 }
2064
2065 #[test]
2066 fn test_tokenization_differences() {
2067 let config = EmbeddingConfig::default();
2068
2069 let bert = SentenceTransformerGenerator::new(config.clone());
2070 let roberta = SentenceTransformerGenerator::roberta(config.clone());
2071 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
2072
2073 let model_details_bert = bert.get_model_details();
2074 let model_details_roberta = roberta.get_model_details();
2075 let model_details_multibert = multibert.get_model_details();
2076
2077 let complex_word = "preprocessing";
2078
2079 let bert_tokens =
2081 bert.simulate_wordpiece_tokenization(complex_word, model_details_bert.vocab_size);
2082 let roberta_tokens =
2083 roberta.simulate_bpe_tokenization(complex_word, model_details_roberta.vocab_size);
2084 let multibert_tokens = multibert
2085 .simulate_multilingual_tokenization(complex_word, model_details_multibert.vocab_size);
2086
2087 assert!(roberta_tokens.len() >= bert_tokens.len());
2089
2090 assert!(multibert_tokens.len() <= bert_tokens.len());
2092
2093 for token in &bert_tokens {
2095 assert!(*token < model_details_bert.vocab_size as u32);
2096 }
2097 for token in &roberta_tokens {
2098 assert!(*token < model_details_roberta.vocab_size as u32);
2099 }
2100 for token in &multibert_tokens {
2101 assert!(*token < model_details_multibert.vocab_size as u32);
2102 }
2103 }
2104
2105 #[test]
2106 fn test_model_size_comparisons() {
2107 let config = EmbeddingConfig::default();
2108
2109 let bert = SentenceTransformerGenerator::new(config.clone());
2110 let roberta = SentenceTransformerGenerator::roberta(config.clone());
2111 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
2112 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
2113
2114 let bert_size = bert.model_size_mb();
2115 let roberta_size = roberta.model_size_mb();
2116 let distilbert_size = distilbert.model_size_mb();
2117 let multibert_size = multibert.model_size_mb();
2118
2119 assert!(distilbert_size < bert_size);
2121 assert!(distilbert_size < roberta_size);
2122 assert!(distilbert_size < multibert_size);
2123
2124 assert!(multibert_size > bert_size);
2126 assert!(multibert_size > roberta_size);
2127 assert!(multibert_size > distilbert_size);
2128
2129 assert!(roberta_size > bert_size);
2131 }
2132}