1use crate::Vector;
4use anyhow::{anyhow, Result};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::hash::{Hash, Hasher};
8use std::time::Duration;
9#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct EmbeddingConfig {
14 pub model_name: String,
15 pub dimensions: usize,
16 pub max_sequence_length: usize,
17 pub normalize: bool,
18}
19
20impl Default for EmbeddingConfig {
21 fn default() -> Self {
22 Self {
23 model_name: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
24 dimensions: 384,
25 max_sequence_length: 512,
26 normalize: true,
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
33pub enum EmbeddableContent {
34 Text(String),
36 RdfResource {
38 uri: String,
39 label: Option<String>,
40 description: Option<String>,
41 properties: HashMap<String, Vec<String>>,
42 },
43 SparqlQuery(String),
45 GraphPattern(String),
47}
48
49impl EmbeddableContent {
50 pub fn to_text(&self) -> String {
52 match self {
53 EmbeddableContent::Text(text) => text.clone(),
54 EmbeddableContent::RdfResource {
55 uri,
56 label,
57 description,
58 properties,
59 } => {
60 let mut text_parts = vec![uri.clone()];
61
62 if let Some(label) = label {
63 text_parts.push(format!("label: {label}"));
64 }
65
66 if let Some(desc) = description {
67 text_parts.push(format!("description: {desc}"));
68 }
69
70 for (prop, values) in properties {
71 text_parts.push(format!("{prop}: {}", values.join(", ")));
72 }
73
74 text_parts.join(" ")
75 }
76 EmbeddableContent::SparqlQuery(query) => query.clone(),
77 EmbeddableContent::GraphPattern(pattern) => pattern.clone(),
78 }
79 }
80
81 pub fn content_hash(&self) -> u64 {
83 let mut hasher = std::collections::hash_map::DefaultHasher::new();
84 self.to_text().hash(&mut hasher);
85 hasher.finish()
86 }
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub enum EmbeddingStrategy {
92 TfIdf,
94 SentenceTransformer,
96 Transformer(TransformerModelType),
98 Word2Vec(crate::word2vec::Word2VecConfig),
100 OpenAI(OpenAIConfig),
102 Custom(String),
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct OpenAIConfig {
109 pub api_key: String,
111 pub model: String,
113 pub base_url: String,
115 pub timeout_seconds: u64,
117 pub requests_per_minute: u32,
119 pub batch_size: usize,
121 pub enable_cache: bool,
123 pub cache_size: usize,
125 pub cache_ttl_seconds: u64,
127 pub max_retries: u32,
129 pub retry_delay_ms: u64,
131 pub retry_strategy: RetryStrategy,
133 pub track_costs: bool,
135 pub enable_metrics: bool,
137 pub user_agent: String,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
143pub enum RetryStrategy {
144 Fixed,
146 ExponentialBackoff,
148 LinearBackoff,
150}
151
152impl Default for OpenAIConfig {
153 fn default() -> Self {
154 Self {
155 api_key: std::env::var("OPENAI_API_KEY").unwrap_or_default(),
156 model: "text-embedding-3-small".to_string(),
157 base_url: "https://api.openai.com/v1".to_string(),
158 timeout_seconds: 30,
159 requests_per_minute: 3000,
160 batch_size: 100,
161 enable_cache: true,
162 cache_size: 10000,
163 cache_ttl_seconds: 3600, max_retries: 3,
165 retry_delay_ms: 1000,
166 retry_strategy: RetryStrategy::ExponentialBackoff,
167 track_costs: true,
168 enable_metrics: true,
169 user_agent: "oxirs-vec/0.1.0".to_string(),
170 }
171 }
172}
173
174impl OpenAIConfig {
175 pub fn production() -> Self {
177 Self {
178 requests_per_minute: 1000, cache_size: 50000,
180 cache_ttl_seconds: 7200, max_retries: 5,
182 retry_strategy: RetryStrategy::ExponentialBackoff,
183 ..Default::default()
184 }
185 }
186
187 pub fn development() -> Self {
189 Self {
190 requests_per_minute: 100,
191 cache_size: 1000,
192 cache_ttl_seconds: 300, max_retries: 2,
194 ..Default::default()
195 }
196 }
197
198 pub fn validate(&self) -> Result<()> {
200 if self.api_key.is_empty() {
201 return Err(anyhow!("OpenAI API key is required"));
202 }
203 if self.requests_per_minute == 0 {
204 return Err(anyhow!("requests_per_minute must be greater than 0"));
205 }
206 if self.batch_size == 0 {
207 return Err(anyhow!("batch_size must be greater than 0"));
208 }
209 if self.timeout_seconds == 0 {
210 return Err(anyhow!("timeout_seconds must be greater than 0"));
211 }
212 Ok(())
213 }
214}
215
216pub trait EmbeddingGenerator: Send + Sync + AsAny {
218 fn generate(&self, content: &EmbeddableContent) -> Result<Vector>;
220
221 fn generate_batch(&self, contents: &[EmbeddableContent]) -> Result<Vec<Vector>> {
223 contents.iter().map(|c| self.generate(c)).collect()
224 }
225
226 fn dimensions(&self) -> usize;
228
229 fn config(&self) -> &EmbeddingConfig;
231}
232
233pub struct TfIdfEmbeddingGenerator {
235 config: EmbeddingConfig,
236 vocabulary: HashMap<String, usize>,
237 idf_scores: HashMap<String, f32>,
238}
239
240impl TfIdfEmbeddingGenerator {
241 pub fn new(config: EmbeddingConfig) -> Self {
242 Self {
243 config,
244 vocabulary: HashMap::new(),
245 idf_scores: HashMap::new(),
246 }
247 }
248
249 pub fn build_vocabulary(&mut self, documents: &[String]) -> Result<()> {
251 let mut word_counts: HashMap<String, usize> = HashMap::new();
252 let mut doc_counts: HashMap<String, usize> = HashMap::new();
253
254 for doc in documents {
255 let words: Vec<String> = self.tokenize(doc);
256 let unique_words: std::collections::HashSet<_> = words.iter().collect();
257
258 for word in &words {
259 *word_counts.entry(word.clone()).or_insert(0) += 1;
260 }
261
262 for word in unique_words {
263 *doc_counts.entry(word.clone()).or_insert(0) += 1;
264 }
265 }
266
267 let mut word_freq: Vec<(String, usize)> = word_counts.into_iter().collect();
269 word_freq.sort_by(|a, b| b.1.cmp(&a.1));
270
271 self.vocabulary = word_freq
272 .into_iter()
273 .take(self.config.dimensions)
274 .enumerate()
275 .map(|(idx, (word, _))| (word, idx))
276 .collect();
277
278 let total_docs = documents.len() as f32;
280 for word in self.vocabulary.keys() {
281 let doc_freq = doc_counts.get(word).unwrap_or(&0);
282 let idf = (total_docs / (*doc_freq as f32 + 1.0)).ln();
283 self.idf_scores.insert(word.clone(), idf);
284 }
285
286 Ok(())
287 }
288
289 fn tokenize(&self, text: &str) -> Vec<String> {
290 text.to_lowercase()
291 .split_whitespace()
292 .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
293 .filter(|s| !s.is_empty())
294 .map(String::from)
295 .collect()
296 }
297
298 fn calculate_tf_idf(&self, text: &str) -> Vector {
299 let words = self.tokenize(text);
300 let mut tf_counts: HashMap<String, usize> = HashMap::new();
301
302 for word in &words {
303 *tf_counts.entry(word.clone()).or_insert(0) += 1;
304 }
305
306 let total_words = words.len() as f32;
307 let mut embedding = vec![0.0; self.config.dimensions];
308
309 for (word, count) in tf_counts {
310 if let Some(&idx) = self.vocabulary.get(&word) {
311 let tf = count as f32 / total_words;
312 let idf = self.idf_scores.get(&word).unwrap_or(&0.0);
313 embedding[idx] = tf * idf;
314 }
315 }
316
317 if self.config.normalize {
318 self.normalize_vector(&mut embedding);
319 }
320
321 Vector::new(embedding)
322 }
323
324 fn normalize_vector(&self, vector: &mut [f32]) {
325 let magnitude: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
326 if magnitude > 0.0 {
327 for value in vector {
328 *value /= magnitude;
329 }
330 }
331 }
332}
333
334impl EmbeddingGenerator for TfIdfEmbeddingGenerator {
335 fn generate(&self, content: &EmbeddableContent) -> Result<Vector> {
336 if self.vocabulary.is_empty() {
337 return Err(anyhow!(
338 "Vocabulary not built. Call build_vocabulary first."
339 ));
340 }
341
342 let text = content.to_text();
343 Ok(self.calculate_tf_idf(&text))
344 }
345
346 fn dimensions(&self) -> usize {
347 self.config.dimensions
348 }
349
350 fn config(&self) -> &EmbeddingConfig {
351 &self.config
352 }
353}
354
355pub struct SentenceTransformerGenerator {
357 config: EmbeddingConfig,
358 model_type: TransformerModelType,
359}
360
361#[derive(Debug, Clone, Serialize, Deserialize, Default)]
363pub enum TransformerModelType {
364 #[default]
366 BERT,
367 RoBERTa,
369 DistilBERT,
371 MultiBERT,
373 Custom(String),
375}
376
377#[derive(Debug, Clone)]
379pub struct ModelDetails {
380 pub vocab_size: usize,
381 pub num_layers: usize,
382 pub num_attention_heads: usize,
383 pub hidden_size: usize,
384 pub intermediate_size: usize,
385 pub max_position_embeddings: usize,
386 pub supports_languages: Vec<String>,
387 pub model_size_mb: usize,
388 pub typical_inference_time_ms: u64,
389}
390
391impl SentenceTransformerGenerator {
392 pub fn new(config: EmbeddingConfig) -> Self {
393 Self {
394 config,
395 model_type: TransformerModelType::default(),
396 }
397 }
398
399 pub fn with_model_type(config: EmbeddingConfig, model_type: TransformerModelType) -> Self {
400 Self { config, model_type }
401 }
402
403 pub fn roberta(config: EmbeddingConfig) -> Self {
405 Self::with_model_type(config, TransformerModelType::RoBERTa)
406 }
407
408 pub fn distilbert(config: EmbeddingConfig) -> Self {
410 let adjusted_config = EmbeddingConfig {
411 dimensions: 384, ..config
413 };
414 Self::with_model_type(adjusted_config, TransformerModelType::DistilBERT)
415 }
416
417 pub fn multilingual_bert(config: EmbeddingConfig) -> Self {
419 Self::with_model_type(config, TransformerModelType::MultiBERT)
420 }
421
422 pub fn model_type(&self) -> &TransformerModelType {
424 &self.model_type
425 }
426
427 pub fn model_details(&self) -> ModelDetails {
429 self.get_model_details()
430 }
431
432 pub fn supports_language(&self, language_code: &str) -> bool {
434 let details = self.get_model_details();
435 details
436 .supports_languages
437 .contains(&language_code.to_string())
438 }
439
440 pub fn estimate_inference_time(&self, text_length: usize) -> u64 {
442 let details = self.get_model_details();
443 let base_time = details.typical_inference_time_ms;
444
445 let length_factor = (text_length as f64 / 100.0).sqrt().max(1.0);
447 (base_time as f64 * length_factor) as u64
448 }
449
450 pub fn model_size_mb(&self) -> usize {
452 self.get_model_details().model_size_mb
453 }
454
455 pub fn efficiency_rating(&self) -> f32 {
457 match &self.model_type {
458 TransformerModelType::DistilBERT => 1.5, TransformerModelType::BERT => 1.0, TransformerModelType::RoBERTa => 0.95, TransformerModelType::MultiBERT => 0.8, TransformerModelType::Custom(_) => 1.0, }
464 }
465
466 fn get_model_config(&self) -> (usize, usize, f32) {
468 match &self.model_type {
469 TransformerModelType::BERT => (self.config.dimensions, 512, 1.0), TransformerModelType::RoBERTa => (self.config.dimensions, 514, 0.95), TransformerModelType::DistilBERT => (self.config.dimensions, 512, 1.5), TransformerModelType::MultiBERT => (self.config.dimensions, 512, 0.8), TransformerModelType::Custom(_) => {
474 (self.config.dimensions, self.config.max_sequence_length, 1.0)
475 }
476 }
477 }
478
479 fn get_model_details(&self) -> ModelDetails {
481 match &self.model_type {
482 TransformerModelType::BERT => ModelDetails {
483 vocab_size: 30522,
484 num_layers: 12,
485 num_attention_heads: 12,
486 hidden_size: 768,
487 intermediate_size: 3072,
488 max_position_embeddings: 512,
489 supports_languages: vec!["en".to_string()],
490 model_size_mb: 440,
491 typical_inference_time_ms: 50,
492 },
493 TransformerModelType::RoBERTa => ModelDetails {
494 vocab_size: 50265,
495 num_layers: 12,
496 num_attention_heads: 12,
497 hidden_size: 768,
498 intermediate_size: 3072,
499 max_position_embeddings: 514,
500 supports_languages: vec!["en".to_string()],
501 model_size_mb: 470,
502 typical_inference_time_ms: 55, },
504 TransformerModelType::DistilBERT => ModelDetails {
505 vocab_size: 30522,
506 num_layers: 6, num_attention_heads: 12,
508 hidden_size: 384, intermediate_size: 1536,
510 max_position_embeddings: 512,
511 supports_languages: vec!["en".to_string()],
512 model_size_mb: 250, typical_inference_time_ms: 25, },
515 TransformerModelType::MultiBERT => ModelDetails {
516 vocab_size: 120000, num_layers: 12,
518 num_attention_heads: 12,
519 hidden_size: 768,
520 intermediate_size: 3072,
521 max_position_embeddings: 512,
522 supports_languages: vec![
523 "en".to_string(),
524 "de".to_string(),
525 "fr".to_string(),
526 "es".to_string(),
527 "it".to_string(),
528 "pt".to_string(),
529 "ru".to_string(),
530 "zh".to_string(),
531 "ja".to_string(),
532 "ko".to_string(),
533 "ar".to_string(),
534 "hi".to_string(),
535 "th".to_string(),
536 "tr".to_string(),
537 "pl".to_string(),
538 "nl".to_string(),
539 "sv".to_string(),
540 "da".to_string(),
541 "no".to_string(),
542 "fi".to_string(),
543 ], model_size_mb: 670, typical_inference_time_ms: 70, },
547 TransformerModelType::Custom(_path) => ModelDetails {
548 vocab_size: 50000, num_layers: 12,
550 num_attention_heads: 12,
551 hidden_size: self.config.dimensions,
552 intermediate_size: self.config.dimensions * 4,
553 max_position_embeddings: self.config.max_sequence_length,
554 supports_languages: vec!["unknown".to_string()],
555 model_size_mb: 500, typical_inference_time_ms: 60,
557 },
558 }
559 }
560
561 fn generate_with_model(&self, text: &str) -> Result<Vector> {
563 let _text_hash = {
564 use std::hash::{Hash, Hasher};
565 let mut hasher = std::collections::hash_map::DefaultHasher::new();
566 text.hash(&mut hasher);
567 hasher.finish()
568 };
569
570 let (dimensions, max_len, _efficiency) = self.get_model_config();
571 let model_details = self.get_model_details();
572
573 let processed_text = self.preprocess_text_for_model(text, max_len)?;
575
576 let token_ids = self.simulate_tokenization(&processed_text, &model_details);
578
579 let values =
581 self.generate_embeddings_from_tokens(&token_ids, dimensions, &model_details)?;
582
583 if self.config.normalize {
584 let magnitude: f32 = values.iter().map(|x| x * x).sum::<f32>().sqrt();
585 if magnitude > 0.0 {
586 let mut normalized_values = values;
587 for value in &mut normalized_values {
588 *value /= magnitude;
589 }
590 return Ok(Vector::new(normalized_values));
591 }
592 }
593
594 Ok(Vector::new(values))
595 }
596
597 fn preprocess_text_for_model(&self, text: &str, max_len: usize) -> Result<String> {
599 let processed = match &self.model_type {
600 TransformerModelType::BERT => {
601 let truncated = if text.len() > max_len - 20 {
603 &text[..max_len - 20]
605 } else {
606 text
607 };
608 format!("[CLS] {} [SEP]", truncated.to_lowercase())
609 }
610 TransformerModelType::RoBERTa => {
611 let truncated = if text.len() > max_len - 10 {
613 &text[..max_len - 10]
614 } else {
615 text
616 };
617 format!("<s>{truncated}</s>") }
619 TransformerModelType::DistilBERT => {
620 let truncated = if text.len() > max_len - 20 {
622 &text[..max_len - 20]
623 } else {
624 text
625 };
626 format!("[CLS] {} [SEP]", truncated.to_lowercase())
627 }
628 TransformerModelType::MultiBERT => {
629 let truncated = if text.len() > max_len - 20 {
631 &text[..max_len - 20]
632 } else {
633 text
634 };
635 let has_non_latin = !text.is_ascii();
637 if has_non_latin {
638 format!("[CLS] {truncated} [SEP]") } else {
640 format!("[CLS] {} [SEP]", truncated.to_lowercase()) }
642 }
643 TransformerModelType::Custom(_) => {
644 let truncated = if text.len() > max_len {
646 &text[..max_len]
647 } else {
648 text
649 };
650 truncated.to_string()
651 }
652 };
653
654 Ok(processed)
655 }
656
657 fn simulate_tokenization(&self, text: &str, model_details: &ModelDetails) -> Vec<u32> {
659 let mut token_ids = Vec::new();
660
661 let words: Vec<&str> = text.split_whitespace().collect();
663
664 for word in words {
665 let subwords = match &self.model_type {
667 TransformerModelType::RoBERTa => {
668 self.simulate_bpe_tokenization(word, model_details.vocab_size)
670 }
671 TransformerModelType::DistilBERT | TransformerModelType::BERT => {
672 self.simulate_wordpiece_tokenization(word, model_details.vocab_size)
674 }
675 TransformerModelType::MultiBERT => {
676 self.simulate_multilingual_tokenization(word, model_details.vocab_size)
678 }
679 TransformerModelType::Custom(_) => {
680 vec![self.word_to_token_id(word, model_details.vocab_size)]
682 }
683 };
684
685 token_ids.extend(subwords);
686 }
687
688 token_ids.truncate(model_details.max_position_embeddings - 2); token_ids
691 }
692
693 fn simulate_bpe_tokenization(&self, word: &str, vocab_size: usize) -> Vec<u32> {
695 let mut tokens = Vec::new();
696 let mut remaining = word;
697
698 while !remaining.is_empty() {
699 let chunk_size = if remaining.len() > 4 {
700 4
701 } else {
702 remaining.len()
703 };
704 let chunk = &remaining[..chunk_size];
705 tokens.push(self.word_to_token_id(chunk, vocab_size));
706 remaining = &remaining[chunk_size..];
707 }
708
709 tokens
710 }
711
712 fn simulate_wordpiece_tokenization(&self, word: &str, vocab_size: usize) -> Vec<u32> {
714 if word.len() <= 6 {
715 vec![self.word_to_token_id(word, vocab_size)]
716 } else {
717 let mid = word.len() / 2;
718 vec![
719 self.word_to_token_id(&word[..mid], vocab_size),
720 self.word_to_token_id(&format!("##{}", &word[mid..]), vocab_size), ]
722 }
723 }
724
725 fn simulate_multilingual_tokenization(&self, word: &str, vocab_size: usize) -> Vec<u32> {
727 if word.len() <= 10 {
729 vec![self.word_to_token_id(word, vocab_size)]
730 } else {
731 let mid = word.len() / 2;
732 vec![
733 self.word_to_token_id(&word[..mid], vocab_size),
734 self.word_to_token_id(&word[mid..], vocab_size),
735 ]
736 }
737 }
738
739 fn word_to_token_id(&self, word: &str, vocab_size: usize) -> u32 {
741 use std::hash::{Hash, Hasher};
742 let mut hasher = std::collections::hash_map::DefaultHasher::new();
743 word.hash(&mut hasher);
744 (hasher.finish() % vocab_size as u64) as u32
745 }
746
747 fn generate_embeddings_from_tokens(
749 &self,
750 token_ids: &[u32],
751 dimensions: usize,
752 model_details: &ModelDetails,
753 ) -> Result<Vec<f32>> {
754 let mut values = vec![0.0; dimensions];
755
756 match &self.model_type {
758 TransformerModelType::BERT => {
759 self.generate_bert_style_embeddings(token_ids, &mut values, model_details)
760 }
761 TransformerModelType::RoBERTa => {
762 self.generate_roberta_style_embeddings(token_ids, &mut values, model_details)
763 }
764 TransformerModelType::DistilBERT => {
765 self.generate_distilbert_style_embeddings(token_ids, &mut values, model_details)
766 }
767 TransformerModelType::MultiBERT => {
768 self.generate_multibert_style_embeddings(token_ids, &mut values, model_details)
769 }
770 TransformerModelType::Custom(_) => {
771 self.generate_custom_style_embeddings(token_ids, &mut values, model_details)
772 }
773 }
774
775 Ok(values)
776 }
777
778 fn generate_bert_style_embeddings(
780 &self,
781 token_ids: &[u32],
782 values: &mut [f32],
783 _model_details: &ModelDetails,
784 ) {
785 for (i, &token_id) in token_ids.iter().enumerate() {
786 let mut seed = token_id as u64;
787 for value in values.iter_mut() {
788 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
789 let normalized = (seed as f32) / (u64::MAX as f32);
790 let position_encoding =
791 ((i as f32 / 512.0) * 2.0 * std::f32::consts::PI).sin() * 0.1;
792 *value += ((normalized - 0.5) * 2.0) + position_encoding;
793 }
794 }
795
796 if !token_ids.is_empty() {
798 for value in values.iter_mut() {
799 *value /= token_ids.len() as f32;
800 }
801 }
802 }
803
804 fn generate_roberta_style_embeddings(
806 &self,
807 token_ids: &[u32],
808 values: &mut [f32],
809 _model_details: &ModelDetails,
810 ) {
811 for (i, &token_id) in token_ids.iter().enumerate() {
812 let mut seed = token_id.wrapping_mul(31415927); for value in values.iter_mut() {
814 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
815 let normalized = (seed as f32) / (u64::MAX as f32);
816 let position_encoding =
818 ((i as f32 + 2.0) / 514.0 * 2.0 * std::f32::consts::PI).cos() * 0.1;
819 *value += ((normalized - 0.5) * 2.0) + position_encoding;
820 }
821 }
822
823 if !token_ids.is_empty() {
824 for value in values.iter_mut() {
825 *value /= token_ids.len() as f32;
826 }
827 }
828 }
829
830 fn generate_distilbert_style_embeddings(
832 &self,
833 token_ids: &[u32],
834 values: &mut [f32],
835 _model_details: &ModelDetails,
836 ) {
837 for (i, &token_id) in token_ids.iter().enumerate() {
839 let mut seed = token_id as u64;
840 for value in values.iter_mut() {
841 seed = seed.wrapping_mul(982451653).wrapping_add(12345); let normalized = (seed as f32) / (u64::MAX as f32);
843 let position_encoding = (i as f32 / 512.0).sin() * 0.05;
845 *value += ((normalized - 0.5) * 1.5) + position_encoding; }
847 }
848
849 if !token_ids.is_empty() {
850 for value in values.iter_mut() {
851 *value /= token_ids.len() as f32;
852 }
853 }
854 }
855
856 fn generate_multibert_style_embeddings(
858 &self,
859 token_ids: &[u32],
860 values: &mut [f32],
861 _model_details: &ModelDetails,
862 ) {
863 for (i, &token_id) in token_ids.iter().enumerate() {
864 let mut seed = token_id.wrapping_mul(2654435761); for j in 0..values.len() {
867 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
868 let normalized = (seed as f32) / (u64::MAX as f32);
869 let position_encoding =
870 ((i as f32 / 512.0) * 2.0 * std::f32::consts::PI).sin() * 0.08;
871 let cross_lingual_bias =
873 (j as f32 / values.len() as f32 * std::f32::consts::PI).cos() * 0.05;
874 values[j] += ((normalized - 0.5) * 1.8) + position_encoding + cross_lingual_bias;
875 }
876 }
877
878 if !token_ids.is_empty() {
879 for value in values.iter_mut() {
880 *value /= token_ids.len() as f32;
881 }
882 }
883 }
884
885 fn generate_custom_style_embeddings(
887 &self,
888 token_ids: &[u32],
889 values: &mut [f32],
890 _model_details: &ModelDetails,
891 ) {
892 for &token_id in token_ids {
894 let mut seed = token_id as u64;
895 for value in values.iter_mut() {
896 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
897 let normalized = (seed as f32) / (u64::MAX as f32);
898 *value += (normalized - 0.5) * 2.0;
899 }
900 }
901
902 if !token_ids.is_empty() {
903 for value in values.iter_mut() {
904 *value /= token_ids.len() as f32;
905 }
906 }
907 }
908}
909
910impl EmbeddingGenerator for SentenceTransformerGenerator {
911 fn generate(&self, content: &EmbeddableContent) -> Result<Vector> {
912 let text = content.to_text();
913 self.generate_with_model(&text)
914 }
915
916 fn dimensions(&self) -> usize {
917 self.config.dimensions
918 }
919
920 fn config(&self) -> &EmbeddingConfig {
921 &self.config
922 }
923}
924
925pub struct EmbeddingCache {
927 cache: HashMap<u64, Vector>,
928 max_size: usize,
929 access_order: Vec<u64>,
930}
931
932impl EmbeddingCache {
933 pub fn new(max_size: usize) -> Self {
934 Self {
935 cache: HashMap::new(),
936 max_size,
937 access_order: Vec::new(),
938 }
939 }
940
941 pub fn get(&mut self, content: &EmbeddableContent) -> Option<&Vector> {
942 let hash = content.content_hash();
943 if let Some(vector) = self.cache.get(&hash) {
944 if let Some(pos) = self.access_order.iter().position(|&x| x == hash) {
946 self.access_order.remove(pos);
947 }
948 self.access_order.push(hash);
949 Some(vector)
950 } else {
951 None
952 }
953 }
954
955 pub fn insert(&mut self, content: &EmbeddableContent, vector: Vector) {
956 let hash = content.content_hash();
957
958 if self.cache.len() >= self.max_size && !self.cache.contains_key(&hash) {
960 if let Some(&lru_hash) = self.access_order.first() {
961 self.cache.remove(&lru_hash);
962 self.access_order.remove(0);
963 }
964 }
965
966 self.cache.insert(hash, vector);
967 self.access_order.push(hash);
968 }
969
970 pub fn clear(&mut self) {
971 self.cache.clear();
972 self.access_order.clear();
973 }
974
975 pub fn size(&self) -> usize {
976 self.cache.len()
977 }
978}
979
980pub struct EmbeddingManager {
982 generator: Box<dyn EmbeddingGenerator>,
983 cache: EmbeddingCache,
984 strategy: EmbeddingStrategy,
985}
986
987impl EmbeddingManager {
988 pub fn new(strategy: EmbeddingStrategy, cache_size: usize) -> Result<Self> {
989 let generator: Box<dyn EmbeddingGenerator> = match &strategy {
990 EmbeddingStrategy::TfIdf => {
991 let config = EmbeddingConfig::default();
992 Box::new(TfIdfEmbeddingGenerator::new(config))
993 }
994 EmbeddingStrategy::SentenceTransformer => {
995 let config = EmbeddingConfig::default();
996 Box::new(SentenceTransformerGenerator::new(config))
997 }
998 EmbeddingStrategy::Transformer(model_type) => {
999 let config = EmbeddingConfig {
1000 model_name: format!("{model_type:?}"),
1001 dimensions: match model_type {
1002 TransformerModelType::DistilBERT => 384, _ => 768, },
1005 max_sequence_length: 512,
1006 normalize: true,
1007 };
1008 Box::new(SentenceTransformerGenerator::with_model_type(
1009 config,
1010 model_type.clone(),
1011 ))
1012 }
1013 EmbeddingStrategy::Word2Vec(word2vec_config) => {
1014 let embedding_config = EmbeddingConfig {
1015 model_name: "word2vec".to_string(),
1016 dimensions: word2vec_config.dimensions,
1017 max_sequence_length: 512,
1018 normalize: word2vec_config.normalize,
1019 };
1020 Box::new(crate::word2vec::Word2VecEmbeddingGenerator::new(
1021 word2vec_config.clone(),
1022 embedding_config,
1023 )?)
1024 }
1025 EmbeddingStrategy::OpenAI(openai_config) => {
1026 Box::new(OpenAIEmbeddingGenerator::new(openai_config.clone())?)
1027 }
1028 EmbeddingStrategy::Custom(_model_path) => {
1029 let config = EmbeddingConfig::default();
1031 Box::new(SentenceTransformerGenerator::new(config))
1032 }
1033 };
1034
1035 Ok(Self {
1036 generator,
1037 cache: EmbeddingCache::new(cache_size),
1038 strategy,
1039 })
1040 }
1041
1042 pub fn get_embedding(&mut self, content: &EmbeddableContent) -> Result<Vector> {
1044 if let Some(cached) = self.cache.get(content) {
1045 return Ok(cached.clone());
1046 }
1047
1048 let embedding = self.generator.generate(content)?;
1049 self.cache.insert(content, embedding.clone());
1050 Ok(embedding)
1051 }
1052
1053 pub fn precompute_embeddings(&mut self, contents: &[EmbeddableContent]) -> Result<()> {
1055 let embeddings = self.generator.generate_batch(contents)?;
1056
1057 for (content, embedding) in contents.iter().zip(embeddings) {
1058 self.cache.insert(content, embedding);
1059 }
1060
1061 Ok(())
1062 }
1063
1064 pub fn build_vocabulary(&mut self, documents: &[String]) -> Result<()> {
1066 if let EmbeddingStrategy::TfIdf = self.strategy {
1067 if let Some(tfidf_gen) = self
1068 .generator
1069 .as_any_mut()
1070 .downcast_mut::<TfIdfEmbeddingGenerator>()
1071 {
1072 tfidf_gen.build_vocabulary(documents)?;
1073 }
1074 }
1075 Ok(())
1076 }
1077
1078 pub fn dimensions(&self) -> usize {
1079 self.generator.dimensions()
1080 }
1081
1082 pub fn cache_stats(&self) -> (usize, usize) {
1083 (self.cache.size(), self.cache.max_size)
1084 }
1085}
1086
1087pub trait AsAny {
1089 fn as_any(&self) -> &dyn std::any::Any;
1090 fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
1091}
1092
1093impl AsAny for TfIdfEmbeddingGenerator {
1094 fn as_any(&self) -> &dyn std::any::Any {
1095 self
1096 }
1097
1098 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
1099 self
1100 }
1101}
1102
1103impl AsAny for SentenceTransformerGenerator {
1104 fn as_any(&self) -> &dyn std::any::Any {
1105 self
1106 }
1107
1108 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
1109 self
1110 }
1111}
1112
1113pub struct OpenAIEmbeddingGenerator {
1115 config: EmbeddingConfig,
1116 openai_config: OpenAIConfig,
1117 client: reqwest::Client,
1118 rate_limiter: RateLimiter,
1119 request_cache: std::sync::Arc<std::sync::Mutex<lru::LruCache<u64, CachedEmbedding>>>,
1120 metrics: OpenAIMetrics,
1121}
1122
1123#[derive(Debug, Clone)]
1125pub struct CachedEmbedding {
1126 pub vector: Vector,
1127 pub cached_at: std::time::SystemTime,
1128 pub model: String,
1129 pub cost_usd: f64,
1130}
1131
1132#[derive(Debug, Clone, Default)]
1134pub struct OpenAIMetrics {
1135 pub total_requests: u64,
1136 pub successful_requests: u64,
1137 pub failed_requests: u64,
1138 pub total_tokens_processed: u64,
1139 pub cache_hits: u64,
1140 pub cache_misses: u64,
1141 pub total_cost_usd: f64,
1142 pub retry_count: u64,
1143 pub rate_limit_waits: u64,
1144 pub average_response_time_ms: f64,
1145 pub last_request_time: Option<std::time::SystemTime>,
1146 pub requests_by_model: HashMap<String, u64>,
1147 pub errors_by_type: HashMap<String, u64>,
1148}
1149
1150impl OpenAIMetrics {
1151 pub fn cache_hit_ratio(&self) -> f64 {
1153 if self.cache_hits + self.cache_misses == 0 {
1154 0.0
1155 } else {
1156 self.cache_hits as f64 / (self.cache_hits + self.cache_misses) as f64
1157 }
1158 }
1159
1160 pub fn success_rate(&self) -> f64 {
1162 if self.total_requests == 0 {
1163 0.0
1164 } else {
1165 self.successful_requests as f64 / self.total_requests as f64
1166 }
1167 }
1168
1169 pub fn average_cost_per_request(&self) -> f64 {
1171 if self.successful_requests == 0 {
1172 0.0
1173 } else {
1174 self.total_cost_usd / self.successful_requests as f64
1175 }
1176 }
1177
1178 pub fn report(&self) -> String {
1180 format!(
1181 "OpenAI Metrics Report:\n\
1182 Total Requests: {}\n\
1183 Success Rate: {:.2}%\n\
1184 Cache Hit Ratio: {:.2}%\n\
1185 Total Cost: ${:.4}\n\
1186 Avg Cost/Request: ${:.6}\n\
1187 Avg Response Time: {:.2}ms\n\
1188 Retries: {}\n\
1189 Rate Limit Waits: {}",
1190 self.total_requests,
1191 self.success_rate() * 100.0,
1192 self.cache_hit_ratio() * 100.0,
1193 self.total_cost_usd,
1194 self.average_cost_per_request(),
1195 self.average_response_time_ms,
1196 self.retry_count,
1197 self.rate_limit_waits
1198 )
1199 }
1200}
1201
1202pub struct RateLimiter {
1204 requests_per_minute: u32,
1205 request_times: std::collections::VecDeque<std::time::Instant>,
1206}
1207
1208impl RateLimiter {
1209 pub fn new(requests_per_minute: u32) -> Self {
1210 Self {
1211 requests_per_minute,
1212 request_times: std::collections::VecDeque::new(),
1213 }
1214 }
1215
1216 pub async fn wait_if_needed(&mut self) {
1217 let now = std::time::Instant::now();
1218 let minute_ago = now - std::time::Duration::from_secs(60);
1219
1220 while let Some(&front_time) = self.request_times.front() {
1222 if front_time < minute_ago {
1223 self.request_times.pop_front();
1224 } else {
1225 break;
1226 }
1227 }
1228
1229 if self.request_times.len() >= self.requests_per_minute as usize {
1231 if let Some(&oldest) = self.request_times.front() {
1232 let wait_time = oldest + std::time::Duration::from_secs(60) - now;
1233 if !wait_time.is_zero() {
1234 tokio::time::sleep(wait_time).await;
1235 }
1236 }
1237 }
1238
1239 self.request_times.push_back(now);
1240 }
1241}
1242
1243impl OpenAIEmbeddingGenerator {
1244 pub fn new(openai_config: OpenAIConfig) -> Result<Self> {
1245 openai_config.validate()?;
1246
1247 let client = reqwest::Client::builder()
1248 .timeout(std::time::Duration::from_secs(
1249 openai_config.timeout_seconds,
1250 ))
1251 .user_agent(&openai_config.user_agent)
1252 .build()
1253 .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
1254
1255 let embedding_config = EmbeddingConfig {
1256 model_name: openai_config.model.clone(),
1257 dimensions: Self::get_model_dimensions(&openai_config.model),
1258 max_sequence_length: 8191, normalize: true,
1260 };
1261
1262 let cache_size = if openai_config.enable_cache {
1263 std::num::NonZeroUsize::new(openai_config.cache_size)
1264 .unwrap_or(std::num::NonZeroUsize::new(1000).unwrap())
1265 } else {
1266 std::num::NonZeroUsize::new(1).unwrap()
1267 };
1268
1269 Ok(Self {
1270 config: embedding_config,
1271 openai_config: openai_config.clone(),
1272 client,
1273 rate_limiter: RateLimiter::new(openai_config.requests_per_minute),
1274 request_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(
1275 cache_size,
1276 ))),
1277 metrics: OpenAIMetrics::default(),
1278 })
1279 }
1280
1281 fn get_model_dimensions(model: &str) -> usize {
1283 match model {
1284 "text-embedding-ada-002" => 1536,
1285 "text-embedding-3-small" => 1536,
1286 "text-embedding-3-large" => 3072,
1287 "text-embedding-004" => 1536,
1288 _ => 1536, }
1290 }
1291
1292 fn get_model_cost_per_1k_tokens(model: &str) -> f64 {
1294 match model {
1295 "text-embedding-ada-002" => 0.0001,
1296 "text-embedding-3-small" => 0.00002,
1297 "text-embedding-3-large" => 0.00013,
1298 "text-embedding-004" => 0.00002,
1299 _ => 0.0001, }
1301 }
1302
1303 fn calculate_cost(&self, texts: &[String]) -> f64 {
1305 if !self.openai_config.track_costs {
1306 return 0.0;
1307 }
1308
1309 let total_tokens: usize = texts.iter().map(|t| t.len() / 4).sum(); let cost_per_1k = Self::get_model_cost_per_1k_tokens(&self.openai_config.model);
1311 (total_tokens as f64 / 1000.0) * cost_per_1k
1312 }
1313
1314 fn is_cache_valid(&self, cached: &CachedEmbedding) -> bool {
1316 if self.openai_config.cache_ttl_seconds == 0 {
1317 return true; }
1319
1320 let elapsed = cached
1321 .cached_at
1322 .elapsed()
1323 .unwrap_or(std::time::Duration::from_secs(u64::MAX));
1324
1325 elapsed.as_secs() < self.openai_config.cache_ttl_seconds
1326 }
1327
1328 async fn make_request(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
1330 let start_time = std::time::Instant::now();
1331 let mut attempts = 0;
1332
1333 while attempts < self.openai_config.max_retries {
1334 match self.try_request(texts).await {
1335 Ok(embeddings) => {
1336 if self.openai_config.enable_metrics {
1338 let response_time = start_time.elapsed().as_millis() as f64;
1339 self.update_response_time(response_time);
1340
1341 let cost = self.calculate_cost(texts);
1342 self.metrics.total_cost_usd += cost;
1343
1344 *self
1345 .metrics
1346 .requests_by_model
1347 .entry(self.openai_config.model.clone())
1348 .or_insert(0) += 1;
1349 }
1350
1351 return Ok(embeddings);
1352 }
1353 Err(e) => {
1354 attempts += 1;
1355 self.metrics.retry_count += 1;
1356
1357 let error_type = if e.to_string().contains("rate_limit") {
1359 "rate_limit"
1360 } else if e.to_string().contains("timeout") {
1361 "timeout"
1362 } else if e.to_string().contains("401") {
1363 "unauthorized"
1364 } else if e.to_string().contains("400") {
1365 "bad_request"
1366 } else {
1367 "other"
1368 };
1369
1370 *self
1371 .metrics
1372 .errors_by_type
1373 .entry(error_type.to_string())
1374 .or_insert(0) += 1;
1375
1376 if attempts >= self.openai_config.max_retries {
1377 return Err(e);
1378 }
1379
1380 let delay = self.calculate_retry_delay(attempts);
1382 tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
1383 }
1384 }
1385 }
1386
1387 Err(anyhow!("Max retries exceeded"))
1388 }
1389
1390 fn calculate_retry_delay(&self, attempt: u32) -> u64 {
1392 let base_delay = self.openai_config.retry_delay_ms;
1393
1394 match self.openai_config.retry_strategy {
1395 RetryStrategy::Fixed => base_delay,
1396 RetryStrategy::LinearBackoff => base_delay * attempt as u64,
1397 RetryStrategy::ExponentialBackoff => {
1398 let delay = base_delay * (2_u64.pow(attempt - 1));
1399 let jitter = {
1401 use scirs2_core::random::{Random, Rng};
1402 let mut rng = Random::seed(42);
1403 (delay as f64 * 0.25 * (rng.gen_range(0.0..1.0) - 0.5)) as u64
1404 };
1405 delay.saturating_add(jitter).min(30000) }
1407 }
1408 }
1409
1410 fn update_response_time(&mut self, response_time_ms: f64) {
1412 if self.metrics.successful_requests == 0 {
1413 self.metrics.average_response_time_ms = response_time_ms;
1414 } else {
1415 let total =
1417 self.metrics.average_response_time_ms * self.metrics.successful_requests as f64;
1418 self.metrics.average_response_time_ms =
1419 (total + response_time_ms) / (self.metrics.successful_requests + 1) as f64;
1420 }
1421 }
1422
1423 async fn try_request(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
1424 self.rate_limiter.wait_if_needed().await;
1425
1426 let request_body = serde_json::json!({
1427 "model": self.openai_config.model,
1428 "input": texts,
1429 "encoding_format": "float"
1430 });
1431
1432 let response = self
1433 .client
1434 .post(format!("{}/embeddings", self.openai_config.base_url))
1435 .header(
1436 "Authorization",
1437 format!("Bearer {}", self.openai_config.api_key),
1438 )
1439 .header("Content-Type", "application/json")
1440 .json(&request_body)
1441 .send()
1442 .await
1443 .map_err(|e| anyhow!("Request failed: {}", e))?;
1444
1445 if !response.status().is_success() {
1446 let status = response.status();
1447 let error_text = response.text().await.unwrap_or_default();
1448 return Err(anyhow!(
1449 "API request failed with status {}: {}",
1450 status,
1451 error_text
1452 ));
1453 }
1454
1455 let response_data: serde_json::Value = response
1456 .json()
1457 .await
1458 .map_err(|e| anyhow!("Failed to parse response: {}", e))?;
1459
1460 let embeddings_data = response_data["data"]
1461 .as_array()
1462 .ok_or_else(|| anyhow!("Invalid response format: missing data array"))?;
1463
1464 let mut embeddings = Vec::new();
1465 for item in embeddings_data {
1466 let embedding = item["embedding"]
1467 .as_array()
1468 .ok_or_else(|| anyhow!("Invalid response format: missing embedding"))?;
1469
1470 let vec: Result<Vec<f32>, _> = embedding
1471 .iter()
1472 .map(|v| {
1473 v.as_f64()
1474 .ok_or_else(|| anyhow!("Invalid embedding value"))
1475 .map(|f| f as f32)
1476 })
1477 .collect();
1478
1479 embeddings.push(vec?);
1480 }
1481
1482 Ok(embeddings)
1483 }
1484
1485 pub async fn generate_async(&mut self, content: &EmbeddableContent) -> Result<Vector> {
1487 let text = content.to_text();
1488
1489 if self.openai_config.enable_cache {
1491 let hash = content.content_hash();
1492
1493 let cached_vector = match self.request_cache.lock() {
1495 Ok(mut cache) => {
1496 if let Some(cached) = cache.get(&hash) {
1497 let is_valid = cached.cached_at.elapsed().unwrap_or_default()
1498 < Duration::from_secs(self.openai_config.cache_ttl_seconds);
1499 if is_valid {
1500 Some(cached.vector.clone())
1501 } else {
1502 None
1503 }
1504 } else {
1505 None
1506 }
1507 }
1508 _ => None,
1509 };
1510
1511 if let Some(result) = cached_vector {
1512 self.update_cache_hit();
1513 return Ok(result);
1514 } else {
1515 if let Ok(mut cache) = self.request_cache.lock() {
1517 cache.pop(&hash);
1518 }
1519 self.update_cache_miss();
1520 }
1521 }
1522
1523 let embeddings = match self.make_request(&[text.clone()]).await {
1524 Ok(embeddings) => {
1525 self.update_metrics_success(&[text.clone()]);
1526 embeddings
1527 }
1528 Err(e) => {
1529 self.update_metrics_failure();
1530 return Err(e);
1531 }
1532 };
1533
1534 if embeddings.is_empty() {
1535 self.update_metrics_failure();
1536 return Err(anyhow!("No embeddings returned from API"));
1537 }
1538
1539 let vector = Vector::new(embeddings[0].clone());
1540
1541 if self.openai_config.enable_cache {
1543 let hash = content.content_hash();
1544 let cost = self.calculate_cost(&[text.clone()]);
1545 let cached_embedding = CachedEmbedding {
1546 vector: vector.clone(),
1547 cached_at: std::time::SystemTime::now(),
1548 model: self.openai_config.model.clone(),
1549 cost_usd: cost,
1550 };
1551 if let Ok(mut cache) = self.request_cache.lock() {
1552 cache.put(hash, cached_embedding);
1553 }
1554 }
1555
1556 Ok(vector)
1557 }
1558
1559 pub async fn generate_batch_async(
1561 &mut self,
1562 contents: &[EmbeddableContent],
1563 ) -> Result<Vec<Vector>> {
1564 if contents.is_empty() {
1565 return Ok(Vec::new());
1566 }
1567
1568 let mut results = Vec::with_capacity(contents.len());
1569 let batch_size = self.openai_config.batch_size;
1570
1571 for chunk in contents.chunks(batch_size) {
1572 let texts: Vec<String> = chunk.iter().map(|c| c.to_text()).collect();
1573
1574 let embeddings = match self.make_request(&texts).await {
1575 Ok(embeddings) => {
1576 self.update_metrics_success(&texts);
1577 embeddings
1578 }
1579 Err(e) => {
1580 self.update_metrics_failure();
1581 return Err(e);
1582 }
1583 };
1584
1585 if embeddings.len() != chunk.len() {
1586 self.update_metrics_failure();
1587 return Err(anyhow!("Mismatch between request and response sizes"));
1588 }
1589
1590 let batch_cost = self.calculate_cost(&texts) / chunk.len() as f64;
1591
1592 for (content, embedding) in chunk.iter().zip(embeddings) {
1593 let vector = Vector::new(embedding);
1594
1595 if self.openai_config.enable_cache {
1597 let hash = content.content_hash();
1598 let cached_embedding = CachedEmbedding {
1599 vector: vector.clone(),
1600 cached_at: std::time::SystemTime::now(),
1601 model: self.openai_config.model.clone(),
1602 cost_usd: batch_cost,
1603 };
1604 if let Ok(mut cache) = self.request_cache.lock() {
1605 cache.put(hash, cached_embedding);
1606 }
1607 }
1608
1609 results.push(vector);
1610 }
1611 }
1612
1613 Ok(results)
1614 }
1615
1616 pub fn clear_cache(&mut self) {
1618 if let Ok(mut cache) = self.request_cache.lock() {
1619 cache.clear();
1620 }
1621 }
1622
1623 pub fn cache_stats(&self) -> (usize, Option<usize>) {
1625 match self.request_cache.lock() {
1626 Ok(cache) => (cache.len(), Some(cache.cap().into())),
1627 _ => (0, None),
1628 }
1629 }
1630
1631 pub fn get_cache_cost(&self) -> f64 {
1633 match self.request_cache.lock() {
1634 Ok(cache) => cache.iter().map(|(_, cached)| cached.cost_usd).sum(),
1635 _ => 0.0,
1636 }
1637 }
1638
1639 pub fn get_metrics(&self) -> &OpenAIMetrics {
1641 &self.metrics
1642 }
1643
1644 pub fn reset_metrics(&mut self) {
1646 self.metrics = OpenAIMetrics::default();
1647 }
1648
1649 fn estimate_tokens(&self, text: &str) -> u64 {
1651 (text.len() / 4).max(1) as u64
1654 }
1655
1656 fn calculate_cost_from_tokens(&self, total_tokens: u64) -> f64 {
1658 let cost_per_1k_tokens = match self.openai_config.model.as_str() {
1660 "text-embedding-ada-002" => 0.0001, "text-embedding-3-small" => 0.00002, "text-embedding-3-large" => 0.00013, _ => 0.0001, };
1665
1666 (total_tokens as f64 / 1000.0) * cost_per_1k_tokens
1667 }
1668
1669 fn update_metrics_success(&mut self, texts: &[String]) {
1671 self.metrics.total_requests += 1;
1672 self.metrics.successful_requests += 1;
1673
1674 let total_tokens: u64 = texts.iter().map(|text| self.estimate_tokens(text)).sum();
1675
1676 self.metrics.total_tokens_processed += total_tokens;
1677 self.metrics.total_cost_usd += self.calculate_cost_from_tokens(total_tokens);
1678 }
1679
1680 fn update_metrics_failure(&mut self) {
1682 self.metrics.total_requests += 1;
1683 self.metrics.failed_requests += 1;
1684 }
1685
1686 fn update_cache_hit(&mut self) {
1688 self.metrics.cache_hits += 1;
1689 }
1690
1691 fn update_cache_miss(&mut self) {
1692 self.metrics.cache_misses += 1;
1693 }
1694}
1695
1696impl EmbeddingGenerator for OpenAIEmbeddingGenerator {
1697 fn generate(&self, content: &EmbeddableContent) -> Result<Vector> {
1698 if self.openai_config.enable_cache {
1700 let hash = content.content_hash();
1701 if let Ok(mut cache) = self.request_cache.lock() {
1702 if let Some(cached) = cache.get(&hash) {
1703 return Ok(cached.vector.clone());
1704 }
1705 }
1706 }
1707
1708 let rt = tokio::runtime::Runtime::new()
1711 .map_err(|e| anyhow!("Failed to create async runtime: {}", e))?;
1712
1713 let mut temp_generator = OpenAIEmbeddingGenerator {
1715 config: self.config.clone(),
1716 openai_config: self.openai_config.clone(),
1717 client: self.client.clone(),
1718 rate_limiter: RateLimiter::new(self.openai_config.requests_per_minute),
1719 request_cache: self.request_cache.clone(),
1720 metrics: self.metrics.clone(),
1721 };
1722
1723 rt.block_on(temp_generator.generate_async(content))
1724 }
1725
1726 fn dimensions(&self) -> usize {
1727 self.config.dimensions
1728 }
1729
1730 fn config(&self) -> &EmbeddingConfig {
1731 &self.config
1732 }
1733}
1734
1735impl AsAny for OpenAIEmbeddingGenerator {
1736 fn as_any(&self) -> &dyn std::any::Any {
1737 self
1738 }
1739
1740 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
1741 self
1742 }
1743}
1744
1745#[cfg(test)]
1747pub struct MockEmbeddingGenerator {
1748 config: EmbeddingConfig,
1749}
1750
1751#[cfg(test)]
1752impl Default for MockEmbeddingGenerator {
1753 fn default() -> Self {
1754 Self::new()
1755 }
1756}
1757
1758#[cfg(test)]
1759impl MockEmbeddingGenerator {
1760 pub fn new() -> Self {
1761 Self {
1762 config: EmbeddingConfig {
1763 dimensions: 128,
1764 ..Default::default()
1765 },
1766 }
1767 }
1768
1769 pub fn with_dimensions(dimensions: usize) -> Self {
1770 Self {
1771 config: EmbeddingConfig {
1772 dimensions,
1773 ..Default::default()
1774 },
1775 }
1776 }
1777}
1778
1779#[cfg(test)]
1780impl AsAny for MockEmbeddingGenerator {
1781 fn as_any(&self) -> &dyn std::any::Any {
1782 self
1783 }
1784
1785 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
1786 self
1787 }
1788}
1789
1790#[cfg(test)]
1791impl EmbeddingGenerator for MockEmbeddingGenerator {
1792 fn generate(&self, content: &EmbeddableContent) -> Result<crate::Vector> {
1793 let text = content.to_text();
1794
1795 let mut hasher = std::collections::hash_map::DefaultHasher::new();
1797 text.hash(&mut hasher);
1798 let hash = hasher.finish();
1799
1800 let mut embedding = Vec::with_capacity(self.config.dimensions);
1801 let mut seed = hash;
1802
1803 for _ in 0..self.config.dimensions {
1804 seed = seed.wrapping_mul(1664525).wrapping_add(1013904223);
1806 let value = (seed as f64 / u64::MAX as f64) as f32;
1807 embedding.push(value * 2.0 - 1.0); }
1809
1810 let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
1812 if magnitude > 0.0 {
1813 for value in &mut embedding {
1814 *value /= magnitude;
1815 }
1816 }
1817
1818 Ok(crate::Vector::new(embedding))
1819 }
1820
1821 fn dimensions(&self) -> usize {
1822 self.config.dimensions
1823 }
1824
1825 fn config(&self) -> &EmbeddingConfig {
1826 &self.config
1827 }
1828}
1829
1830#[cfg(test)]
1831mod tests {
1832 use super::*;
1833
1834 #[test]
1835 fn test_transformer_model_types() {
1836 let config = EmbeddingConfig::default();
1837
1838 let bert = SentenceTransformerGenerator::new(config.clone());
1840 assert!(matches!(bert.model_type(), TransformerModelType::BERT));
1841 assert_eq!(bert.dimensions(), 384); let roberta = SentenceTransformerGenerator::roberta(config.clone());
1845 assert!(matches!(
1846 roberta.model_type(),
1847 TransformerModelType::RoBERTa
1848 ));
1849
1850 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
1852 assert!(matches!(
1853 distilbert.model_type(),
1854 TransformerModelType::DistilBERT
1855 ));
1856 assert_eq!(distilbert.dimensions(), 384); let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
1860 assert!(matches!(
1861 multibert.model_type(),
1862 TransformerModelType::MultiBERT
1863 ));
1864 }
1865
1866 #[test]
1867 fn test_model_details() {
1868 let config = EmbeddingConfig::default();
1869
1870 let bert = SentenceTransformerGenerator::new(config.clone());
1872 let bert_details = bert.model_details();
1873 assert_eq!(bert_details.vocab_size, 30522);
1874 assert_eq!(bert_details.num_layers, 12);
1875 assert_eq!(bert_details.hidden_size, 768);
1876 assert!(bert_details.supports_languages.contains(&"en".to_string()));
1877
1878 let roberta = SentenceTransformerGenerator::roberta(config.clone());
1880 let roberta_details = roberta.model_details();
1881 assert_eq!(roberta_details.vocab_size, 50265); assert_eq!(roberta_details.max_position_embeddings, 514); let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
1886 let distilbert_details = distilbert.model_details();
1887 assert_eq!(distilbert_details.num_layers, 6); assert_eq!(distilbert_details.hidden_size, 384); assert!(distilbert_details.model_size_mb < bert_details.model_size_mb); assert!(
1891 distilbert_details.typical_inference_time_ms < bert_details.typical_inference_time_ms
1892 ); let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
1896 let multibert_details = multibert.model_details();
1897 assert_eq!(multibert_details.vocab_size, 120000); assert!(multibert_details.supports_languages.len() > 10); assert!(multibert_details
1900 .supports_languages
1901 .contains(&"zh".to_string())); assert!(multibert_details
1903 .supports_languages
1904 .contains(&"de".to_string())); }
1906
1907 #[test]
1908 fn test_language_support() {
1909 let config = EmbeddingConfig::default();
1910
1911 let bert = SentenceTransformerGenerator::new(config.clone());
1913 assert!(bert.supports_language("en"));
1914 assert!(!bert.supports_language("zh"));
1915 assert!(!bert.supports_language("de"));
1916
1917 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
1918 assert!(distilbert.supports_language("en"));
1919 assert!(!distilbert.supports_language("zh"));
1920
1921 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
1923 assert!(multibert.supports_language("en"));
1924 assert!(multibert.supports_language("zh"));
1925 assert!(multibert.supports_language("de"));
1926 assert!(multibert.supports_language("fr"));
1927 assert!(multibert.supports_language("es"));
1928 assert!(!multibert.supports_language("unknown_lang"));
1929 }
1930
1931 #[test]
1932 fn test_efficiency_ratings() {
1933 let config = EmbeddingConfig::default();
1934
1935 let bert = SentenceTransformerGenerator::new(config.clone());
1936 let roberta = SentenceTransformerGenerator::roberta(config.clone());
1937 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
1938 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
1939
1940 assert!(distilbert.efficiency_rating() > bert.efficiency_rating());
1942 assert!(distilbert.efficiency_rating() > roberta.efficiency_rating());
1943 assert!(distilbert.efficiency_rating() > multibert.efficiency_rating());
1944
1945 assert!(bert.efficiency_rating() > roberta.efficiency_rating());
1947
1948 assert!(bert.efficiency_rating() > multibert.efficiency_rating());
1950 assert!(roberta.efficiency_rating() > multibert.efficiency_rating());
1951 }
1952
1953 #[test]
1954 fn test_inference_time_estimation() {
1955 let config = EmbeddingConfig::default();
1956 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
1957 let bert = SentenceTransformerGenerator::new(config.clone());
1958
1959 let short_time_distilbert = distilbert.estimate_inference_time(50);
1961 let short_time_bert = bert.estimate_inference_time(50);
1962
1963 let long_time_distilbert = distilbert.estimate_inference_time(500);
1965 let long_time_bert = bert.estimate_inference_time(500);
1966
1967 assert!(short_time_distilbert < short_time_bert);
1969 assert!(long_time_distilbert < long_time_bert);
1970
1971 assert!(long_time_distilbert > short_time_distilbert);
1973 assert!(long_time_bert > short_time_bert);
1974 }
1975
1976 #[test]
1977 fn test_model_specific_text_preprocessing() {
1978 let config = EmbeddingConfig::default();
1979
1980 let bert = SentenceTransformerGenerator::new(config.clone());
1981 let roberta = SentenceTransformerGenerator::roberta(config.clone());
1982 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
1983
1984 let text = "Hello World";
1985
1986 let bert_processed = bert.preprocess_text_for_model(text, 512).unwrap();
1988 assert!(bert_processed.contains("[CLS]"));
1989 assert!(bert_processed.contains("[SEP]"));
1990 assert!(bert_processed.contains("hello world")); let roberta_processed = roberta.preprocess_text_for_model(text, 512).unwrap();
1994 assert!(roberta_processed.contains("<s>"));
1995 assert!(roberta_processed.contains("</s>"));
1996 assert!(roberta_processed.contains("Hello World")); let latin_text = "Hello World";
2000 let chinese_text = "你好世界";
2001
2002 let latin_processed = multibert
2003 .preprocess_text_for_model(latin_text, 512)
2004 .unwrap();
2005 let chinese_processed = multibert
2006 .preprocess_text_for_model(chinese_text, 512)
2007 .unwrap();
2008
2009 assert!(latin_processed.contains("hello world")); assert!(chinese_processed.contains("你好世界")); }
2012
2013 #[test]
2014 fn test_embedding_generation_differences() {
2015 let config = EmbeddingConfig::default();
2016
2017 let bert = SentenceTransformerGenerator::new(config.clone());
2018 let roberta = SentenceTransformerGenerator::roberta(config.clone());
2019 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
2020
2021 let content = EmbeddableContent::Text("This is a test sentence".to_string());
2022
2023 let bert_embedding = bert.generate(&content).unwrap();
2024 let roberta_embedding = roberta.generate(&content).unwrap();
2025 let distilbert_embedding = distilbert.generate(&content).unwrap();
2026
2027 assert_ne!(bert_embedding.as_f32(), roberta_embedding.as_f32());
2029 assert_ne!(bert_embedding.as_f32(), distilbert_embedding.as_f32());
2030 assert_ne!(roberta_embedding.as_f32(), distilbert_embedding.as_f32());
2031
2032 assert_eq!(distilbert_embedding.dimensions, 384);
2034 assert_eq!(bert_embedding.dimensions, 384); assert_eq!(roberta_embedding.dimensions, 384);
2036
2037 if config.normalize {
2039 let bert_magnitude: f32 = bert_embedding
2040 .as_f32()
2041 .iter()
2042 .map(|x| x * x)
2043 .sum::<f32>()
2044 .sqrt();
2045 let roberta_magnitude: f32 = roberta_embedding
2046 .as_f32()
2047 .iter()
2048 .map(|x| x * x)
2049 .sum::<f32>()
2050 .sqrt();
2051 let distilbert_magnitude: f32 = distilbert_embedding
2052 .as_f32()
2053 .iter()
2054 .map(|x| x * x)
2055 .sum::<f32>()
2056 .sqrt();
2057
2058 assert!((bert_magnitude - 1.0).abs() < 0.1);
2059 assert!((roberta_magnitude - 1.0).abs() < 0.1);
2060 assert!((distilbert_magnitude - 1.0).abs() < 0.1);
2061 }
2062 }
2063
2064 #[test]
2065 fn test_tokenization_differences() {
2066 let config = EmbeddingConfig::default();
2067
2068 let bert = SentenceTransformerGenerator::new(config.clone());
2069 let roberta = SentenceTransformerGenerator::roberta(config.clone());
2070 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
2071
2072 let model_details_bert = bert.get_model_details();
2073 let model_details_roberta = roberta.get_model_details();
2074 let model_details_multibert = multibert.get_model_details();
2075
2076 let complex_word = "preprocessing";
2077
2078 let bert_tokens =
2080 bert.simulate_wordpiece_tokenization(complex_word, model_details_bert.vocab_size);
2081 let roberta_tokens =
2082 roberta.simulate_bpe_tokenization(complex_word, model_details_roberta.vocab_size);
2083 let multibert_tokens = multibert
2084 .simulate_multilingual_tokenization(complex_word, model_details_multibert.vocab_size);
2085
2086 assert!(roberta_tokens.len() >= bert_tokens.len());
2088
2089 assert!(multibert_tokens.len() <= bert_tokens.len());
2091
2092 for token in &bert_tokens {
2094 assert!(*token < model_details_bert.vocab_size as u32);
2095 }
2096 for token in &roberta_tokens {
2097 assert!(*token < model_details_roberta.vocab_size as u32);
2098 }
2099 for token in &multibert_tokens {
2100 assert!(*token < model_details_multibert.vocab_size as u32);
2101 }
2102 }
2103
2104 #[test]
2105 fn test_model_size_comparisons() {
2106 let config = EmbeddingConfig::default();
2107
2108 let bert = SentenceTransformerGenerator::new(config.clone());
2109 let roberta = SentenceTransformerGenerator::roberta(config.clone());
2110 let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
2111 let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
2112
2113 let bert_size = bert.model_size_mb();
2114 let roberta_size = roberta.model_size_mb();
2115 let distilbert_size = distilbert.model_size_mb();
2116 let multibert_size = multibert.model_size_mb();
2117
2118 assert!(distilbert_size < bert_size);
2120 assert!(distilbert_size < roberta_size);
2121 assert!(distilbert_size < multibert_size);
2122
2123 assert!(multibert_size > bert_size);
2125 assert!(multibert_size > roberta_size);
2126 assert!(multibert_size > distilbert_size);
2127
2128 assert!(roberta_size > bert_size);
2130 }
2131}