oxirs_embed/biomedical_embeddings/
text_model.rs

1//! Module for biomedical embeddings
2
3use crate::{ModelConfig, ModelStats, TrainingStats};
4use anyhow::Result;
5use chrono::Utc;
6use scirs2_core::ndarray_ext::Array1;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use uuid::Uuid;
10
11/// Specialized text embedding models for domain-specific applications
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub enum SpecializedTextModel {
14    /// SciBERT for scientific literature
15    SciBERT,
16    /// CodeBERT for code and programming
17    CodeBERT,
18    /// BioBERT for biomedical literature
19    BioBERT,
20    /// LegalBERT for legal documents
21    LegalBERT,
22    /// FinBERT for financial texts
23    FinBERT,
24    /// ClinicalBERT for clinical notes
25    ClinicalBERT,
26    /// ChemBERT for chemical compounds
27    ChemBERT,
28}
29
30impl SpecializedTextModel {
31    /// Get the model name for loading pre-trained weights
32    pub fn model_name(&self) -> &'static str {
33        match self {
34            SpecializedTextModel::SciBERT => "allenai/scibert_scivocab_uncased",
35            SpecializedTextModel::CodeBERT => "microsoft/codebert-base",
36            SpecializedTextModel::BioBERT => "dmis-lab/biobert-base-cased-v1.2",
37            SpecializedTextModel::LegalBERT => "nlpaueb/legal-bert-base-uncased",
38            SpecializedTextModel::FinBERT => "ProsusAI/finbert",
39            SpecializedTextModel::ClinicalBERT => "emilyalsentzer/Bio_ClinicalBERT",
40            SpecializedTextModel::ChemBERT => "seyonec/ChemBERTa-zinc-base-v1",
41        }
42    }
43
44    /// Get the vocabulary size for the model
45    pub fn vocab_size(&self) -> usize {
46        match self {
47            SpecializedTextModel::SciBERT => 31090,
48            SpecializedTextModel::CodeBERT => 50265,
49            SpecializedTextModel::BioBERT => 28996,
50            SpecializedTextModel::LegalBERT => 30522,
51            SpecializedTextModel::FinBERT => 30522,
52            SpecializedTextModel::ClinicalBERT => 28996,
53            SpecializedTextModel::ChemBERT => 600,
54        }
55    }
56
57    /// Get the default embedding dimension
58    pub fn embedding_dim(&self) -> usize {
59        match self {
60            SpecializedTextModel::SciBERT => 768,
61            SpecializedTextModel::CodeBERT => 768,
62            SpecializedTextModel::BioBERT => 768,
63            SpecializedTextModel::LegalBERT => 768,
64            SpecializedTextModel::FinBERT => 768,
65            SpecializedTextModel::ClinicalBERT => 768,
66            SpecializedTextModel::ChemBERT => 384,
67        }
68    }
69
70    /// Get the maximum sequence length
71    pub fn max_sequence_length(&self) -> usize {
72        match self {
73            SpecializedTextModel::SciBERT => 512,
74            SpecializedTextModel::CodeBERT => 512,
75            SpecializedTextModel::BioBERT => 512,
76            SpecializedTextModel::LegalBERT => 512,
77            SpecializedTextModel::FinBERT => 512,
78            SpecializedTextModel::ClinicalBERT => 512,
79            SpecializedTextModel::ChemBERT => 512,
80        }
81    }
82
83    /// Get domain-specific preprocessing rules
84    pub fn get_preprocessing_rules(&self) -> Vec<PreprocessingRule> {
85        match self {
86            SpecializedTextModel::SciBERT => vec![
87                PreprocessingRule::NormalizeScientificNotation,
88                PreprocessingRule::ExpandAbbreviations,
89                PreprocessingRule::HandleChemicalFormulas,
90                PreprocessingRule::PreserveCitations,
91            ],
92            SpecializedTextModel::CodeBERT => vec![
93                PreprocessingRule::PreserveCodeTokens,
94                PreprocessingRule::HandleCamelCase,
95                PreprocessingRule::NormalizeWhitespace,
96                PreprocessingRule::PreservePunctuation,
97            ],
98            SpecializedTextModel::BioBERT => vec![
99                PreprocessingRule::NormalizeMedicalTerms,
100                PreprocessingRule::HandleGeneNames,
101                PreprocessingRule::ExpandMedicalAbbreviations,
102                PreprocessingRule::PreserveDosages,
103            ],
104            SpecializedTextModel::LegalBERT => vec![
105                PreprocessingRule::PreserveLegalCitations,
106                PreprocessingRule::HandleLegalTerms,
107                PreprocessingRule::NormalizeCaseReferences,
108            ],
109            SpecializedTextModel::FinBERT => vec![
110                PreprocessingRule::NormalizeFinancialTerms,
111                PreprocessingRule::HandleCurrencySymbols,
112                PreprocessingRule::PreservePercentages,
113            ],
114            SpecializedTextModel::ClinicalBERT => vec![
115                PreprocessingRule::NormalizeMedicalTerms,
116                PreprocessingRule::HandleMedicalAbbreviations,
117                PreprocessingRule::PreserveDosages,
118                PreprocessingRule::NormalizeTimestamps,
119            ],
120            SpecializedTextModel::ChemBERT => vec![
121                PreprocessingRule::HandleChemicalFormulas,
122                PreprocessingRule::PreserveMolecularStructures,
123                PreprocessingRule::NormalizeChemicalNames,
124            ],
125        }
126    }
127}
128
129/// Preprocessing rules for specialized text models
130#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
131pub enum PreprocessingRule {
132    /// Normalize scientific notation (e.g., 1.23e-4)
133    NormalizeScientificNotation,
134    /// Expand domain-specific abbreviations
135    ExpandAbbreviations,
136    /// Handle chemical formulas and compounds
137    HandleChemicalFormulas,
138    /// Preserve citation formats
139    PreserveCitations,
140    /// Preserve code tokens and keywords
141    PreserveCodeTokens,
142    /// Handle camelCase and snake_case
143    HandleCamelCase,
144    /// Normalize whitespace patterns
145    NormalizeWhitespace,
146    /// Preserve punctuation in code
147    PreservePunctuation,
148    /// Normalize medical terminology
149    NormalizeMedicalTerms,
150    /// Handle gene and protein names
151    HandleGeneNames,
152    /// Expand medical abbreviations
153    ExpandMedicalAbbreviations,
154    /// Preserve dosage information
155    PreserveDosages,
156    /// Preserve legal citations
157    PreserveLegalCitations,
158    /// Handle legal terminology
159    HandleLegalTerms,
160    /// Normalize case references
161    NormalizeCaseReferences,
162    /// Normalize financial terms
163    NormalizeFinancialTerms,
164    /// Handle currency symbols
165    HandleCurrencySymbols,
166    /// Preserve percentage values
167    PreservePercentages,
168    /// Handle medical abbreviations
169    HandleMedicalAbbreviations,
170    /// Normalize timestamps
171    NormalizeTimestamps,
172    /// Preserve molecular structures
173    PreserveMolecularStructures,
174    /// Normalize chemical names
175    NormalizeChemicalNames,
176}
177
178/// Configuration for specialized text embeddings
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct SpecializedTextConfig {
181    pub model_type: SpecializedTextModel,
182    pub base_config: ModelConfig,
183    /// Fine-tuning configuration
184    pub fine_tune_config: FineTuningConfig,
185    /// Preprocessing configuration
186    pub preprocessing_enabled: bool,
187    /// Domain-specific vocabulary augmentation
188    pub vocab_augmentation: bool,
189    /// Use domain-specific pre-training
190    pub domain_pretraining: bool,
191}
192
193/// Fine-tuning configuration for specialized models
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct FineTuningConfig {
196    /// Learning rate for fine-tuning
197    pub learning_rate: f64,
198    /// Number of fine-tuning epochs
199    pub epochs: usize,
200    /// Freeze base model layers
201    pub freeze_base_layers: bool,
202    /// Number of layers to freeze
203    pub frozen_layers: usize,
204    /// Use gradual unfreezing
205    pub gradual_unfreezing: bool,
206    /// Discriminative fine-tuning rates
207    pub discriminative_rates: Vec<f64>,
208}
209
210impl Default for FineTuningConfig {
211    fn default() -> Self {
212        Self {
213            learning_rate: 2e-5,
214            epochs: 3,
215            freeze_base_layers: false,
216            frozen_layers: 0,
217            gradual_unfreezing: false,
218            discriminative_rates: vec![],
219        }
220    }
221}
222
223impl Default for SpecializedTextConfig {
224    fn default() -> Self {
225        Self {
226            model_type: SpecializedTextModel::BioBERT,
227            base_config: ModelConfig::default(),
228            fine_tune_config: FineTuningConfig::default(),
229            preprocessing_enabled: true,
230            vocab_augmentation: false,
231            domain_pretraining: false,
232        }
233    }
234}
235
236/// Specialized text embedding processor
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct SpecializedTextEmbedding {
239    pub config: SpecializedTextConfig,
240    pub model_id: Uuid,
241    /// Text embeddings cache
242    pub text_embeddings: HashMap<String, Array1<f32>>,
243    /// Domain-specific vocabulary
244    pub domain_vocab: HashSet<String>,
245    /// Preprocessing pipeline
246    pub preprocessing_rules: Vec<PreprocessingRule>,
247    /// Training statistics
248    pub training_stats: TrainingStats,
249    /// Model statistics
250    pub model_stats: ModelStats,
251    pub is_trained: bool,
252}
253
254impl SpecializedTextEmbedding {
255    /// Create new specialized text embedding model
256    pub fn new(config: SpecializedTextConfig) -> Self {
257        let model_id = Uuid::new_v4();
258        let now = Utc::now();
259        let preprocessing_rules = config.model_type.get_preprocessing_rules();
260
261        Self {
262            model_id,
263            text_embeddings: HashMap::new(),
264            domain_vocab: HashSet::new(),
265            preprocessing_rules,
266            training_stats: TrainingStats::default(),
267            model_stats: ModelStats {
268                num_entities: 0,
269                num_relations: 0,
270                num_triples: 0,
271                dimensions: config.model_type.embedding_dim(),
272                is_trained: false,
273                model_type: format!("SpecializedText_{:?}", config.model_type),
274                creation_time: now,
275                last_training_time: None,
276            },
277            is_trained: false,
278            config,
279        }
280    }
281
282    /// Create SciBERT configuration
283    pub fn scibert_config() -> SpecializedTextConfig {
284        SpecializedTextConfig {
285            model_type: SpecializedTextModel::SciBERT,
286            base_config: ModelConfig::default().with_dimensions(768),
287            fine_tune_config: FineTuningConfig::default(),
288            preprocessing_enabled: true,
289            vocab_augmentation: true,
290            domain_pretraining: true,
291        }
292    }
293
294    /// Create CodeBERT configuration
295    pub fn codebert_config() -> SpecializedTextConfig {
296        SpecializedTextConfig {
297            model_type: SpecializedTextModel::CodeBERT,
298            base_config: ModelConfig::default().with_dimensions(768),
299            fine_tune_config: FineTuningConfig::default(),
300            preprocessing_enabled: true,
301            vocab_augmentation: false,
302            domain_pretraining: true,
303        }
304    }
305
306    /// Create BioBERT configuration
307    pub fn biobert_config() -> SpecializedTextConfig {
308        SpecializedTextConfig {
309            model_type: SpecializedTextModel::BioBERT,
310            base_config: ModelConfig::default().with_dimensions(768),
311            fine_tune_config: FineTuningConfig {
312                learning_rate: 1e-5,
313                epochs: 5,
314                freeze_base_layers: true,
315                frozen_layers: 6,
316                gradual_unfreezing: true,
317                discriminative_rates: vec![1e-6, 5e-6, 1e-5, 2e-5],
318            },
319            preprocessing_enabled: true,
320            vocab_augmentation: true,
321            domain_pretraining: true,
322        }
323    }
324
325    /// Preprocess text according to domain-specific rules
326    pub fn preprocess_text(&self, text: &str) -> Result<String> {
327        if !self.config.preprocessing_enabled {
328            return Ok(text.to_string());
329        }
330
331        let mut processed = text.to_string();
332
333        for rule in &self.preprocessing_rules {
334            processed = self.apply_preprocessing_rule(&processed, rule)?;
335        }
336
337        Ok(processed)
338    }
339
340    /// Apply a specific preprocessing rule
341    fn apply_preprocessing_rule(&self, text: &str, rule: &PreprocessingRule) -> Result<String> {
342        match rule {
343            PreprocessingRule::NormalizeScientificNotation => {
344                // Convert scientific notation to normalized form (simplified)
345                Ok(text
346                    .replace("E+", "e+")
347                    .replace("E-", "e-")
348                    .replace("E", "e"))
349            }
350            PreprocessingRule::HandleChemicalFormulas => {
351                // Preserve chemical formulas by adding special tokens (simplified)
352                Ok(text.replace("H2O", "[CHEM]H2O[/CHEM]"))
353            }
354            PreprocessingRule::HandleCamelCase => {
355                // Split camelCase into separate tokens (simplified)
356                let mut result = String::new();
357                let mut chars = text.chars().peekable();
358                while let Some(c) = chars.next() {
359                    result.push(c);
360                    if c.is_lowercase() && chars.peek().is_some_and(|&next| next.is_uppercase()) {
361                        result.push(' ');
362                    }
363                }
364                Ok(result)
365            }
366            PreprocessingRule::NormalizeMedicalTerms => {
367                // Normalize common medical abbreviations
368                let mut result = text.to_string();
369                let replacements = [
370                    ("mg/kg", "milligrams per kilogram"),
371                    ("q.d.", "once daily"),
372                    ("b.i.d.", "twice daily"),
373                    ("t.i.d.", "three times daily"),
374                    ("q.i.d.", "four times daily"),
375                ];
376
377                for (abbrev, expansion) in &replacements {
378                    result = result.replace(abbrev, expansion);
379                }
380                Ok(result)
381            }
382            PreprocessingRule::HandleGeneNames => {
383                // Standardize gene name formatting (simplified)
384                Ok(text
385                    .replace("BRCA1", "[GENE]BRCA1[/GENE]")
386                    .replace("TP53", "[GENE]TP53[/GENE]"))
387            }
388            PreprocessingRule::PreserveCodeTokens => {
389                // Preserve code-like tokens (simplified)
390                Ok(text.replace("function", "[CODE]function[/CODE]"))
391            }
392            _ => {
393                // Placeholder for other rules - would implement in production
394                Ok(text.to_string())
395            }
396        }
397    }
398
399    /// Generate embedding for text using specialized model
400    pub async fn encode_text(&mut self, text: &str) -> Result<Array1<f32>> {
401        // Preprocess the text
402        let processed_text = self.preprocess_text(text)?;
403
404        // Check cache first
405        if let Some(cached_embedding) = self.text_embeddings.get(&processed_text) {
406            return Ok(cached_embedding.clone());
407        }
408
409        // Generate embedding using domain-specific model
410        let embedding = self.generate_specialized_embedding(&processed_text).await?;
411
412        // Cache the result
413        self.text_embeddings
414            .insert(processed_text, embedding.clone());
415
416        Ok(embedding)
417    }
418
419    /// Generate specialized embedding for the specific domain
420    async fn generate_specialized_embedding(&self, text: &str) -> Result<Array1<f32>> {
421        // In a real implementation, this would use the actual pre-trained model
422        // For now, simulate domain-specific embeddings with enhanced features
423
424        let embedding_dim = self.config.model_type.embedding_dim();
425        let mut embedding = vec![0.0; embedding_dim];
426
427        // Domain-specific feature extraction
428        match self.config.model_type {
429            SpecializedTextModel::SciBERT => {
430                // Scientific text features: citations, formulas, terminology
431                embedding[0] = if text.contains("et al.") { 1.0 } else { 0.0 };
432                embedding[1] = if text.contains("figure") || text.contains("table") {
433                    1.0
434                } else {
435                    0.0
436                };
437                embedding[2] = text.matches(char::is_numeric).count() as f32 / text.len() as f32;
438            }
439            SpecializedTextModel::CodeBERT => {
440                // Code features: keywords, operators, structures
441                embedding[0] = if text.contains("function") || text.contains("def") {
442                    1.0
443                } else {
444                    0.0
445                };
446                embedding[1] = if text.contains("class") || text.contains("struct") {
447                    1.0
448                } else {
449                    0.0
450                };
451                embedding[2] =
452                    text.matches(|c: char| "{}()[]".contains(c)).count() as f32 / text.len() as f32;
453            }
454            SpecializedTextModel::BioBERT => {
455                // Biomedical features: genes, proteins, diseases
456                embedding[0] = if text.contains("protein") || text.contains("gene") {
457                    1.0
458                } else {
459                    0.0
460                };
461                embedding[1] = if text.contains("disease") || text.contains("syndrome") {
462                    1.0
463                } else {
464                    0.0
465                };
466                embedding[2] = if text.contains("mg") || text.contains("dose") {
467                    1.0
468                } else {
469                    0.0
470                };
471            }
472            _ => {
473                // Generic specialized features
474                embedding[0] = text.len() as f32 / 1000.0; // Length normalization
475                embedding[1] = text.split_whitespace().count() as f32 / text.len() as f32;
476                // Word density
477            }
478        }
479
480        // Fill remaining dimensions with text-based features
481        for (i, item) in embedding.iter_mut().enumerate().take(embedding_dim).skip(3) {
482            let byte_val = text.as_bytes().get(i % text.len()).copied().unwrap_or(0) as f32;
483            *item = (byte_val / 255.0 - 0.5) * 2.0; // Normalize to [-1, 1]
484        }
485
486        // Apply domain-specific transformations
487        if self.config.domain_pretraining {
488            for val in &mut embedding {
489                *val *= 1.2; // Amplify features for domain-pretrained models
490            }
491        }
492
493        Ok(Array1::from_vec(embedding))
494    }
495
496    /// Fine-tune the model on domain-specific data
497    pub async fn fine_tune(&mut self, training_texts: Vec<String>) -> Result<TrainingStats> {
498        let start_time = std::time::Instant::now();
499        let epochs = self.config.fine_tune_config.epochs;
500
501        let mut loss_history = Vec::new();
502
503        for epoch in 0..epochs {
504            let mut epoch_loss = 0.0;
505
506            for text in &training_texts {
507                // Generate embedding and compute loss
508                let embedding = self.encode_text(text).await?;
509
510                // Simplified fine-tuning loss computation
511                let target_variance = 0.1; // Target embedding variance
512                let actual_variance = embedding.var(0.0);
513                let loss = (actual_variance - target_variance).powi(2);
514                epoch_loss += loss;
515            }
516
517            epoch_loss /= training_texts.len() as f32;
518            loss_history.push(epoch_loss as f64);
519
520            if epoch % 10 == 0 {
521                println!("Fine-tuning epoch {epoch}: Loss = {epoch_loss:.6}");
522            }
523        }
524
525        let training_time = start_time.elapsed().as_secs_f64();
526
527        self.training_stats = TrainingStats {
528            epochs_completed: epochs,
529            final_loss: loss_history.last().copied().unwrap_or(0.0),
530            training_time_seconds: training_time,
531            convergence_achieved: loss_history.last().is_some_and(|&loss| loss < 0.01),
532            loss_history,
533        };
534
535        self.is_trained = true;
536        self.model_stats.is_trained = true;
537        self.model_stats.last_training_time = Some(Utc::now());
538
539        Ok(self.training_stats.clone())
540    }
541
542    /// Get model statistics
543    pub fn get_stats(&self) -> ModelStats {
544        self.model_stats.clone()
545    }
546
547    /// Clear cached embeddings
548    pub fn clear_cache(&mut self) {
549        self.text_embeddings.clear();
550    }
551}
552
553// Simplified regex-like functionality for preprocessing
554#[allow(dead_code)]
555mod regex {
556    #[allow(dead_code)]
557    pub struct Regex(String);
558
559    impl Regex {
560        #[allow(dead_code)]
561        pub fn new(pattern: &str) -> Result<Self, &'static str> {
562            Ok(Regex(pattern.to_string()))
563        }
564
565        #[allow(dead_code)]
566        pub fn replace_all<'a, F>(&self, text: &'a str, _rep: F) -> std::borrow::Cow<'a, str>
567        where
568            F: Fn(&str) -> String,
569        {
570            // Simplified regex replacement for demo - just return original text
571            std::borrow::Cow::Borrowed(text)
572        }
573    }
574}
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579    use crate::biomedical_embeddings::types::{
580        BiomedicalEmbedding, BiomedicalEmbeddingConfig, BiomedicalEntityType,
581    };
582
583    #[test]
584    fn test_biomedical_entity_type_from_iri() {
585        assert_eq!(
586            BiomedicalEntityType::from_iri("http://example.org/gene/BRCA1"),
587            Some(BiomedicalEntityType::Gene)
588        );
589        assert_eq!(
590            BiomedicalEntityType::from_iri("http://example.org/disease/cancer"),
591            Some(BiomedicalEntityType::Disease)
592        );
593        assert_eq!(
594            BiomedicalEntityType::from_iri("http://example.org/drug/aspirin"),
595            Some(BiomedicalEntityType::Drug)
596        );
597    }
598
599    #[test]
600    fn test_biomedical_config_default() {
601        let config = BiomedicalEmbeddingConfig::default();
602        assert_eq!(config.gene_disease_weight, 2.0);
603        assert_eq!(config.drug_target_weight, 1.5);
604        assert!(config.use_sequence_similarity);
605        assert_eq!(config.species_filter, Some("Homo sapiens".to_string()));
606    }
607
608    #[test]
609    fn test_biomedical_embedding_creation() {
610        let config = BiomedicalEmbeddingConfig::default();
611        let model = BiomedicalEmbedding::new(config);
612
613        assert_eq!(model.model_type(), "BiomedicalEmbedding");
614        assert!(!model.is_trained());
615        assert_eq!(model.gene_embeddings.len(), 0);
616    }
617
618    #[test]
619    fn test_gene_disease_association() {
620        let mut model = BiomedicalEmbedding::new(BiomedicalEmbeddingConfig::default());
621
622        model.add_gene_disease_association("BRCA1", "breast_cancer", 0.8);
623
624        assert_eq!(
625            model
626                .features
627                .gene_disease_associations
628                .get(&("BRCA1".to_string(), "breast_cancer".to_string())),
629            Some(&0.8)
630        );
631    }
632
633    #[test]
634    fn test_drug_target_interaction() {
635        let mut model = BiomedicalEmbedding::new(BiomedicalEmbeddingConfig::default());
636
637        model.add_drug_target_interaction("aspirin", "COX1", 0.9);
638
639        assert_eq!(
640            model
641                .features
642                .drug_target_affinities
643                .get(&("aspirin".to_string(), "COX1".to_string())),
644            Some(&0.9)
645        );
646    }
647
648    #[test]
649    fn test_specialized_text_model_properties() {
650        let scibert = SpecializedTextModel::SciBERT;
651        assert_eq!(scibert.model_name(), "allenai/scibert_scivocab_uncased");
652        assert_eq!(scibert.vocab_size(), 31090);
653        assert_eq!(scibert.embedding_dim(), 768);
654        assert_eq!(scibert.max_sequence_length(), 512);
655
656        let codebert = SpecializedTextModel::CodeBERT;
657        assert_eq!(codebert.model_name(), "microsoft/codebert-base");
658        assert_eq!(codebert.vocab_size(), 50265);
659
660        let biobert = SpecializedTextModel::BioBERT;
661        assert_eq!(biobert.model_name(), "dmis-lab/biobert-base-cased-v1.2");
662        assert_eq!(biobert.vocab_size(), 28996);
663    }
664
665    #[test]
666    fn test_specialized_text_preprocessing_rules() {
667        let scibert = SpecializedTextModel::SciBERT;
668        let rules = scibert.get_preprocessing_rules();
669        assert!(rules.contains(&PreprocessingRule::NormalizeScientificNotation));
670        assert!(rules.contains(&PreprocessingRule::HandleChemicalFormulas));
671
672        let codebert = SpecializedTextModel::CodeBERT;
673        let rules = codebert.get_preprocessing_rules();
674        assert!(rules.contains(&PreprocessingRule::PreserveCodeTokens));
675        assert!(rules.contains(&PreprocessingRule::HandleCamelCase));
676
677        let biobert = SpecializedTextModel::BioBERT;
678        let rules = biobert.get_preprocessing_rules();
679        assert!(rules.contains(&PreprocessingRule::NormalizeMedicalTerms));
680        assert!(rules.contains(&PreprocessingRule::HandleGeneNames));
681    }
682
683    #[test]
684    fn test_specialized_text_config_factory_methods() {
685        let scibert_config = SpecializedTextEmbedding::scibert_config();
686        assert_eq!(scibert_config.model_type, SpecializedTextModel::SciBERT);
687        assert_eq!(scibert_config.base_config.dimensions, 768);
688        assert!(scibert_config.preprocessing_enabled);
689        assert!(scibert_config.vocab_augmentation);
690        assert!(scibert_config.domain_pretraining);
691
692        let codebert_config = SpecializedTextEmbedding::codebert_config();
693        assert_eq!(codebert_config.model_type, SpecializedTextModel::CodeBERT);
694        assert!(!codebert_config.vocab_augmentation);
695
696        let biobert_config = SpecializedTextEmbedding::biobert_config();
697        assert_eq!(biobert_config.model_type, SpecializedTextModel::BioBERT);
698        assert!(biobert_config.fine_tune_config.freeze_base_layers);
699        assert_eq!(biobert_config.fine_tune_config.frozen_layers, 6);
700        assert!(biobert_config.fine_tune_config.gradual_unfreezing);
701    }
702
703    #[test]
704    fn test_specialized_text_embedding_creation() {
705        let config = SpecializedTextEmbedding::scibert_config();
706        let model = SpecializedTextEmbedding::new(config);
707
708        assert!(model.model_stats.model_type.contains("SciBERT"));
709        assert_eq!(model.model_stats.dimensions, 768);
710        assert!(!model.is_trained);
711        assert_eq!(model.text_embeddings.len(), 0);
712        assert_eq!(model.preprocessing_rules.len(), 4); // SciBERT has 4 rules
713    }
714
715    #[test]
716    fn test_preprocessing_medical_terms() {
717        let config = SpecializedTextEmbedding::biobert_config();
718        let model = SpecializedTextEmbedding::new(config);
719
720        let text = "Patient takes 100 mg/kg b.i.d. for treatment";
721        let processed = model.preprocess_text(text).unwrap();
722
723        // Should expand medical abbreviations
724        assert!(processed.contains("milligrams per kilogram"));
725        assert!(processed.contains("twice daily"));
726    }
727
728    #[test]
729    fn test_preprocessing_disabled() {
730        let mut config = SpecializedTextEmbedding::biobert_config();
731        config.preprocessing_enabled = false;
732        let model = SpecializedTextEmbedding::new(config);
733
734        let text = "Patient takes 100 mg/kg b.i.d. for treatment";
735        let processed = model.preprocess_text(text).unwrap();
736
737        // Should be unchanged when preprocessing is disabled
738        assert_eq!(processed, text);
739    }
740
741    #[tokio::test]
742    async fn test_specialized_text_encoding() {
743        let config = SpecializedTextEmbedding::scibert_config();
744        let mut model = SpecializedTextEmbedding::new(config);
745
746        let text = "The protein folding study shows significant results with p < 0.001";
747        let embedding = model.encode_text(text).await.unwrap();
748
749        assert_eq!(embedding.len(), 768);
750
751        // Test caching - second call should return cached result
752        let embedding2 = model.encode_text(text).await.unwrap();
753        assert_eq!(embedding.to_vec(), embedding2.to_vec());
754        assert_eq!(model.text_embeddings.len(), 1);
755    }
756
757    #[tokio::test]
758    async fn test_domain_specific_features() {
759        // Test SciBERT features
760        let config = SpecializedTextEmbedding::scibert_config();
761        let mut model = SpecializedTextEmbedding::new(config);
762
763        let scientific_text = "The study by Smith et al. shows figure 1 demonstrates the results";
764        let embedding = model.encode_text(scientific_text).await.unwrap();
765
766        // Should detect scientific features (citations, figures)
767        // Values are amplified by 1.2 due to domain pretraining
768        assert_eq!(embedding[0], 1.2); // et al. detected, amplified
769        assert_eq!(embedding[1], 1.2); // figure detected, amplified
770
771        // Test CodeBERT features
772        let config = SpecializedTextEmbedding::codebert_config();
773        let mut model = SpecializedTextEmbedding::new(config);
774
775        let code_text = "function calculateSum() { return a + b; }";
776        let embedding = model.encode_text(code_text).await.unwrap();
777
778        // Should detect code features (amplified by domain pretraining)
779        assert_eq!(embedding[0], 1.2); // function detected, amplified
780        assert!(embedding[2] > 0.0); // brackets detected (text-based features)
781
782        // Test BioBERT features
783        let config = SpecializedTextEmbedding::biobert_config();
784        let mut model = SpecializedTextEmbedding::new(config);
785
786        let biomedical_text =
787            "The protein expression correlates with cancer disease progression, dose 100mg";
788        let embedding = model.encode_text(biomedical_text).await.unwrap();
789
790        // Should detect biomedical features (amplified by domain pretraining)
791        assert_eq!(embedding[0], 1.2); // protein detected, amplified
792        assert_eq!(embedding[1], 1.2); // disease detected, amplified
793        assert_eq!(embedding[2], 1.2); // mg detected, amplified
794    }
795
796    #[tokio::test]
797    async fn test_fine_tuning() {
798        let config = SpecializedTextEmbedding::biobert_config();
799        let mut model = SpecializedTextEmbedding::new(config);
800
801        let training_texts = vec![
802            "Gene expression analysis in cancer cells".to_string(),
803            "Protein folding mechanisms in disease".to_string(),
804            "Drug interaction with target proteins".to_string(),
805        ];
806
807        let stats = model.fine_tune(training_texts).await.unwrap();
808
809        assert!(model.is_trained);
810        assert_eq!(stats.epochs_completed, 5); // BioBERT config has 5 epochs
811        assert!(stats.training_time_seconds > 0.0);
812        assert!(!stats.loss_history.is_empty());
813        assert!(model.model_stats.is_trained);
814        assert!(model.model_stats.last_training_time.is_some());
815    }
816}