oxirs_vec/
embeddings.rs

1//! Embedding generation and management for RDF resources and text content
2
3use crate::Vector;
4use anyhow::{anyhow, Result};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::hash::{Hash, Hasher};
8use std::time::Duration;
9// AsAny trait will be defined locally
10
11/// Embedding model configuration
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct EmbeddingConfig {
14    pub model_name: String,
15    pub dimensions: usize,
16    pub max_sequence_length: usize,
17    pub normalize: bool,
18}
19
20impl Default for EmbeddingConfig {
21    fn default() -> Self {
22        Self {
23            model_name: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
24            dimensions: 384,
25            max_sequence_length: 512,
26            normalize: true,
27        }
28    }
29}
30
31/// Content to be embedded
32#[derive(Debug, Clone)]
33pub enum EmbeddableContent {
34    /// Plain text content
35    Text(String),
36    /// RDF resource with properties
37    RdfResource {
38        uri: String,
39        label: Option<String>,
40        description: Option<String>,
41        properties: HashMap<String, Vec<String>>,
42    },
43    /// SPARQL query or query fragment
44    SparqlQuery(String),
45    /// Knowledge graph path or pattern
46    GraphPattern(String),
47}
48
49impl EmbeddableContent {
50    /// Convert content to text representation for embedding
51    pub fn to_text(&self) -> String {
52        match self {
53            EmbeddableContent::Text(text) => text.clone(),
54            EmbeddableContent::RdfResource {
55                uri,
56                label,
57                description,
58                properties,
59            } => {
60                let mut text_parts = vec![uri.clone()];
61
62                if let Some(label) = label {
63                    text_parts.push(format!("label: {label}"));
64                }
65
66                if let Some(desc) = description {
67                    text_parts.push(format!("description: {desc}"));
68                }
69
70                for (prop, values) in properties {
71                    text_parts.push(format!("{prop}: {}", values.join(", ")));
72                }
73
74                text_parts.join(" ")
75            }
76            EmbeddableContent::SparqlQuery(query) => query.clone(),
77            EmbeddableContent::GraphPattern(pattern) => pattern.clone(),
78        }
79    }
80
81    /// Get a unique identifier for this content
82    pub fn content_hash(&self) -> u64 {
83        let mut hasher = std::collections::hash_map::DefaultHasher::new();
84        self.to_text().hash(&mut hasher);
85        hasher.finish()
86    }
87}
88
89/// Embedding generation strategy
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub enum EmbeddingStrategy {
92    /// Simple TF-IDF based embeddings (for testing/fallback)
93    TfIdf,
94    /// Sentence transformer embeddings (requires external service)
95    SentenceTransformer,
96    /// BERT-based transformer models
97    Transformer(TransformerModelType),
98    /// Word2Vec embeddings
99    Word2Vec(crate::word2vec::Word2VecConfig),
100    /// OpenAI embeddings (requires API key)
101    OpenAI(OpenAIConfig),
102    /// Custom embedding model
103    Custom(String),
104}
105
106/// OpenAI embeddings configuration
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct OpenAIConfig {
109    /// API key for OpenAI service
110    pub api_key: String,
111    /// Model to use (e.g., "text-embedding-ada-002", "text-embedding-3-small")
112    pub model: String,
113    /// Base URL for API calls (default: `https://api.openai.com/v1`)
114    pub base_url: String,
115    /// Request timeout in seconds
116    pub timeout_seconds: u64,
117    /// Rate limiting: requests per minute
118    pub requests_per_minute: u32,
119    /// Batch size for batch processing
120    pub batch_size: usize,
121    /// Enable local caching
122    pub enable_cache: bool,
123    /// Cache size (number of embeddings to cache)
124    pub cache_size: usize,
125    /// Cache TTL in seconds (0 for no expiration)
126    pub cache_ttl_seconds: u64,
127    /// Maximum retries for failed requests
128    pub max_retries: u32,
129    /// Retry delay in milliseconds
130    pub retry_delay_ms: u64,
131    /// Retry strategy
132    pub retry_strategy: RetryStrategy,
133    /// Enable cost tracking
134    pub track_costs: bool,
135    /// Enable detailed metrics
136    pub enable_metrics: bool,
137    /// User agent for requests
138    pub user_agent: String,
139}
140
141/// Retry strategy for failed requests
142#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
143pub enum RetryStrategy {
144    /// Fixed delay between retries
145    Fixed,
146    /// Exponential backoff with jitter
147    ExponentialBackoff,
148    /// Linear backoff
149    LinearBackoff,
150}
151
152impl Default for OpenAIConfig {
153    fn default() -> Self {
154        Self {
155            api_key: std::env::var("OPENAI_API_KEY").unwrap_or_default(),
156            model: "text-embedding-3-small".to_string(),
157            base_url: "https://api.openai.com/v1".to_string(),
158            timeout_seconds: 30,
159            requests_per_minute: 3000,
160            batch_size: 100,
161            enable_cache: true,
162            cache_size: 10000,
163            cache_ttl_seconds: 3600, // 1 hour
164            max_retries: 3,
165            retry_delay_ms: 1000,
166            retry_strategy: RetryStrategy::ExponentialBackoff,
167            track_costs: true,
168            enable_metrics: true,
169            user_agent: "oxirs-vec/0.1.0".to_string(),
170        }
171    }
172}
173
174impl OpenAIConfig {
175    /// Create config for production use
176    pub fn production() -> Self {
177        Self {
178            requests_per_minute: 1000, // More conservative for production
179            cache_size: 50000,
180            cache_ttl_seconds: 7200, // 2 hours
181            max_retries: 5,
182            retry_strategy: RetryStrategy::ExponentialBackoff,
183            ..Default::default()
184        }
185    }
186
187    /// Create config for development/testing
188    pub fn development() -> Self {
189        Self {
190            requests_per_minute: 100,
191            cache_size: 1000,
192            cache_ttl_seconds: 300, // 5 minutes
193            max_retries: 2,
194            ..Default::default()
195        }
196    }
197
198    /// Validate configuration
199    pub fn validate(&self) -> Result<()> {
200        if self.api_key.is_empty() {
201            return Err(anyhow!("OpenAI API key is required"));
202        }
203        if self.requests_per_minute == 0 {
204            return Err(anyhow!("requests_per_minute must be greater than 0"));
205        }
206        if self.batch_size == 0 {
207            return Err(anyhow!("batch_size must be greater than 0"));
208        }
209        if self.timeout_seconds == 0 {
210            return Err(anyhow!("timeout_seconds must be greater than 0"));
211        }
212        Ok(())
213    }
214}
215
216/// Embedding generator trait
217pub trait EmbeddingGenerator: Send + Sync + AsAny {
218    /// Generate embedding for content
219    fn generate(&self, content: &EmbeddableContent) -> Result<Vector>;
220
221    /// Generate embeddings for multiple contents in batch
222    fn generate_batch(&self, contents: &[EmbeddableContent]) -> Result<Vec<Vector>> {
223        contents.iter().map(|c| self.generate(c)).collect()
224    }
225
226    /// Get the embedding dimensions
227    fn dimensions(&self) -> usize;
228
229    /// Get the model configuration
230    fn config(&self) -> &EmbeddingConfig;
231}
232
233/// Simple TF-IDF based embedding generator
234pub struct TfIdfEmbeddingGenerator {
235    config: EmbeddingConfig,
236    vocabulary: HashMap<String, usize>,
237    idf_scores: HashMap<String, f32>,
238}
239
240impl TfIdfEmbeddingGenerator {
241    pub fn new(config: EmbeddingConfig) -> Self {
242        Self {
243            config,
244            vocabulary: HashMap::new(),
245            idf_scores: HashMap::new(),
246        }
247    }
248
249    /// Build vocabulary from a corpus of documents
250    pub fn build_vocabulary(&mut self, documents: &[String]) -> Result<()> {
251        let mut word_counts: HashMap<String, usize> = HashMap::new();
252        let mut doc_counts: HashMap<String, usize> = HashMap::new();
253
254        for doc in documents {
255            let words: Vec<String> = self.tokenize(doc);
256            let unique_words: std::collections::HashSet<_> = words.iter().collect();
257
258            for word in &words {
259                *word_counts.entry(word.clone()).or_insert(0) += 1;
260            }
261
262            for word in unique_words {
263                *doc_counts.entry(word.clone()).or_insert(0) += 1;
264            }
265        }
266
267        // Build vocabulary with most frequent words
268        let mut word_freq: Vec<(String, usize)> = word_counts.into_iter().collect();
269        word_freq.sort_by(|a, b| b.1.cmp(&a.1));
270
271        self.vocabulary = word_freq
272            .into_iter()
273            .take(self.config.dimensions)
274            .enumerate()
275            .map(|(idx, (word, _))| (word, idx))
276            .collect();
277
278        // Calculate IDF scores
279        let total_docs = documents.len() as f32;
280        for word in self.vocabulary.keys() {
281            let doc_freq = doc_counts.get(word).unwrap_or(&0);
282            let idf = (total_docs / (*doc_freq as f32 + 1.0)).ln();
283            self.idf_scores.insert(word.clone(), idf);
284        }
285
286        Ok(())
287    }
288
289    fn tokenize(&self, text: &str) -> Vec<String> {
290        text.to_lowercase()
291            .split_whitespace()
292            .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
293            .filter(|s| !s.is_empty())
294            .map(String::from)
295            .collect()
296    }
297
298    fn calculate_tf_idf(&self, text: &str) -> Vector {
299        let words = self.tokenize(text);
300        let mut tf_counts: HashMap<String, usize> = HashMap::new();
301
302        for word in &words {
303            *tf_counts.entry(word.clone()).or_insert(0) += 1;
304        }
305
306        let total_words = words.len() as f32;
307        let mut embedding = vec![0.0; self.config.dimensions];
308
309        for (word, count) in tf_counts {
310            if let Some(&idx) = self.vocabulary.get(&word) {
311                let tf = count as f32 / total_words;
312                let idf = self.idf_scores.get(&word).unwrap_or(&0.0);
313                embedding[idx] = tf * idf;
314            }
315        }
316
317        if self.config.normalize {
318            self.normalize_vector(&mut embedding);
319        }
320
321        Vector::new(embedding)
322    }
323
324    fn normalize_vector(&self, vector: &mut [f32]) {
325        let magnitude: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
326        if magnitude > 0.0 {
327            for value in vector {
328                *value /= magnitude;
329            }
330        }
331    }
332}
333
334impl EmbeddingGenerator for TfIdfEmbeddingGenerator {
335    fn generate(&self, content: &EmbeddableContent) -> Result<Vector> {
336        if self.vocabulary.is_empty() {
337            return Err(anyhow!(
338                "Vocabulary not built. Call build_vocabulary first."
339            ));
340        }
341
342        let text = content.to_text();
343        Ok(self.calculate_tf_idf(&text))
344    }
345
346    fn dimensions(&self) -> usize {
347        self.config.dimensions
348    }
349
350    fn config(&self) -> &EmbeddingConfig {
351        &self.config
352    }
353}
354
355/// Transformer-based embedding generator supporting multiple models
356pub struct SentenceTransformerGenerator {
357    config: EmbeddingConfig,
358    model_type: TransformerModelType,
359}
360
361/// Supported transformer model types
362#[derive(Debug, Clone, Serialize, Deserialize, Default)]
363pub enum TransformerModelType {
364    /// Basic BERT-based model (already implemented)
365    #[default]
366    BERT,
367    /// RoBERTa model with improved training
368    RoBERTa,
369    /// DistilBERT for efficiency
370    DistilBERT,
371    /// Multilingual BERT
372    MultiBERT,
373    /// Custom model path
374    Custom(String),
375}
376
377/// Detailed information about a transformer model
378#[derive(Debug, Clone)]
379pub struct ModelDetails {
380    pub vocab_size: usize,
381    pub num_layers: usize,
382    pub num_attention_heads: usize,
383    pub hidden_size: usize,
384    pub intermediate_size: usize,
385    pub max_position_embeddings: usize,
386    pub supports_languages: Vec<String>,
387    pub model_size_mb: usize,
388    pub typical_inference_time_ms: u64,
389}
390
391impl SentenceTransformerGenerator {
392    pub fn new(config: EmbeddingConfig) -> Self {
393        Self {
394            config,
395            model_type: TransformerModelType::default(),
396        }
397    }
398
399    pub fn with_model_type(config: EmbeddingConfig, model_type: TransformerModelType) -> Self {
400        Self { config, model_type }
401    }
402
403    /// Create a new RoBERTa model generator
404    pub fn roberta(config: EmbeddingConfig) -> Self {
405        Self::with_model_type(config, TransformerModelType::RoBERTa)
406    }
407
408    /// Create a new DistilBERT model generator
409    pub fn distilbert(config: EmbeddingConfig) -> Self {
410        let adjusted_config = EmbeddingConfig {
411            dimensions: 384, // DistilBERT has smaller dimensions
412            ..config
413        };
414        Self::with_model_type(adjusted_config, TransformerModelType::DistilBERT)
415    }
416
417    /// Create a new multilingual BERT model generator
418    pub fn multilingual_bert(config: EmbeddingConfig) -> Self {
419        Self::with_model_type(config, TransformerModelType::MultiBERT)
420    }
421
422    /// Get the current model type
423    pub fn model_type(&self) -> &TransformerModelType {
424        &self.model_type
425    }
426
427    /// Get detailed information about the current model
428    pub fn model_details(&self) -> ModelDetails {
429        self.get_model_details()
430    }
431
432    /// Check if the model supports a specific language
433    pub fn supports_language(&self, language_code: &str) -> bool {
434        let details = self.get_model_details();
435        details
436            .supports_languages
437            .contains(&language_code.to_string())
438    }
439
440    /// Get the estimated inference time for a given text length
441    pub fn estimate_inference_time(&self, text_length: usize) -> u64 {
442        let details = self.get_model_details();
443        let base_time = details.typical_inference_time_ms;
444
445        // Rough estimation: longer texts take more time
446        let length_factor = (text_length as f64 / 100.0).sqrt().max(1.0);
447        (base_time as f64 * length_factor) as u64
448    }
449
450    /// Get the memory footprint of the model in MB
451    pub fn model_size_mb(&self) -> usize {
452        self.get_model_details().model_size_mb
453    }
454
455    /// Get efficiency rating (higher is better/faster)
456    pub fn efficiency_rating(&self) -> f32 {
457        match &self.model_type {
458            TransformerModelType::DistilBERT => 1.5, // Fastest
459            TransformerModelType::BERT => 1.0,       // Baseline
460            TransformerModelType::RoBERTa => 0.95,   // Slightly slower
461            TransformerModelType::MultiBERT => 0.8,  // Slowest due to multilingual complexity
462            TransformerModelType::Custom(_) => 1.0,  // Unknown, assume baseline
463        }
464    }
465
466    /// Get model-specific configuration adjustments
467    fn get_model_config(&self) -> (usize, usize, f32) {
468        match &self.model_type {
469            TransformerModelType::BERT => (self.config.dimensions, 512, 1.0), // Use config dimensions
470            TransformerModelType::RoBERTa => (self.config.dimensions, 514, 0.95), // Use config dimensions
471            TransformerModelType::DistilBERT => (self.config.dimensions, 512, 1.5), // Use config dimensions
472            TransformerModelType::MultiBERT => (self.config.dimensions, 512, 0.8), // Use config dimensions
473            TransformerModelType::Custom(_) => {
474                (self.config.dimensions, self.config.max_sequence_length, 1.0)
475            }
476        }
477    }
478
479    /// Get model-specific vocabulary size and training details
480    fn get_model_details(&self) -> ModelDetails {
481        match &self.model_type {
482            TransformerModelType::BERT => ModelDetails {
483                vocab_size: 30522,
484                num_layers: 12,
485                num_attention_heads: 12,
486                hidden_size: 768,
487                intermediate_size: 3072,
488                max_position_embeddings: 512,
489                supports_languages: vec!["en".to_string()],
490                model_size_mb: 440,
491                typical_inference_time_ms: 50,
492            },
493            TransformerModelType::RoBERTa => ModelDetails {
494                vocab_size: 50265,
495                num_layers: 12,
496                num_attention_heads: 12,
497                hidden_size: 768,
498                intermediate_size: 3072,
499                max_position_embeddings: 514,
500                supports_languages: vec!["en".to_string()],
501                model_size_mb: 470,
502                typical_inference_time_ms: 55, // Slightly slower due to different training
503            },
504            TransformerModelType::DistilBERT => ModelDetails {
505                vocab_size: 30522,
506                num_layers: 6, // Half the layers of BERT
507                num_attention_heads: 12,
508                hidden_size: 384, // Smaller hidden size
509                intermediate_size: 1536,
510                max_position_embeddings: 512,
511                supports_languages: vec!["en".to_string()],
512                model_size_mb: 250,            // Much smaller
513                typical_inference_time_ms: 25, // Much faster
514            },
515            TransformerModelType::MultiBERT => ModelDetails {
516                vocab_size: 120000, // Larger vocabulary for multilingual support
517                num_layers: 12,
518                num_attention_heads: 12,
519                hidden_size: 768,
520                intermediate_size: 3072,
521                max_position_embeddings: 512,
522                supports_languages: vec![
523                    "en".to_string(),
524                    "de".to_string(),
525                    "fr".to_string(),
526                    "es".to_string(),
527                    "it".to_string(),
528                    "pt".to_string(),
529                    "ru".to_string(),
530                    "zh".to_string(),
531                    "ja".to_string(),
532                    "ko".to_string(),
533                    "ar".to_string(),
534                    "hi".to_string(),
535                    "th".to_string(),
536                    "tr".to_string(),
537                    "pl".to_string(),
538                    "nl".to_string(),
539                    "sv".to_string(),
540                    "da".to_string(),
541                    "no".to_string(),
542                    "fi".to_string(),
543                ], // Top 20 languages supported
544                model_size_mb: 670, // Larger due to multilingual vocabulary
545                typical_inference_time_ms: 70, // Slower due to larger vocabulary
546            },
547            TransformerModelType::Custom(_path) => ModelDetails {
548                vocab_size: 50000, // Default assumption
549                num_layers: 12,
550                num_attention_heads: 12,
551                hidden_size: self.config.dimensions,
552                intermediate_size: self.config.dimensions * 4,
553                max_position_embeddings: self.config.max_sequence_length,
554                supports_languages: vec!["unknown".to_string()],
555                model_size_mb: 500, // Estimate
556                typical_inference_time_ms: 60,
557            },
558        }
559    }
560
561    /// Generate embedding with model-specific processing
562    fn generate_with_model(&self, text: &str) -> Result<Vector> {
563        let _text_hash = {
564            use std::hash::{Hash, Hasher};
565            let mut hasher = std::collections::hash_map::DefaultHasher::new();
566            text.hash(&mut hasher);
567            hasher.finish()
568        };
569
570        let (dimensions, max_len, _efficiency) = self.get_model_config();
571        let model_details = self.get_model_details();
572
573        // Apply model-specific text preprocessing
574        let processed_text = self.preprocess_text_for_model(text, max_len)?;
575
576        // Simulate tokenization differences between models
577        let token_ids = self.simulate_tokenization(&processed_text, &model_details);
578
579        // Generate model-specific embeddings
580        let values =
581            self.generate_embeddings_from_tokens(&token_ids, dimensions, &model_details)?;
582
583        if self.config.normalize {
584            let magnitude: f32 = values.iter().map(|x| x * x).sum::<f32>().sqrt();
585            if magnitude > 0.0 {
586                let mut normalized_values = values;
587                for value in &mut normalized_values {
588                    *value /= magnitude;
589                }
590                return Ok(Vector::new(normalized_values));
591            }
592        }
593
594        Ok(Vector::new(values))
595    }
596
597    /// Preprocess text according to model-specific requirements
598    fn preprocess_text_for_model(&self, text: &str, max_len: usize) -> Result<String> {
599        let processed = match &self.model_type {
600            TransformerModelType::BERT => {
601                // BERT uses [CLS] and [SEP] tokens
602                let truncated = if text.len() > max_len - 20 {
603                    // Reserve space for special tokens
604                    &text[..max_len - 20]
605                } else {
606                    text
607                };
608                format!("[CLS] {} [SEP]", truncated.to_lowercase())
609            }
610            TransformerModelType::RoBERTa => {
611                // RoBERTa uses <s> and </s> tokens and preserves case better
612                let truncated = if text.len() > max_len - 10 {
613                    &text[..max_len - 10]
614                } else {
615                    text
616                };
617                format!("<s>{truncated}</s>") // RoBERTa preserves case
618            }
619            TransformerModelType::DistilBERT => {
620                // DistilBERT is similar to BERT but more aggressive truncation due to efficiency
621                let truncated = if text.len() > max_len - 20 {
622                    &text[..max_len - 20]
623                } else {
624                    text
625                };
626                format!("[CLS] {} [SEP]", truncated.to_lowercase())
627            }
628            TransformerModelType::MultiBERT => {
629                // Multilingual BERT handles multiple languages, no case conversion for non-Latin scripts
630                let truncated = if text.len() > max_len - 20 {
631                    &text[..max_len - 20]
632                } else {
633                    text
634                };
635                // Detect if text contains non-Latin characters
636                let has_non_latin = !text.is_ascii();
637                if has_non_latin {
638                    format!("[CLS] {truncated} [SEP]") // Preserve case for non-Latin
639                } else {
640                    format!("[CLS] {} [SEP]", truncated.to_lowercase()) // Lowercase for Latin
641                }
642            }
643            TransformerModelType::Custom(_) => {
644                // Basic preprocessing for custom models
645                let truncated = if text.len() > max_len {
646                    &text[..max_len]
647                } else {
648                    text
649                };
650                truncated.to_string()
651            }
652        };
653
654        Ok(processed)
655    }
656
657    /// Simulate tokenization process for different models
658    fn simulate_tokenization(&self, text: &str, model_details: &ModelDetails) -> Vec<u32> {
659        let mut token_ids = Vec::new();
660
661        // Simple word-based tokenization simulation
662        let words: Vec<&str> = text.split_whitespace().collect();
663
664        for word in words {
665            // Simulate subword tokenization
666            let subwords = match &self.model_type {
667                TransformerModelType::RoBERTa => {
668                    // RoBERTa uses byte-pair encoding, tends to create more subwords
669                    self.simulate_bpe_tokenization(word, model_details.vocab_size)
670                }
671                TransformerModelType::DistilBERT | TransformerModelType::BERT => {
672                    // BERT uses WordPiece tokenization
673                    self.simulate_wordpiece_tokenization(word, model_details.vocab_size)
674                }
675                TransformerModelType::MultiBERT => {
676                    // Multilingual BERT has larger vocabulary, fewer subwords
677                    self.simulate_multilingual_tokenization(word, model_details.vocab_size)
678                }
679                TransformerModelType::Custom(_) => {
680                    // Simple tokenization for custom models
681                    vec![self.word_to_token_id(word, model_details.vocab_size)]
682                }
683            };
684
685            token_ids.extend(subwords);
686        }
687
688        // Truncate to max sequence length
689        token_ids.truncate(model_details.max_position_embeddings - 2); // Reserve space for special tokens
690        token_ids
691    }
692
693    /// Simulate BPE tokenization (used by RoBERTa)
694    fn simulate_bpe_tokenization(&self, word: &str, vocab_size: usize) -> Vec<u32> {
695        let mut tokens = Vec::new();
696        let mut remaining = word;
697
698        while !remaining.is_empty() {
699            let chunk_size = if remaining.len() > 4 {
700                4
701            } else {
702                remaining.len()
703            };
704            let chunk = &remaining[..chunk_size];
705            tokens.push(self.word_to_token_id(chunk, vocab_size));
706            remaining = &remaining[chunk_size..];
707        }
708
709        tokens
710    }
711
712    /// Simulate WordPiece tokenization (used by BERT)
713    fn simulate_wordpiece_tokenization(&self, word: &str, vocab_size: usize) -> Vec<u32> {
714        if word.len() <= 6 {
715            vec![self.word_to_token_id(word, vocab_size)]
716        } else {
717            let mid = word.len() / 2;
718            vec![
719                self.word_to_token_id(&word[..mid], vocab_size),
720                self.word_to_token_id(&format!("##{}", &word[mid..]), vocab_size), // ## prefix for subwords
721            ]
722        }
723    }
724
725    /// Simulate multilingual tokenization (larger vocab = fewer subwords)
726    fn simulate_multilingual_tokenization(&self, word: &str, vocab_size: usize) -> Vec<u32> {
727        // Multilingual models have larger vocabularies, so less subword splitting
728        if word.len() <= 10 {
729            vec![self.word_to_token_id(word, vocab_size)]
730        } else {
731            let mid = word.len() / 2;
732            vec![
733                self.word_to_token_id(&word[..mid], vocab_size),
734                self.word_to_token_id(&word[mid..], vocab_size),
735            ]
736        }
737    }
738
739    /// Convert word to token ID
740    fn word_to_token_id(&self, word: &str, vocab_size: usize) -> u32 {
741        use std::hash::{Hash, Hasher};
742        let mut hasher = std::collections::hash_map::DefaultHasher::new();
743        word.hash(&mut hasher);
744        (hasher.finish() % vocab_size as u64) as u32
745    }
746
747    /// Generate embeddings from token IDs using model-specific patterns
748    fn generate_embeddings_from_tokens(
749        &self,
750        token_ids: &[u32],
751        dimensions: usize,
752        model_details: &ModelDetails,
753    ) -> Result<Vec<f32>> {
754        let mut values = vec![0.0; dimensions];
755
756        // Model-specific embedding generation
757        match &self.model_type {
758            TransformerModelType::BERT => {
759                self.generate_bert_style_embeddings(token_ids, &mut values, model_details)
760            }
761            TransformerModelType::RoBERTa => {
762                self.generate_roberta_style_embeddings(token_ids, &mut values, model_details)
763            }
764            TransformerModelType::DistilBERT => {
765                self.generate_distilbert_style_embeddings(token_ids, &mut values, model_details)
766            }
767            TransformerModelType::MultiBERT => {
768                self.generate_multibert_style_embeddings(token_ids, &mut values, model_details)
769            }
770            TransformerModelType::Custom(_) => {
771                self.generate_custom_style_embeddings(token_ids, &mut values, model_details)
772            }
773        }
774
775        Ok(values)
776    }
777
778    /// Generate BERT-style embeddings
779    fn generate_bert_style_embeddings(
780        &self,
781        token_ids: &[u32],
782        values: &mut [f32],
783        _model_details: &ModelDetails,
784    ) {
785        for (i, &token_id) in token_ids.iter().enumerate() {
786            let mut seed = token_id as u64;
787            for value in values.iter_mut() {
788                seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
789                let normalized = (seed as f32) / (u64::MAX as f32);
790                let position_encoding =
791                    ((i as f32 / 512.0) * 2.0 * std::f32::consts::PI).sin() * 0.1;
792                *value += ((normalized - 0.5) * 2.0) + position_encoding;
793            }
794        }
795
796        // Average the contributions from all tokens
797        if !token_ids.is_empty() {
798            for value in values.iter_mut() {
799                *value /= token_ids.len() as f32;
800            }
801        }
802    }
803
804    /// Generate RoBERTa-style embeddings (no segment embeddings, different position encoding)
805    fn generate_roberta_style_embeddings(
806        &self,
807        token_ids: &[u32],
808        values: &mut [f32],
809        _model_details: &ModelDetails,
810    ) {
811        for (i, &token_id) in token_ids.iter().enumerate() {
812            let mut seed = token_id.wrapping_mul(31415927); // Different seed pattern
813            for value in values.iter_mut() {
814                seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
815                let normalized = (seed as f32) / (u64::MAX as f32);
816                // RoBERTa uses learned position embeddings starting from index 2
817                let position_encoding =
818                    ((i as f32 + 2.0) / 514.0 * 2.0 * std::f32::consts::PI).cos() * 0.1;
819                *value += ((normalized - 0.5) * 2.0) + position_encoding;
820            }
821        }
822
823        if !token_ids.is_empty() {
824            for value in values.iter_mut() {
825                *value /= token_ids.len() as f32;
826            }
827        }
828    }
829
830    /// Generate DistilBERT-style embeddings (simpler, faster)
831    fn generate_distilbert_style_embeddings(
832        &self,
833        token_ids: &[u32],
834        values: &mut [f32],
835        _model_details: &ModelDetails,
836    ) {
837        // DistilBERT has fewer layers and smaller hidden size
838        for (i, &token_id) in token_ids.iter().enumerate() {
839            let mut seed = token_id as u64;
840            for value in values.iter_mut() {
841                seed = seed.wrapping_mul(982451653).wrapping_add(12345); // Faster computation
842                let normalized = (seed as f32) / (u64::MAX as f32);
843                // Simpler position encoding
844                let position_encoding = (i as f32 / 512.0).sin() * 0.05;
845                *value += ((normalized - 0.5) * 1.5) + position_encoding; // Slightly different scale
846            }
847        }
848
849        if !token_ids.is_empty() {
850            for value in values.iter_mut() {
851                *value /= token_ids.len() as f32;
852            }
853        }
854    }
855
856    /// Generate multilingual BERT-style embeddings
857    fn generate_multibert_style_embeddings(
858        &self,
859        token_ids: &[u32],
860        values: &mut [f32],
861        _model_details: &ModelDetails,
862    ) {
863        for (i, &token_id) in token_ids.iter().enumerate() {
864            // Multilingual models have different patterns due to cross-lingual training
865            let mut seed = token_id.wrapping_mul(2654435761); // Different multiplier for multilingual
866            for j in 0..values.len() {
867                seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
868                let normalized = (seed as f32) / (u64::MAX as f32);
869                let position_encoding =
870                    ((i as f32 / 512.0) * 2.0 * std::f32::consts::PI).sin() * 0.08;
871                // Add language-agnostic patterns
872                let cross_lingual_bias =
873                    (j as f32 / values.len() as f32 * std::f32::consts::PI).cos() * 0.05;
874                values[j] += ((normalized - 0.5) * 1.8) + position_encoding + cross_lingual_bias;
875            }
876        }
877
878        if !token_ids.is_empty() {
879            for value in values.iter_mut() {
880                *value /= token_ids.len() as f32;
881            }
882        }
883    }
884
885    /// Generate custom model embeddings
886    fn generate_custom_style_embeddings(
887        &self,
888        token_ids: &[u32],
889        values: &mut [f32],
890        _model_details: &ModelDetails,
891    ) {
892        // Simple approach for custom models
893        for &token_id in token_ids {
894            let mut seed = token_id as u64;
895            for value in values.iter_mut() {
896                seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
897                let normalized = (seed as f32) / (u64::MAX as f32);
898                *value += (normalized - 0.5) * 2.0;
899            }
900        }
901
902        if !token_ids.is_empty() {
903            for value in values.iter_mut() {
904                *value /= token_ids.len() as f32;
905            }
906        }
907    }
908}
909
910impl EmbeddingGenerator for SentenceTransformerGenerator {
911    fn generate(&self, content: &EmbeddableContent) -> Result<Vector> {
912        let text = content.to_text();
913        self.generate_with_model(&text)
914    }
915
916    fn dimensions(&self) -> usize {
917        self.config.dimensions
918    }
919
920    fn config(&self) -> &EmbeddingConfig {
921        &self.config
922    }
923}
924
925/// Embedding cache for frequently accessed embeddings
926pub struct EmbeddingCache {
927    cache: HashMap<u64, Vector>,
928    max_size: usize,
929    access_order: Vec<u64>,
930}
931
932impl EmbeddingCache {
933    pub fn new(max_size: usize) -> Self {
934        Self {
935            cache: HashMap::new(),
936            max_size,
937            access_order: Vec::new(),
938        }
939    }
940
941    pub fn get(&mut self, content: &EmbeddableContent) -> Option<&Vector> {
942        let hash = content.content_hash();
943        if let Some(vector) = self.cache.get(&hash) {
944            // Move to end (most recently used)
945            if let Some(pos) = self.access_order.iter().position(|&x| x == hash) {
946                self.access_order.remove(pos);
947            }
948            self.access_order.push(hash);
949            Some(vector)
950        } else {
951            None
952        }
953    }
954
955    pub fn insert(&mut self, content: &EmbeddableContent, vector: Vector) {
956        let hash = content.content_hash();
957
958        // Remove least recently used if at capacity
959        if self.cache.len() >= self.max_size && !self.cache.contains_key(&hash) {
960            if let Some(&lru_hash) = self.access_order.first() {
961                self.cache.remove(&lru_hash);
962                self.access_order.remove(0);
963            }
964        }
965
966        self.cache.insert(hash, vector);
967        self.access_order.push(hash);
968    }
969
970    pub fn clear(&mut self) {
971        self.cache.clear();
972        self.access_order.clear();
973    }
974
975    pub fn size(&self) -> usize {
976        self.cache.len()
977    }
978}
979
980/// Embedding manager that combines generation, caching, and persistence
981pub struct EmbeddingManager {
982    generator: Box<dyn EmbeddingGenerator>,
983    cache: EmbeddingCache,
984    strategy: EmbeddingStrategy,
985}
986
987impl EmbeddingManager {
988    pub fn new(strategy: EmbeddingStrategy, cache_size: usize) -> Result<Self> {
989        let generator: Box<dyn EmbeddingGenerator> = match &strategy {
990            EmbeddingStrategy::TfIdf => {
991                let config = EmbeddingConfig::default();
992                Box::new(TfIdfEmbeddingGenerator::new(config))
993            }
994            EmbeddingStrategy::SentenceTransformer => {
995                let config = EmbeddingConfig::default();
996                Box::new(SentenceTransformerGenerator::new(config))
997            }
998            EmbeddingStrategy::Transformer(model_type) => {
999                let config = EmbeddingConfig {
1000                    model_name: format!("{model_type:?}"),
1001                    dimensions: match model_type {
1002                        TransformerModelType::DistilBERT => 384, // DistilBERT is smaller
1003                        _ => 768,                                // Most BERT variants
1004                    },
1005                    max_sequence_length: 512,
1006                    normalize: true,
1007                };
1008                Box::new(SentenceTransformerGenerator::with_model_type(
1009                    config,
1010                    model_type.clone(),
1011                ))
1012            }
1013            EmbeddingStrategy::Word2Vec(word2vec_config) => {
1014                let embedding_config = EmbeddingConfig {
1015                    model_name: "word2vec".to_string(),
1016                    dimensions: word2vec_config.dimensions,
1017                    max_sequence_length: 512,
1018                    normalize: word2vec_config.normalize,
1019                };
1020                Box::new(crate::word2vec::Word2VecEmbeddingGenerator::new(
1021                    word2vec_config.clone(),
1022                    embedding_config,
1023                )?)
1024            }
1025            EmbeddingStrategy::OpenAI(openai_config) => {
1026                Box::new(OpenAIEmbeddingGenerator::new(openai_config.clone())?)
1027            }
1028            EmbeddingStrategy::Custom(_model_path) => {
1029                // For now, fall back to sentence transformer
1030                let config = EmbeddingConfig::default();
1031                Box::new(SentenceTransformerGenerator::new(config))
1032            }
1033        };
1034
1035        Ok(Self {
1036            generator,
1037            cache: EmbeddingCache::new(cache_size),
1038            strategy,
1039        })
1040    }
1041
1042    /// Get or generate embedding for content
1043    pub fn get_embedding(&mut self, content: &EmbeddableContent) -> Result<Vector> {
1044        if let Some(cached) = self.cache.get(content) {
1045            return Ok(cached.clone());
1046        }
1047
1048        let embedding = self.generator.generate(content)?;
1049        self.cache.insert(content, embedding.clone());
1050        Ok(embedding)
1051    }
1052
1053    /// Pre-compute embeddings for a batch of content
1054    pub fn precompute_embeddings(&mut self, contents: &[EmbeddableContent]) -> Result<()> {
1055        let embeddings = self.generator.generate_batch(contents)?;
1056
1057        for (content, embedding) in contents.iter().zip(embeddings) {
1058            self.cache.insert(content, embedding);
1059        }
1060
1061        Ok(())
1062    }
1063
1064    /// Build vocabulary for TF-IDF strategy
1065    pub fn build_vocabulary(&mut self, documents: &[String]) -> Result<()> {
1066        if let EmbeddingStrategy::TfIdf = self.strategy {
1067            if let Some(tfidf_gen) = self
1068                .generator
1069                .as_any_mut()
1070                .downcast_mut::<TfIdfEmbeddingGenerator>()
1071            {
1072                tfidf_gen.build_vocabulary(documents)?;
1073            }
1074        }
1075        Ok(())
1076    }
1077
1078    pub fn dimensions(&self) -> usize {
1079        self.generator.dimensions()
1080    }
1081
1082    pub fn cache_stats(&self) -> (usize, usize) {
1083        (self.cache.size(), self.cache.max_size)
1084    }
1085}
1086
1087/// Extension trait to add downcast functionality
1088pub trait AsAny {
1089    fn as_any(&self) -> &dyn std::any::Any;
1090    fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
1091}
1092
1093impl AsAny for TfIdfEmbeddingGenerator {
1094    fn as_any(&self) -> &dyn std::any::Any {
1095        self
1096    }
1097
1098    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
1099        self
1100    }
1101}
1102
1103impl AsAny for SentenceTransformerGenerator {
1104    fn as_any(&self) -> &dyn std::any::Any {
1105        self
1106    }
1107
1108    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
1109        self
1110    }
1111}
1112
1113/// OpenAI embeddings generator with rate limiting and retry logic
1114pub struct OpenAIEmbeddingGenerator {
1115    config: EmbeddingConfig,
1116    openai_config: OpenAIConfig,
1117    client: reqwest::Client,
1118    rate_limiter: RateLimiter,
1119    request_cache: std::sync::Arc<std::sync::Mutex<lru::LruCache<u64, CachedEmbedding>>>,
1120    metrics: OpenAIMetrics,
1121}
1122
1123/// Cached embedding with metadata
1124#[derive(Debug, Clone)]
1125pub struct CachedEmbedding {
1126    pub vector: Vector,
1127    pub cached_at: std::time::SystemTime,
1128    pub model: String,
1129    pub cost_usd: f64,
1130}
1131
1132/// Metrics for OpenAI API usage
1133#[derive(Debug, Clone, Default)]
1134pub struct OpenAIMetrics {
1135    pub total_requests: u64,
1136    pub successful_requests: u64,
1137    pub failed_requests: u64,
1138    pub total_tokens_processed: u64,
1139    pub cache_hits: u64,
1140    pub cache_misses: u64,
1141    pub total_cost_usd: f64,
1142    pub retry_count: u64,
1143    pub rate_limit_waits: u64,
1144    pub average_response_time_ms: f64,
1145    pub last_request_time: Option<std::time::SystemTime>,
1146    pub requests_by_model: HashMap<String, u64>,
1147    pub errors_by_type: HashMap<String, u64>,
1148}
1149
1150impl OpenAIMetrics {
1151    /// Calculate cache hit ratio
1152    pub fn cache_hit_ratio(&self) -> f64 {
1153        if self.cache_hits + self.cache_misses == 0 {
1154            0.0
1155        } else {
1156            self.cache_hits as f64 / (self.cache_hits + self.cache_misses) as f64
1157        }
1158    }
1159
1160    /// Calculate success rate
1161    pub fn success_rate(&self) -> f64 {
1162        if self.total_requests == 0 {
1163            0.0
1164        } else {
1165            self.successful_requests as f64 / self.total_requests as f64
1166        }
1167    }
1168
1169    /// Calculate average cost per request
1170    pub fn average_cost_per_request(&self) -> f64 {
1171        if self.successful_requests == 0 {
1172            0.0
1173        } else {
1174            self.total_cost_usd / self.successful_requests as f64
1175        }
1176    }
1177
1178    /// Get formatted metrics report
1179    pub fn report(&self) -> String {
1180        format!(
1181            "OpenAI Metrics Report:\n\
1182            Total Requests: {}\n\
1183            Success Rate: {:.2}%\n\
1184            Cache Hit Ratio: {:.2}%\n\
1185            Total Cost: ${:.4}\n\
1186            Avg Cost/Request: ${:.6}\n\
1187            Avg Response Time: {:.2}ms\n\
1188            Retries: {}\n\
1189            Rate Limit Waits: {}",
1190            self.total_requests,
1191            self.success_rate() * 100.0,
1192            self.cache_hit_ratio() * 100.0,
1193            self.total_cost_usd,
1194            self.average_cost_per_request(),
1195            self.average_response_time_ms,
1196            self.retry_count,
1197            self.rate_limit_waits
1198        )
1199    }
1200}
1201
1202/// Simple rate limiter implementation
1203pub struct RateLimiter {
1204    requests_per_minute: u32,
1205    request_times: std::collections::VecDeque<std::time::Instant>,
1206}
1207
1208impl RateLimiter {
1209    pub fn new(requests_per_minute: u32) -> Self {
1210        Self {
1211            requests_per_minute,
1212            request_times: std::collections::VecDeque::new(),
1213        }
1214    }
1215
1216    pub async fn wait_if_needed(&mut self) {
1217        let now = std::time::Instant::now();
1218        let minute_ago = now - std::time::Duration::from_secs(60);
1219
1220        // Remove requests older than 1 minute
1221        while let Some(&front_time) = self.request_times.front() {
1222            if front_time < minute_ago {
1223                self.request_times.pop_front();
1224            } else {
1225                break;
1226            }
1227        }
1228
1229        // If we're at the rate limit, wait
1230        if self.request_times.len() >= self.requests_per_minute as usize {
1231            if let Some(&oldest) = self.request_times.front() {
1232                let wait_time = oldest + std::time::Duration::from_secs(60) - now;
1233                if !wait_time.is_zero() {
1234                    tokio::time::sleep(wait_time).await;
1235                }
1236            }
1237        }
1238
1239        self.request_times.push_back(now);
1240    }
1241}
1242
1243impl OpenAIEmbeddingGenerator {
1244    pub fn new(openai_config: OpenAIConfig) -> Result<Self> {
1245        openai_config.validate()?;
1246
1247        let client = reqwest::Client::builder()
1248            .timeout(std::time::Duration::from_secs(
1249                openai_config.timeout_seconds,
1250            ))
1251            .user_agent(&openai_config.user_agent)
1252            .build()
1253            .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
1254
1255        let embedding_config = EmbeddingConfig {
1256            model_name: openai_config.model.clone(),
1257            dimensions: Self::get_model_dimensions(&openai_config.model),
1258            max_sequence_length: 8191, // OpenAI limit
1259            normalize: true,
1260        };
1261
1262        let cache_size = if openai_config.enable_cache {
1263            std::num::NonZeroUsize::new(openai_config.cache_size)
1264                .unwrap_or(std::num::NonZeroUsize::new(1000).unwrap())
1265        } else {
1266            std::num::NonZeroUsize::new(1).unwrap()
1267        };
1268
1269        Ok(Self {
1270            config: embedding_config,
1271            openai_config: openai_config.clone(),
1272            client,
1273            rate_limiter: RateLimiter::new(openai_config.requests_per_minute),
1274            request_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(
1275                cache_size,
1276            ))),
1277            metrics: OpenAIMetrics::default(),
1278        })
1279    }
1280
1281    /// Get dimensions for different OpenAI models
1282    fn get_model_dimensions(model: &str) -> usize {
1283        match model {
1284            "text-embedding-ada-002" => 1536,
1285            "text-embedding-3-small" => 1536,
1286            "text-embedding-3-large" => 3072,
1287            "text-embedding-004" => 1536,
1288            _ => 1536, // Default
1289        }
1290    }
1291
1292    /// Get cost per 1k tokens for different models (in USD)
1293    fn get_model_cost_per_1k_tokens(model: &str) -> f64 {
1294        match model {
1295            "text-embedding-ada-002" => 0.0001,
1296            "text-embedding-3-small" => 0.00002,
1297            "text-embedding-3-large" => 0.00013,
1298            "text-embedding-004" => 0.00002,
1299            _ => 0.0001, // Conservative default
1300        }
1301    }
1302
1303    /// Calculate cost for processing texts
1304    fn calculate_cost(&self, texts: &[String]) -> f64 {
1305        if !self.openai_config.track_costs {
1306            return 0.0;
1307        }
1308
1309        let total_tokens: usize = texts.iter().map(|t| t.len() / 4).sum(); // Rough token estimation
1310        let cost_per_1k = Self::get_model_cost_per_1k_tokens(&self.openai_config.model);
1311        (total_tokens as f64 / 1000.0) * cost_per_1k
1312    }
1313
1314    /// Check if cached embedding is still valid
1315    fn is_cache_valid(&self, cached: &CachedEmbedding) -> bool {
1316        if self.openai_config.cache_ttl_seconds == 0 {
1317            return true; // No expiration
1318        }
1319
1320        let elapsed = cached
1321            .cached_at
1322            .elapsed()
1323            .unwrap_or(std::time::Duration::from_secs(u64::MAX));
1324
1325        elapsed.as_secs() < self.openai_config.cache_ttl_seconds
1326    }
1327
1328    /// Make API request to OpenAI with retry logic
1329    async fn make_request(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
1330        let start_time = std::time::Instant::now();
1331        let mut attempts = 0;
1332
1333        while attempts < self.openai_config.max_retries {
1334            match self.try_request(texts).await {
1335                Ok(embeddings) => {
1336                    // Update metrics
1337                    if self.openai_config.enable_metrics {
1338                        let response_time = start_time.elapsed().as_millis() as f64;
1339                        self.update_response_time(response_time);
1340
1341                        let cost = self.calculate_cost(texts);
1342                        self.metrics.total_cost_usd += cost;
1343
1344                        *self
1345                            .metrics
1346                            .requests_by_model
1347                            .entry(self.openai_config.model.clone())
1348                            .or_insert(0) += 1;
1349                    }
1350
1351                    return Ok(embeddings);
1352                }
1353                Err(e) => {
1354                    attempts += 1;
1355                    self.metrics.retry_count += 1;
1356
1357                    // Track error types
1358                    let error_type = if e.to_string().contains("rate_limit") {
1359                        "rate_limit"
1360                    } else if e.to_string().contains("timeout") {
1361                        "timeout"
1362                    } else if e.to_string().contains("401") {
1363                        "unauthorized"
1364                    } else if e.to_string().contains("400") {
1365                        "bad_request"
1366                    } else {
1367                        "other"
1368                    };
1369
1370                    *self
1371                        .metrics
1372                        .errors_by_type
1373                        .entry(error_type.to_string())
1374                        .or_insert(0) += 1;
1375
1376                    if attempts >= self.openai_config.max_retries {
1377                        return Err(e);
1378                    }
1379
1380                    // Calculate delay based on retry strategy
1381                    let delay = self.calculate_retry_delay(attempts);
1382                    tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
1383                }
1384            }
1385        }
1386
1387        Err(anyhow!("Max retries exceeded"))
1388    }
1389
1390    /// Calculate retry delay based on strategy
1391    fn calculate_retry_delay(&self, attempt: u32) -> u64 {
1392        let base_delay = self.openai_config.retry_delay_ms;
1393
1394        match self.openai_config.retry_strategy {
1395            RetryStrategy::Fixed => base_delay,
1396            RetryStrategy::LinearBackoff => base_delay * attempt as u64,
1397            RetryStrategy::ExponentialBackoff => {
1398                let delay = base_delay * (2_u64.pow(attempt - 1));
1399                // Add jitter (±25%)
1400                let jitter = {
1401                    #[allow(unused_imports)]
1402                    use scirs2_core::random::{Random, Rng};
1403                    let mut rng = Random::seed(42);
1404                    (delay as f64 * 0.25 * (rng.gen_range(0.0..1.0) - 0.5)) as u64
1405                };
1406                delay.saturating_add(jitter).min(30000) // Max 30 seconds
1407            }
1408        }
1409    }
1410
1411    /// Update response time metrics
1412    fn update_response_time(&mut self, response_time_ms: f64) {
1413        if self.metrics.successful_requests == 0 {
1414            self.metrics.average_response_time_ms = response_time_ms;
1415        } else {
1416            // Running average
1417            let total =
1418                self.metrics.average_response_time_ms * self.metrics.successful_requests as f64;
1419            self.metrics.average_response_time_ms =
1420                (total + response_time_ms) / (self.metrics.successful_requests + 1) as f64;
1421        }
1422    }
1423
1424    async fn try_request(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
1425        self.rate_limiter.wait_if_needed().await;
1426
1427        let request_body = serde_json::json!({
1428            "model": self.openai_config.model,
1429            "input": texts,
1430            "encoding_format": "float"
1431        });
1432
1433        let response = self
1434            .client
1435            .post(format!("{}/embeddings", self.openai_config.base_url))
1436            .header(
1437                "Authorization",
1438                format!("Bearer {}", self.openai_config.api_key),
1439            )
1440            .header("Content-Type", "application/json")
1441            .json(&request_body)
1442            .send()
1443            .await
1444            .map_err(|e| anyhow!("Request failed: {}", e))?;
1445
1446        if !response.status().is_success() {
1447            let status = response.status();
1448            let error_text = response.text().await.unwrap_or_default();
1449            return Err(anyhow!(
1450                "API request failed with status {}: {}",
1451                status,
1452                error_text
1453            ));
1454        }
1455
1456        let response_data: serde_json::Value = response
1457            .json()
1458            .await
1459            .map_err(|e| anyhow!("Failed to parse response: {}", e))?;
1460
1461        let embeddings_data = response_data["data"]
1462            .as_array()
1463            .ok_or_else(|| anyhow!("Invalid response format: missing data array"))?;
1464
1465        let mut embeddings = Vec::new();
1466        for item in embeddings_data {
1467            let embedding = item["embedding"]
1468                .as_array()
1469                .ok_or_else(|| anyhow!("Invalid response format: missing embedding"))?;
1470
1471            let vec: Result<Vec<f32>, _> = embedding
1472                .iter()
1473                .map(|v| {
1474                    v.as_f64()
1475                        .ok_or_else(|| anyhow!("Invalid embedding value"))
1476                        .map(|f| f as f32)
1477                })
1478                .collect();
1479
1480            embeddings.push(vec?);
1481        }
1482
1483        Ok(embeddings)
1484    }
1485
1486    /// Generate embeddings with batching support
1487    pub async fn generate_async(&mut self, content: &EmbeddableContent) -> Result<Vector> {
1488        let text = content.to_text();
1489
1490        // Check cache first
1491        if self.openai_config.enable_cache {
1492            let hash = content.content_hash();
1493
1494            // Check if cached entry exists and is valid
1495            let cached_vector = match self.request_cache.lock() {
1496                Ok(mut cache) => {
1497                    if let Some(cached) = cache.get(&hash) {
1498                        let is_valid = cached.cached_at.elapsed().unwrap_or_default()
1499                            < Duration::from_secs(self.openai_config.cache_ttl_seconds);
1500                        if is_valid {
1501                            Some(cached.vector.clone())
1502                        } else {
1503                            None
1504                        }
1505                    } else {
1506                        None
1507                    }
1508                }
1509                _ => None,
1510            };
1511
1512            if let Some(result) = cached_vector {
1513                self.update_cache_hit();
1514                return Ok(result);
1515            } else {
1516                // Remove expired entry if it exists
1517                if let Ok(mut cache) = self.request_cache.lock() {
1518                    cache.pop(&hash);
1519                }
1520                self.update_cache_miss();
1521            }
1522        }
1523
1524        let embeddings = match self.make_request(std::slice::from_ref(&text)).await {
1525            Ok(embeddings) => {
1526                self.update_metrics_success(std::slice::from_ref(&text));
1527                embeddings
1528            }
1529            Err(e) => {
1530                self.update_metrics_failure();
1531                return Err(e);
1532            }
1533        };
1534
1535        if embeddings.is_empty() {
1536            self.update_metrics_failure();
1537            return Err(anyhow!("No embeddings returned from API"));
1538        }
1539
1540        let vector = Vector::new(embeddings[0].clone());
1541
1542        // Cache the result
1543        if self.openai_config.enable_cache {
1544            let hash = content.content_hash();
1545            let cost = self.calculate_cost(std::slice::from_ref(&text));
1546            let cached_embedding = CachedEmbedding {
1547                vector: vector.clone(),
1548                cached_at: std::time::SystemTime::now(),
1549                model: self.openai_config.model.clone(),
1550                cost_usd: cost,
1551            };
1552            if let Ok(mut cache) = self.request_cache.lock() {
1553                cache.put(hash, cached_embedding);
1554            }
1555        }
1556
1557        Ok(vector)
1558    }
1559
1560    /// Generate embeddings for multiple texts in batch
1561    pub async fn generate_batch_async(
1562        &mut self,
1563        contents: &[EmbeddableContent],
1564    ) -> Result<Vec<Vector>> {
1565        if contents.is_empty() {
1566            return Ok(Vec::new());
1567        }
1568
1569        let mut results = Vec::with_capacity(contents.len());
1570        let batch_size = self.openai_config.batch_size;
1571
1572        for chunk in contents.chunks(batch_size) {
1573            let texts: Vec<String> = chunk.iter().map(|c| c.to_text()).collect();
1574
1575            let embeddings = match self.make_request(&texts).await {
1576                Ok(embeddings) => {
1577                    self.update_metrics_success(&texts);
1578                    embeddings
1579                }
1580                Err(e) => {
1581                    self.update_metrics_failure();
1582                    return Err(e);
1583                }
1584            };
1585
1586            if embeddings.len() != chunk.len() {
1587                self.update_metrics_failure();
1588                return Err(anyhow!("Mismatch between request and response sizes"));
1589            }
1590
1591            let batch_cost = self.calculate_cost(&texts) / chunk.len() as f64;
1592
1593            for (content, embedding) in chunk.iter().zip(embeddings) {
1594                let vector = Vector::new(embedding);
1595
1596                // Cache the result
1597                if self.openai_config.enable_cache {
1598                    let hash = content.content_hash();
1599                    let cached_embedding = CachedEmbedding {
1600                        vector: vector.clone(),
1601                        cached_at: std::time::SystemTime::now(),
1602                        model: self.openai_config.model.clone(),
1603                        cost_usd: batch_cost,
1604                    };
1605                    if let Ok(mut cache) = self.request_cache.lock() {
1606                        cache.put(hash, cached_embedding);
1607                    }
1608                }
1609
1610                results.push(vector);
1611            }
1612        }
1613
1614        Ok(results)
1615    }
1616
1617    /// Clear the request cache
1618    pub fn clear_cache(&mut self) {
1619        if let Ok(mut cache) = self.request_cache.lock() {
1620            cache.clear();
1621        }
1622    }
1623
1624    /// Get cache statistics
1625    pub fn cache_stats(&self) -> (usize, Option<usize>) {
1626        match self.request_cache.lock() {
1627            Ok(cache) => (cache.len(), Some(cache.cap().into())),
1628            _ => (0, None),
1629        }
1630    }
1631
1632    /// Get total cache cost
1633    pub fn get_cache_cost(&self) -> f64 {
1634        match self.request_cache.lock() {
1635            Ok(cache) => cache.iter().map(|(_, cached)| cached.cost_usd).sum(),
1636            _ => 0.0,
1637        }
1638    }
1639
1640    /// Get API usage metrics
1641    pub fn get_metrics(&self) -> &OpenAIMetrics {
1642        &self.metrics
1643    }
1644
1645    /// Reset metrics
1646    pub fn reset_metrics(&mut self) {
1647        self.metrics = OpenAIMetrics::default();
1648    }
1649
1650    /// Estimate token count for text (approximate)
1651    fn estimate_tokens(&self, text: &str) -> u64 {
1652        // Rough estimation: ~4 characters per token on average
1653        // This is an approximation - actual tokenization depends on the model
1654        (text.len() / 4).max(1) as u64
1655    }
1656
1657    /// Calculate cost for embeddings request
1658    fn calculate_cost_from_tokens(&self, total_tokens: u64) -> f64 {
1659        // OpenAI pricing (as of 2024) - these should be configurable
1660        let cost_per_1k_tokens = match self.openai_config.model.as_str() {
1661            "text-embedding-ada-002" => 0.0001,  // $0.0001 per 1K tokens
1662            "text-embedding-3-small" => 0.00002, // $0.00002 per 1K tokens
1663            "text-embedding-3-large" => 0.00013, // $0.00013 per 1K tokens
1664            _ => 0.0001,                         // Default to ada-002 pricing
1665        };
1666
1667        (total_tokens as f64 / 1000.0) * cost_per_1k_tokens
1668    }
1669
1670    /// Update metrics after successful request
1671    fn update_metrics_success(&mut self, texts: &[String]) {
1672        self.metrics.total_requests += 1;
1673        self.metrics.successful_requests += 1;
1674
1675        let total_tokens: u64 = texts.iter().map(|text| self.estimate_tokens(text)).sum();
1676
1677        self.metrics.total_tokens_processed += total_tokens;
1678        self.metrics.total_cost_usd += self.calculate_cost_from_tokens(total_tokens);
1679    }
1680
1681    /// Update metrics after failed request
1682    fn update_metrics_failure(&mut self) {
1683        self.metrics.total_requests += 1;
1684        self.metrics.failed_requests += 1;
1685    }
1686
1687    /// Update cache metrics
1688    fn update_cache_hit(&mut self) {
1689        self.metrics.cache_hits += 1;
1690    }
1691
1692    fn update_cache_miss(&mut self) {
1693        self.metrics.cache_misses += 1;
1694    }
1695}
1696
1697impl EmbeddingGenerator for OpenAIEmbeddingGenerator {
1698    fn generate(&self, content: &EmbeddableContent) -> Result<Vector> {
1699        // Check cache first (readonly access is fine)
1700        if self.openai_config.enable_cache {
1701            let hash = content.content_hash();
1702            if let Ok(mut cache) = self.request_cache.lock() {
1703                if let Some(cached) = cache.get(&hash) {
1704                    return Ok(cached.vector.clone());
1705                }
1706            }
1707        }
1708
1709        // For synchronous interface with API calls, we need to use async runtime
1710        // This is a workaround for the trait design limitation
1711        let rt = tokio::runtime::Runtime::new()
1712            .map_err(|e| anyhow!("Failed to create async runtime: {}", e))?;
1713
1714        // Create a temporary mutable copy for the async operation
1715        let mut temp_generator = OpenAIEmbeddingGenerator {
1716            config: self.config.clone(),
1717            openai_config: self.openai_config.clone(),
1718            client: self.client.clone(),
1719            rate_limiter: RateLimiter::new(self.openai_config.requests_per_minute),
1720            request_cache: self.request_cache.clone(),
1721            metrics: self.metrics.clone(),
1722        };
1723
1724        rt.block_on(temp_generator.generate_async(content))
1725    }
1726
1727    fn dimensions(&self) -> usize {
1728        self.config.dimensions
1729    }
1730
1731    fn config(&self) -> &EmbeddingConfig {
1732        &self.config
1733    }
1734}
1735
1736impl AsAny for OpenAIEmbeddingGenerator {
1737    fn as_any(&self) -> &dyn std::any::Any {
1738        self
1739    }
1740
1741    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
1742        self
1743    }
1744}
1745
1746/// Mock embedding generator for testing
1747#[cfg(test)]
1748pub struct MockEmbeddingGenerator {
1749    config: EmbeddingConfig,
1750}
1751
1752#[cfg(test)]
1753impl Default for MockEmbeddingGenerator {
1754    fn default() -> Self {
1755        Self::new()
1756    }
1757}
1758
1759#[cfg(test)]
1760impl MockEmbeddingGenerator {
1761    pub fn new() -> Self {
1762        Self {
1763            config: EmbeddingConfig {
1764                dimensions: 128,
1765                ..Default::default()
1766            },
1767        }
1768    }
1769
1770    pub fn with_dimensions(dimensions: usize) -> Self {
1771        Self {
1772            config: EmbeddingConfig {
1773                dimensions,
1774                ..Default::default()
1775            },
1776        }
1777    }
1778}
1779
1780#[cfg(test)]
1781impl AsAny for MockEmbeddingGenerator {
1782    fn as_any(&self) -> &dyn std::any::Any {
1783        self
1784    }
1785
1786    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
1787        self
1788    }
1789}
1790
1791#[cfg(test)]
1792impl EmbeddingGenerator for MockEmbeddingGenerator {
1793    fn generate(&self, content: &EmbeddableContent) -> Result<crate::Vector> {
1794        let text = content.to_text();
1795
1796        // Generate deterministic mock embedding based on content hash
1797        let mut hasher = std::collections::hash_map::DefaultHasher::new();
1798        text.hash(&mut hasher);
1799        let hash = hasher.finish();
1800
1801        let mut embedding = Vec::with_capacity(self.config.dimensions);
1802        let mut seed = hash;
1803
1804        for _ in 0..self.config.dimensions {
1805            // Simple LCG for deterministic values
1806            seed = seed.wrapping_mul(1664525).wrapping_add(1013904223);
1807            let value = (seed as f64 / u64::MAX as f64) as f32;
1808            embedding.push(value * 2.0 - 1.0); // Range [-1, 1]
1809        }
1810
1811        // Normalize to unit vector
1812        let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
1813        if magnitude > 0.0 {
1814            for value in &mut embedding {
1815                *value /= magnitude;
1816            }
1817        }
1818
1819        Ok(crate::Vector::new(embedding))
1820    }
1821
1822    fn dimensions(&self) -> usize {
1823        self.config.dimensions
1824    }
1825
1826    fn config(&self) -> &EmbeddingConfig {
1827        &self.config
1828    }
1829}
1830
1831#[cfg(test)]
1832mod tests {
1833    use super::*;
1834
1835    #[test]
1836    fn test_transformer_model_types() {
1837        let config = EmbeddingConfig::default();
1838
1839        // Test BERT
1840        let bert = SentenceTransformerGenerator::new(config.clone());
1841        assert!(matches!(bert.model_type(), TransformerModelType::BERT));
1842        assert_eq!(bert.dimensions(), 384); // Default config dimensions
1843
1844        // Test RoBERTa
1845        let roberta = SentenceTransformerGenerator::roberta(config.clone());
1846        assert!(matches!(
1847            roberta.model_type(),
1848            TransformerModelType::RoBERTa
1849        ));
1850
1851        // Test DistilBERT
1852        let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
1853        assert!(matches!(
1854            distilbert.model_type(),
1855            TransformerModelType::DistilBERT
1856        ));
1857        assert_eq!(distilbert.dimensions(), 384); // DistilBERT uses smaller dimensions
1858
1859        // Test Multilingual BERT
1860        let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
1861        assert!(matches!(
1862            multibert.model_type(),
1863            TransformerModelType::MultiBERT
1864        ));
1865    }
1866
1867    #[test]
1868    fn test_model_details() {
1869        let config = EmbeddingConfig::default();
1870
1871        // Test BERT details
1872        let bert = SentenceTransformerGenerator::new(config.clone());
1873        let bert_details = bert.model_details();
1874        assert_eq!(bert_details.vocab_size, 30522);
1875        assert_eq!(bert_details.num_layers, 12);
1876        assert_eq!(bert_details.hidden_size, 768);
1877        assert!(bert_details.supports_languages.contains(&"en".to_string()));
1878
1879        // Test RoBERTa details
1880        let roberta = SentenceTransformerGenerator::roberta(config.clone());
1881        let roberta_details = roberta.model_details();
1882        assert_eq!(roberta_details.vocab_size, 50265); // Larger vocab than BERT
1883        assert_eq!(roberta_details.max_position_embeddings, 514); // RoBERTa supports longer sequences
1884
1885        // Test DistilBERT details
1886        let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
1887        let distilbert_details = distilbert.model_details();
1888        assert_eq!(distilbert_details.num_layers, 6); // Half the layers of BERT
1889        assert_eq!(distilbert_details.hidden_size, 384); // Smaller hidden size
1890        assert!(distilbert_details.model_size_mb < bert_details.model_size_mb); // Smaller model
1891        assert!(
1892            distilbert_details.typical_inference_time_ms < bert_details.typical_inference_time_ms
1893        ); // Faster
1894
1895        // Test Multilingual BERT details
1896        let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
1897        let multibert_details = multibert.model_details();
1898        assert_eq!(multibert_details.vocab_size, 120000); // Much larger vocabulary
1899        assert!(multibert_details.supports_languages.len() > 10); // Supports many languages
1900        assert!(multibert_details
1901            .supports_languages
1902            .contains(&"zh".to_string())); // Chinese
1903        assert!(multibert_details
1904            .supports_languages
1905            .contains(&"de".to_string())); // German
1906    }
1907
1908    #[test]
1909    fn test_language_support() {
1910        let config = EmbeddingConfig::default();
1911
1912        // BERT and DistilBERT only support English
1913        let bert = SentenceTransformerGenerator::new(config.clone());
1914        assert!(bert.supports_language("en"));
1915        assert!(!bert.supports_language("zh"));
1916        assert!(!bert.supports_language("de"));
1917
1918        let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
1919        assert!(distilbert.supports_language("en"));
1920        assert!(!distilbert.supports_language("zh"));
1921
1922        // Multilingual BERT supports many languages
1923        let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
1924        assert!(multibert.supports_language("en"));
1925        assert!(multibert.supports_language("zh"));
1926        assert!(multibert.supports_language("de"));
1927        assert!(multibert.supports_language("fr"));
1928        assert!(multibert.supports_language("es"));
1929        assert!(!multibert.supports_language("unknown_lang"));
1930    }
1931
1932    #[test]
1933    fn test_efficiency_ratings() {
1934        let config = EmbeddingConfig::default();
1935
1936        let bert = SentenceTransformerGenerator::new(config.clone());
1937        let roberta = SentenceTransformerGenerator::roberta(config.clone());
1938        let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
1939        let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
1940
1941        // DistilBERT should be the most efficient
1942        assert!(distilbert.efficiency_rating() > bert.efficiency_rating());
1943        assert!(distilbert.efficiency_rating() > roberta.efficiency_rating());
1944        assert!(distilbert.efficiency_rating() > multibert.efficiency_rating());
1945
1946        // RoBERTa should be slightly less efficient than BERT
1947        assert!(bert.efficiency_rating() > roberta.efficiency_rating());
1948
1949        // Multilingual BERT should be the least efficient
1950        assert!(bert.efficiency_rating() > multibert.efficiency_rating());
1951        assert!(roberta.efficiency_rating() > multibert.efficiency_rating());
1952    }
1953
1954    #[test]
1955    fn test_inference_time_estimation() {
1956        let config = EmbeddingConfig::default();
1957        let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
1958        let bert = SentenceTransformerGenerator::new(config.clone());
1959
1960        // Short text
1961        let short_time_distilbert = distilbert.estimate_inference_time(50);
1962        let short_time_bert = bert.estimate_inference_time(50);
1963
1964        // Long text
1965        let long_time_distilbert = distilbert.estimate_inference_time(500);
1966        let long_time_bert = bert.estimate_inference_time(500);
1967
1968        // DistilBERT should be faster for both short and long texts
1969        assert!(short_time_distilbert < short_time_bert);
1970        assert!(long_time_distilbert < long_time_bert);
1971
1972        // Longer texts should take more time
1973        assert!(long_time_distilbert > short_time_distilbert);
1974        assert!(long_time_bert > short_time_bert);
1975    }
1976
1977    #[test]
1978    fn test_model_specific_text_preprocessing() {
1979        let config = EmbeddingConfig::default();
1980
1981        let bert = SentenceTransformerGenerator::new(config.clone());
1982        let roberta = SentenceTransformerGenerator::roberta(config.clone());
1983        let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
1984
1985        let text = "Hello World";
1986
1987        // BERT should use [CLS] and [SEP] tokens and lowercase
1988        let bert_processed = bert.preprocess_text_for_model(text, 512).unwrap();
1989        assert!(bert_processed.contains("[CLS]"));
1990        assert!(bert_processed.contains("[SEP]"));
1991        assert!(bert_processed.contains("hello world")); // Should be lowercase
1992
1993        // RoBERTa should use <s> and </s> tokens and preserve case
1994        let roberta_processed = roberta.preprocess_text_for_model(text, 512).unwrap();
1995        assert!(roberta_processed.contains("<s>"));
1996        assert!(roberta_processed.contains("</s>"));
1997        assert!(roberta_processed.contains("Hello World")); // Should preserve case
1998
1999        // Multilingual BERT should handle different scripts appropriately
2000        let latin_text = "Hello World";
2001        let chinese_text = "你好世界";
2002
2003        let latin_processed = multibert
2004            .preprocess_text_for_model(latin_text, 512)
2005            .unwrap();
2006        let chinese_processed = multibert
2007            .preprocess_text_for_model(chinese_text, 512)
2008            .unwrap();
2009
2010        assert!(latin_processed.contains("hello world")); // Latin should be lowercase
2011        assert!(chinese_processed.contains("你好世界")); // Chinese should preserve characters
2012    }
2013
2014    #[test]
2015    fn test_embedding_generation_differences() {
2016        let config = EmbeddingConfig::default();
2017
2018        let bert = SentenceTransformerGenerator::new(config.clone());
2019        let roberta = SentenceTransformerGenerator::roberta(config.clone());
2020        let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
2021
2022        let content = EmbeddableContent::Text("This is a test sentence".to_string());
2023
2024        let bert_embedding = bert.generate(&content).unwrap();
2025        let roberta_embedding = roberta.generate(&content).unwrap();
2026        let distilbert_embedding = distilbert.generate(&content).unwrap();
2027
2028        // Embeddings should be different between models
2029        assert_ne!(bert_embedding.as_f32(), roberta_embedding.as_f32());
2030        assert_ne!(bert_embedding.as_f32(), distilbert_embedding.as_f32());
2031        assert_ne!(roberta_embedding.as_f32(), distilbert_embedding.as_f32());
2032
2033        // DistilBERT should have smaller dimensions
2034        assert_eq!(distilbert_embedding.dimensions, 384);
2035        assert_eq!(bert_embedding.dimensions, 384); // Using default config dimensions
2036        assert_eq!(roberta_embedding.dimensions, 384);
2037
2038        // All embeddings should be normalized if config specifies it
2039        if config.normalize {
2040            let bert_magnitude: f32 = bert_embedding
2041                .as_f32()
2042                .iter()
2043                .map(|x| x * x)
2044                .sum::<f32>()
2045                .sqrt();
2046            let roberta_magnitude: f32 = roberta_embedding
2047                .as_f32()
2048                .iter()
2049                .map(|x| x * x)
2050                .sum::<f32>()
2051                .sqrt();
2052            let distilbert_magnitude: f32 = distilbert_embedding
2053                .as_f32()
2054                .iter()
2055                .map(|x| x * x)
2056                .sum::<f32>()
2057                .sqrt();
2058
2059            assert!((bert_magnitude - 1.0).abs() < 0.1);
2060            assert!((roberta_magnitude - 1.0).abs() < 0.1);
2061            assert!((distilbert_magnitude - 1.0).abs() < 0.1);
2062        }
2063    }
2064
2065    #[test]
2066    fn test_tokenization_differences() {
2067        let config = EmbeddingConfig::default();
2068
2069        let bert = SentenceTransformerGenerator::new(config.clone());
2070        let roberta = SentenceTransformerGenerator::roberta(config.clone());
2071        let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
2072
2073        let model_details_bert = bert.get_model_details();
2074        let model_details_roberta = roberta.get_model_details();
2075        let model_details_multibert = multibert.get_model_details();
2076
2077        let complex_word = "preprocessing";
2078
2079        // Test different tokenization approaches
2080        let bert_tokens =
2081            bert.simulate_wordpiece_tokenization(complex_word, model_details_bert.vocab_size);
2082        let roberta_tokens =
2083            roberta.simulate_bpe_tokenization(complex_word, model_details_roberta.vocab_size);
2084        let multibert_tokens = multibert
2085            .simulate_multilingual_tokenization(complex_word, model_details_multibert.vocab_size);
2086
2087        // RoBERTa BPE should create more subword tokens for complex words
2088        assert!(roberta_tokens.len() >= bert_tokens.len());
2089
2090        // Multilingual BERT should create fewer subwords due to larger vocabulary
2091        assert!(multibert_tokens.len() <= bert_tokens.len());
2092
2093        // All tokenizations should produce valid token IDs
2094        for token in &bert_tokens {
2095            assert!(*token < model_details_bert.vocab_size as u32);
2096        }
2097        for token in &roberta_tokens {
2098            assert!(*token < model_details_roberta.vocab_size as u32);
2099        }
2100        for token in &multibert_tokens {
2101            assert!(*token < model_details_multibert.vocab_size as u32);
2102        }
2103    }
2104
2105    #[test]
2106    fn test_model_size_comparisons() {
2107        let config = EmbeddingConfig::default();
2108
2109        let bert = SentenceTransformerGenerator::new(config.clone());
2110        let roberta = SentenceTransformerGenerator::roberta(config.clone());
2111        let distilbert = SentenceTransformerGenerator::distilbert(config.clone());
2112        let multibert = SentenceTransformerGenerator::multilingual_bert(config.clone());
2113
2114        let bert_size = bert.model_size_mb();
2115        let roberta_size = roberta.model_size_mb();
2116        let distilbert_size = distilbert.model_size_mb();
2117        let multibert_size = multibert.model_size_mb();
2118
2119        // DistilBERT should be the smallest
2120        assert!(distilbert_size < bert_size);
2121        assert!(distilbert_size < roberta_size);
2122        assert!(distilbert_size < multibert_size);
2123
2124        // Multilingual BERT should be the largest
2125        assert!(multibert_size > bert_size);
2126        assert!(multibert_size > roberta_size);
2127        assert!(multibert_size > distilbert_size);
2128
2129        // RoBERTa should be slightly larger than BERT due to larger vocabulary
2130        assert!(roberta_size > bert_size);
2131    }
2132}