oxirs_embed/biomedical_embeddings/
text_model.rs

1//! Module for biomedical embeddings
2
3use crate::{ModelConfig, ModelStats, TrainingStats};
4use anyhow::Result;
5use chrono::Utc;
6use scirs2_core::ndarray_ext::Array1;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use uuid::Uuid;
10
11/// Specialized text embedding models for domain-specific applications
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub enum SpecializedTextModel {
14    /// SciBERT for scientific literature
15    SciBERT,
16    /// CodeBERT for code and programming
17    CodeBERT,
18    /// BioBERT for biomedical literature
19    BioBERT,
20    /// LegalBERT for legal documents
21    LegalBERT,
22    /// FinBERT for financial texts
23    FinBERT,
24    /// ClinicalBERT for clinical notes
25    ClinicalBERT,
26    /// ChemBERT for chemical compounds
27    ChemBERT,
28}
29
30impl SpecializedTextModel {
31    /// Get the model name for loading pre-trained weights
32    pub fn model_name(&self) -> &'static str {
33        match self {
34            SpecializedTextModel::SciBERT => "allenai/scibert_scivocab_uncased",
35            SpecializedTextModel::CodeBERT => "microsoft/codebert-base",
36            SpecializedTextModel::BioBERT => "dmis-lab/biobert-base-cased-v1.2",
37            SpecializedTextModel::LegalBERT => "nlpaueb/legal-bert-base-uncased",
38            SpecializedTextModel::FinBERT => "ProsusAI/finbert",
39            SpecializedTextModel::ClinicalBERT => "emilyalsentzer/Bio_ClinicalBERT",
40            SpecializedTextModel::ChemBERT => "seyonec/ChemBERTa-zinc-base-v1",
41        }
42    }
43
44    /// Get the vocabulary size for the model
45    pub fn vocab_size(&self) -> usize {
46        match self {
47            SpecializedTextModel::SciBERT => 31090,
48            SpecializedTextModel::CodeBERT => 50265,
49            SpecializedTextModel::BioBERT => 28996,
50            SpecializedTextModel::LegalBERT => 30522,
51            SpecializedTextModel::FinBERT => 30522,
52            SpecializedTextModel::ClinicalBERT => 28996,
53            SpecializedTextModel::ChemBERT => 600,
54        }
55    }
56
57    /// Get the default embedding dimension
58    pub fn embedding_dim(&self) -> usize {
59        match self {
60            SpecializedTextModel::SciBERT => 768,
61            SpecializedTextModel::CodeBERT => 768,
62            SpecializedTextModel::BioBERT => 768,
63            SpecializedTextModel::LegalBERT => 768,
64            SpecializedTextModel::FinBERT => 768,
65            SpecializedTextModel::ClinicalBERT => 768,
66            SpecializedTextModel::ChemBERT => 384,
67        }
68    }
69
70    /// Get the maximum sequence length
71    pub fn max_sequence_length(&self) -> usize {
72        match self {
73            SpecializedTextModel::SciBERT => 512,
74            SpecializedTextModel::CodeBERT => 512,
75            SpecializedTextModel::BioBERT => 512,
76            SpecializedTextModel::LegalBERT => 512,
77            SpecializedTextModel::FinBERT => 512,
78            SpecializedTextModel::ClinicalBERT => 512,
79            SpecializedTextModel::ChemBERT => 512,
80        }
81    }
82
83    /// Get domain-specific preprocessing rules
84    pub fn get_preprocessing_rules(&self) -> Vec<PreprocessingRule> {
85        match self {
86            SpecializedTextModel::SciBERT => vec![
87                PreprocessingRule::NormalizeScientificNotation,
88                PreprocessingRule::ExpandAbbreviations,
89                PreprocessingRule::HandleChemicalFormulas,
90                PreprocessingRule::PreserveCitations,
91            ],
92            SpecializedTextModel::CodeBERT => vec![
93                PreprocessingRule::PreserveCodeTokens,
94                PreprocessingRule::HandleCamelCase,
95                PreprocessingRule::NormalizeWhitespace,
96                PreprocessingRule::PreservePunctuation,
97            ],
98            SpecializedTextModel::BioBERT => vec![
99                PreprocessingRule::NormalizeMedicalTerms,
100                PreprocessingRule::HandleGeneNames,
101                PreprocessingRule::ExpandMedicalAbbreviations,
102                PreprocessingRule::PreserveDosages,
103            ],
104            SpecializedTextModel::LegalBERT => vec![
105                PreprocessingRule::PreserveLegalCitations,
106                PreprocessingRule::HandleLegalTerms,
107                PreprocessingRule::NormalizeCaseReferences,
108            ],
109            SpecializedTextModel::FinBERT => vec![
110                PreprocessingRule::NormalizeFinancialTerms,
111                PreprocessingRule::HandleCurrencySymbols,
112                PreprocessingRule::PreservePercentages,
113            ],
114            SpecializedTextModel::ClinicalBERT => vec![
115                PreprocessingRule::NormalizeMedicalTerms,
116                PreprocessingRule::HandleMedicalAbbreviations,
117                PreprocessingRule::PreserveDosages,
118                PreprocessingRule::NormalizeTimestamps,
119            ],
120            SpecializedTextModel::ChemBERT => vec![
121                PreprocessingRule::HandleChemicalFormulas,
122                PreprocessingRule::PreserveMolecularStructures,
123                PreprocessingRule::NormalizeChemicalNames,
124            ],
125        }
126    }
127}
128
129/// Preprocessing rules for specialized text models
130#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
131pub enum PreprocessingRule {
132    /// Normalize scientific notation (e.g., 1.23e-4)
133    NormalizeScientificNotation,
134    /// Expand domain-specific abbreviations
135    ExpandAbbreviations,
136    /// Handle chemical formulas and compounds
137    HandleChemicalFormulas,
138    /// Preserve citation formats
139    PreserveCitations,
140    /// Preserve code tokens and keywords
141    PreserveCodeTokens,
142    /// Handle camelCase and snake_case
143    HandleCamelCase,
144    /// Normalize whitespace patterns
145    NormalizeWhitespace,
146    /// Preserve punctuation in code
147    PreservePunctuation,
148    /// Normalize medical terminology
149    NormalizeMedicalTerms,
150    /// Handle gene and protein names
151    HandleGeneNames,
152    /// Expand medical abbreviations
153    ExpandMedicalAbbreviations,
154    /// Preserve dosage information
155    PreserveDosages,
156    /// Preserve legal citations
157    PreserveLegalCitations,
158    /// Handle legal terminology
159    HandleLegalTerms,
160    /// Normalize case references
161    NormalizeCaseReferences,
162    /// Normalize financial terms
163    NormalizeFinancialTerms,
164    /// Handle currency symbols
165    HandleCurrencySymbols,
166    /// Preserve percentage values
167    PreservePercentages,
168    /// Handle medical abbreviations
169    HandleMedicalAbbreviations,
170    /// Normalize timestamps
171    NormalizeTimestamps,
172    /// Preserve molecular structures
173    PreserveMolecularStructures,
174    /// Normalize chemical names
175    NormalizeChemicalNames,
176}
177
178/// Configuration for specialized text embeddings
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct SpecializedTextConfig {
181    pub model_type: SpecializedTextModel,
182    pub base_config: ModelConfig,
183    /// Fine-tuning configuration
184    pub fine_tune_config: FineTuningConfig,
185    /// Preprocessing configuration
186    pub preprocessing_enabled: bool,
187    /// Domain-specific vocabulary augmentation
188    pub vocab_augmentation: bool,
189    /// Use domain-specific pre-training
190    pub domain_pretraining: bool,
191}
192
193/// Fine-tuning configuration for specialized models
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct FineTuningConfig {
196    /// Learning rate for fine-tuning
197    pub learning_rate: f64,
198    /// Number of fine-tuning epochs
199    pub epochs: usize,
200    /// Freeze base model layers
201    pub freeze_base_layers: bool,
202    /// Number of layers to freeze
203    pub frozen_layers: usize,
204    /// Use gradual unfreezing
205    pub gradual_unfreezing: bool,
206    /// Discriminative fine-tuning rates
207    pub discriminative_rates: Vec<f64>,
208}
209
210impl Default for FineTuningConfig {
211    fn default() -> Self {
212        Self {
213            learning_rate: 2e-5,
214            epochs: 3,
215            freeze_base_layers: false,
216            frozen_layers: 0,
217            gradual_unfreezing: false,
218            discriminative_rates: vec![],
219        }
220    }
221}
222
223impl Default for SpecializedTextConfig {
224    fn default() -> Self {
225        Self {
226            model_type: SpecializedTextModel::BioBERT,
227            base_config: ModelConfig::default(),
228            fine_tune_config: FineTuningConfig::default(),
229            preprocessing_enabled: true,
230            vocab_augmentation: false,
231            domain_pretraining: false,
232        }
233    }
234}
235
236/// Specialized text embedding processor
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct SpecializedTextEmbedding {
239    pub config: SpecializedTextConfig,
240    pub model_id: Uuid,
241    /// Text embeddings cache
242    pub text_embeddings: HashMap<String, Array1<f32>>,
243    /// Domain-specific vocabulary
244    pub domain_vocab: HashSet<String>,
245    /// Preprocessing pipeline
246    pub preprocessing_rules: Vec<PreprocessingRule>,
247    /// Training statistics
248    pub training_stats: TrainingStats,
249    /// Model statistics
250    pub model_stats: ModelStats,
251    pub is_trained: bool,
252}
253
254impl SpecializedTextEmbedding {
255    /// Create new specialized text embedding model
256    pub fn new(config: SpecializedTextConfig) -> Self {
257        let model_id = Uuid::new_v4();
258        let now = Utc::now();
259        let preprocessing_rules = config.model_type.get_preprocessing_rules();
260
261        Self {
262            model_id,
263            text_embeddings: HashMap::new(),
264            domain_vocab: HashSet::new(),
265            preprocessing_rules,
266            training_stats: TrainingStats::default(),
267            model_stats: ModelStats {
268                num_entities: 0,
269                num_relations: 0,
270                num_triples: 0,
271                dimensions: config.model_type.embedding_dim(),
272                is_trained: false,
273                model_type: format!("SpecializedText_{:?}", config.model_type),
274                creation_time: now,
275                last_training_time: None,
276            },
277            is_trained: false,
278            config,
279        }
280    }
281
282    /// Create SciBERT configuration
283    pub fn scibert_config() -> SpecializedTextConfig {
284        SpecializedTextConfig {
285            model_type: SpecializedTextModel::SciBERT,
286            base_config: ModelConfig::default().with_dimensions(768),
287            fine_tune_config: FineTuningConfig::default(),
288            preprocessing_enabled: true,
289            vocab_augmentation: true,
290            domain_pretraining: true,
291        }
292    }
293
294    /// Create CodeBERT configuration
295    pub fn codebert_config() -> SpecializedTextConfig {
296        SpecializedTextConfig {
297            model_type: SpecializedTextModel::CodeBERT,
298            base_config: ModelConfig::default().with_dimensions(768),
299            fine_tune_config: FineTuningConfig::default(),
300            preprocessing_enabled: true,
301            vocab_augmentation: false,
302            domain_pretraining: true,
303        }
304    }
305
306    /// Create BioBERT configuration
307    pub fn biobert_config() -> SpecializedTextConfig {
308        SpecializedTextConfig {
309            model_type: SpecializedTextModel::BioBERT,
310            base_config: ModelConfig::default().with_dimensions(768),
311            fine_tune_config: FineTuningConfig {
312                learning_rate: 1e-5,
313                epochs: 5,
314                freeze_base_layers: true,
315                frozen_layers: 6,
316                gradual_unfreezing: true,
317                discriminative_rates: vec![1e-6, 5e-6, 1e-5, 2e-5],
318            },
319            preprocessing_enabled: true,
320            vocab_augmentation: true,
321            domain_pretraining: true,
322        }
323    }
324
325    /// Preprocess text according to domain-specific rules
326    pub fn preprocess_text(&self, text: &str) -> Result<String> {
327        if !self.config.preprocessing_enabled {
328            return Ok(text.to_string());
329        }
330
331        let mut processed = text.to_string();
332
333        for rule in &self.preprocessing_rules {
334            processed = self.apply_preprocessing_rule(&processed, rule)?;
335        }
336
337        Ok(processed)
338    }
339
340    /// Apply a specific preprocessing rule
341    fn apply_preprocessing_rule(&self, text: &str, rule: &PreprocessingRule) -> Result<String> {
342        match rule {
343            PreprocessingRule::NormalizeScientificNotation => {
344                // Convert scientific notation to normalized form (simplified)
345                Ok(text
346                    .replace("E+", "e+")
347                    .replace("E-", "e-")
348                    .replace("E", "e"))
349            }
350            PreprocessingRule::HandleChemicalFormulas => {
351                // Preserve chemical formulas by adding special tokens (simplified)
352                Ok(text.replace("H2O", "[CHEM]H2O[/CHEM]"))
353            }
354            PreprocessingRule::HandleCamelCase => {
355                // Split camelCase into separate tokens (simplified)
356                let mut result = String::new();
357                let mut chars = text.chars().peekable();
358                while let Some(c) = chars.next() {
359                    result.push(c);
360                    if c.is_lowercase() && chars.peek().is_some_and(|&next| next.is_uppercase()) {
361                        result.push(' ');
362                    }
363                }
364                Ok(result)
365            }
366            PreprocessingRule::NormalizeMedicalTerms => {
367                // Normalize common medical abbreviations
368                let mut result = text.to_string();
369                let replacements = [
370                    ("mg/kg", "milligrams per kilogram"),
371                    ("q.d.", "once daily"),
372                    ("b.i.d.", "twice daily"),
373                    ("t.i.d.", "three times daily"),
374                    ("q.i.d.", "four times daily"),
375                ];
376
377                for (abbrev, expansion) in &replacements {
378                    result = result.replace(abbrev, expansion);
379                }
380                Ok(result)
381            }
382            PreprocessingRule::HandleGeneNames => {
383                // Standardize gene name formatting (simplified)
384                Ok(text
385                    .replace("BRCA1", "[GENE]BRCA1[/GENE]")
386                    .replace("TP53", "[GENE]TP53[/GENE]"))
387            }
388            PreprocessingRule::PreserveCodeTokens => {
389                // Preserve code-like tokens (simplified)
390                Ok(text.replace("function", "[CODE]function[/CODE]"))
391            }
392            _ => {
393                // Placeholder for other rules - would implement in production
394                Ok(text.to_string())
395            }
396        }
397    }
398
399    /// Generate embedding for text using specialized model
400    pub async fn encode_text(&mut self, text: &str) -> Result<Array1<f32>> {
401        // Preprocess the text
402        let processed_text = self.preprocess_text(text)?;
403
404        // Check cache first
405        if let Some(cached_embedding) = self.text_embeddings.get(&processed_text) {
406            return Ok(cached_embedding.clone());
407        }
408
409        // Generate embedding using domain-specific model
410        let embedding = self.generate_specialized_embedding(&processed_text).await?;
411
412        // Cache the result
413        self.text_embeddings
414            .insert(processed_text, embedding.clone());
415
416        Ok(embedding)
417    }
418
419    /// Generate specialized embedding for the specific domain
420    async fn generate_specialized_embedding(&self, text: &str) -> Result<Array1<f32>> {
421        // In a real implementation, this would use the actual pre-trained model
422        // For now, simulate domain-specific embeddings with enhanced features
423
424        let embedding_dim = self.config.model_type.embedding_dim();
425        let mut embedding = vec![0.0; embedding_dim];
426
427        // Domain-specific feature extraction
428        match self.config.model_type {
429            SpecializedTextModel::SciBERT => {
430                // Scientific text features: citations, formulas, terminology
431                embedding[0] = if text.contains("et al.") { 1.0 } else { 0.0 };
432                embedding[1] = if text.contains("figure") || text.contains("table") {
433                    1.0
434                } else {
435                    0.0
436                };
437                embedding[2] = text.matches(char::is_numeric).count() as f32 / text.len() as f32;
438            }
439            SpecializedTextModel::CodeBERT => {
440                // Code features: keywords, operators, structures
441                embedding[0] = if text.contains("function") || text.contains("def") {
442                    1.0
443                } else {
444                    0.0
445                };
446                embedding[1] = if text.contains("class") || text.contains("struct") {
447                    1.0
448                } else {
449                    0.0
450                };
451                embedding[2] =
452                    text.matches(|c: char| "{}()[]".contains(c)).count() as f32 / text.len() as f32;
453            }
454            SpecializedTextModel::BioBERT => {
455                // Biomedical features: genes, proteins, diseases
456                embedding[0] = if text.contains("protein") || text.contains("gene") {
457                    1.0
458                } else {
459                    0.0
460                };
461                embedding[1] = if text.contains("disease") || text.contains("syndrome") {
462                    1.0
463                } else {
464                    0.0
465                };
466                embedding[2] = if text.contains("mg") || text.contains("dose") {
467                    1.0
468                } else {
469                    0.0
470                };
471            }
472            _ => {
473                // Generic specialized features
474                embedding[0] = text.len() as f32 / 1000.0; // Length normalization
475                embedding[1] = text.split_whitespace().count() as f32 / text.len() as f32;
476                // Word density
477            }
478        }
479
480        // Fill remaining dimensions with text-based features
481        for (i, item) in embedding.iter_mut().enumerate().take(embedding_dim).skip(3) {
482            let byte_val = text.as_bytes().get(i % text.len()).copied().unwrap_or(0) as f32;
483            *item = (byte_val / 255.0 - 0.5) * 2.0; // Normalize to [-1, 1]
484        }
485
486        // Apply domain-specific transformations
487        if self.config.domain_pretraining {
488            for val in &mut embedding {
489                *val *= 1.2; // Amplify features for domain-pretrained models
490            }
491        }
492
493        Ok(Array1::from_vec(embedding))
494    }
495
496    /// Fine-tune the model on domain-specific data
497    pub async fn fine_tune(&mut self, training_texts: Vec<String>) -> Result<TrainingStats> {
498        let start_time = std::time::Instant::now();
499        let epochs = self.config.fine_tune_config.epochs;
500
501        let mut loss_history = Vec::new();
502
503        for epoch in 0..epochs {
504            let mut epoch_loss = 0.0;
505
506            for text in &training_texts {
507                // Generate embedding and compute loss
508                let embedding = self.encode_text(text).await?;
509
510                // Simplified fine-tuning loss computation
511                let target_variance = 0.1; // Target embedding variance
512                let actual_variance = embedding.var(0.0);
513                let loss = (actual_variance - target_variance).powi(2);
514                epoch_loss += loss;
515            }
516
517            epoch_loss /= training_texts.len() as f32;
518            loss_history.push(epoch_loss as f64);
519
520            if epoch % 10 == 0 {
521                println!("Fine-tuning epoch {epoch}: Loss = {epoch_loss:.6}");
522            }
523        }
524
525        let training_time = start_time.elapsed().as_secs_f64();
526
527        self.training_stats = TrainingStats {
528            epochs_completed: epochs,
529            final_loss: loss_history.last().copied().unwrap_or(0.0),
530            training_time_seconds: training_time,
531            convergence_achieved: loss_history.last().is_some_and(|&loss| loss < 0.01),
532            loss_history,
533        };
534
535        self.is_trained = true;
536        self.model_stats.is_trained = true;
537        self.model_stats.last_training_time = Some(Utc::now());
538
539        Ok(self.training_stats.clone())
540    }
541
542    /// Get model statistics
543    pub fn get_stats(&self) -> ModelStats {
544        self.model_stats.clone()
545    }
546
547    /// Clear cached embeddings
548    pub fn clear_cache(&mut self) {
549        self.text_embeddings.clear();
550    }
551}
552
553// Simplified regex-like functionality for preprocessing
554#[allow(dead_code)]
555mod regex {
556    #[allow(dead_code)]
557    pub struct Regex(String);
558
559    impl Regex {
560        #[allow(dead_code)]
561        pub fn new(pattern: &str) -> Result<Self, &'static str> {
562            Ok(Regex(pattern.to_string()))
563        }
564
565        #[allow(dead_code)]
566        pub fn replace_all<'a, F>(&self, text: &'a str, _rep: F) -> std::borrow::Cow<'a, str>
567        where
568            F: Fn(&str) -> String,
569        {
570            // Simplified regex replacement for demo - just return original text
571            std::borrow::Cow::Borrowed(text)
572        }
573    }
574}
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579    use crate::biomedical_embeddings::types::{BiomedicalEntityType, BiomedicalEmbeddingConfig, BiomedicalEmbedding};
580
581    #[test]
582    fn test_biomedical_entity_type_from_iri() {
583        assert_eq!(
584            BiomedicalEntityType::from_iri("http://example.org/gene/BRCA1"),
585            Some(BiomedicalEntityType::Gene)
586        );
587        assert_eq!(
588            BiomedicalEntityType::from_iri("http://example.org/disease/cancer"),
589            Some(BiomedicalEntityType::Disease)
590        );
591        assert_eq!(
592            BiomedicalEntityType::from_iri("http://example.org/drug/aspirin"),
593            Some(BiomedicalEntityType::Drug)
594        );
595    }
596
597    #[test]
598    fn test_biomedical_config_default() {
599        let config = BiomedicalEmbeddingConfig::default();
600        assert_eq!(config.gene_disease_weight, 2.0);
601        assert_eq!(config.drug_target_weight, 1.5);
602        assert!(config.use_sequence_similarity);
603        assert_eq!(config.species_filter, Some("Homo sapiens".to_string()));
604    }
605
606    #[test]
607    fn test_biomedical_embedding_creation() {
608        let config = BiomedicalEmbeddingConfig::default();
609        let model = BiomedicalEmbedding::new(config);
610
611        assert_eq!(model.model_type(), "BiomedicalEmbedding");
612        assert!(!model.is_trained());
613        assert_eq!(model.gene_embeddings.len(), 0);
614    }
615
616    #[test]
617    fn test_gene_disease_association() {
618        let mut model = BiomedicalEmbedding::new(BiomedicalEmbeddingConfig::default());
619
620        model.add_gene_disease_association("BRCA1", "breast_cancer", 0.8);
621
622        assert_eq!(
623            model
624                .features
625                .gene_disease_associations
626                .get(&("BRCA1".to_string(), "breast_cancer".to_string())),
627            Some(&0.8)
628        );
629    }
630
631    #[test]
632    fn test_drug_target_interaction() {
633        let mut model = BiomedicalEmbedding::new(BiomedicalEmbeddingConfig::default());
634
635        model.add_drug_target_interaction("aspirin", "COX1", 0.9);
636
637        assert_eq!(
638            model
639                .features
640                .drug_target_affinities
641                .get(&("aspirin".to_string(), "COX1".to_string())),
642            Some(&0.9)
643        );
644    }
645
646    #[test]
647    fn test_specialized_text_model_properties() {
648        let scibert = SpecializedTextModel::SciBERT;
649        assert_eq!(scibert.model_name(), "allenai/scibert_scivocab_uncased");
650        assert_eq!(scibert.vocab_size(), 31090);
651        assert_eq!(scibert.embedding_dim(), 768);
652        assert_eq!(scibert.max_sequence_length(), 512);
653
654        let codebert = SpecializedTextModel::CodeBERT;
655        assert_eq!(codebert.model_name(), "microsoft/codebert-base");
656        assert_eq!(codebert.vocab_size(), 50265);
657
658        let biobert = SpecializedTextModel::BioBERT;
659        assert_eq!(biobert.model_name(), "dmis-lab/biobert-base-cased-v1.2");
660        assert_eq!(biobert.vocab_size(), 28996);
661    }
662
663    #[test]
664    fn test_specialized_text_preprocessing_rules() {
665        let scibert = SpecializedTextModel::SciBERT;
666        let rules = scibert.get_preprocessing_rules();
667        assert!(rules.contains(&PreprocessingRule::NormalizeScientificNotation));
668        assert!(rules.contains(&PreprocessingRule::HandleChemicalFormulas));
669
670        let codebert = SpecializedTextModel::CodeBERT;
671        let rules = codebert.get_preprocessing_rules();
672        assert!(rules.contains(&PreprocessingRule::PreserveCodeTokens));
673        assert!(rules.contains(&PreprocessingRule::HandleCamelCase));
674
675        let biobert = SpecializedTextModel::BioBERT;
676        let rules = biobert.get_preprocessing_rules();
677        assert!(rules.contains(&PreprocessingRule::NormalizeMedicalTerms));
678        assert!(rules.contains(&PreprocessingRule::HandleGeneNames));
679    }
680
681    #[test]
682    fn test_specialized_text_config_factory_methods() {
683        let scibert_config = SpecializedTextEmbedding::scibert_config();
684        assert_eq!(scibert_config.model_type, SpecializedTextModel::SciBERT);
685        assert_eq!(scibert_config.base_config.dimensions, 768);
686        assert!(scibert_config.preprocessing_enabled);
687        assert!(scibert_config.vocab_augmentation);
688        assert!(scibert_config.domain_pretraining);
689
690        let codebert_config = SpecializedTextEmbedding::codebert_config();
691        assert_eq!(codebert_config.model_type, SpecializedTextModel::CodeBERT);
692        assert!(!codebert_config.vocab_augmentation);
693
694        let biobert_config = SpecializedTextEmbedding::biobert_config();
695        assert_eq!(biobert_config.model_type, SpecializedTextModel::BioBERT);
696        assert!(biobert_config.fine_tune_config.freeze_base_layers);
697        assert_eq!(biobert_config.fine_tune_config.frozen_layers, 6);
698        assert!(biobert_config.fine_tune_config.gradual_unfreezing);
699    }
700
701    #[test]
702    fn test_specialized_text_embedding_creation() {
703        let config = SpecializedTextEmbedding::scibert_config();
704        let model = SpecializedTextEmbedding::new(config);
705
706        assert!(model.model_stats.model_type.contains("SciBERT"));
707        assert_eq!(model.model_stats.dimensions, 768);
708        assert!(!model.is_trained);
709        assert_eq!(model.text_embeddings.len(), 0);
710        assert_eq!(model.preprocessing_rules.len(), 4); // SciBERT has 4 rules
711    }
712
713    #[test]
714    fn test_preprocessing_medical_terms() {
715        let config = SpecializedTextEmbedding::biobert_config();
716        let model = SpecializedTextEmbedding::new(config);
717
718        let text = "Patient takes 100 mg/kg b.i.d. for treatment";
719        let processed = model.preprocess_text(text).unwrap();
720
721        // Should expand medical abbreviations
722        assert!(processed.contains("milligrams per kilogram"));
723        assert!(processed.contains("twice daily"));
724    }
725
726    #[test]
727    fn test_preprocessing_disabled() {
728        let mut config = SpecializedTextEmbedding::biobert_config();
729        config.preprocessing_enabled = false;
730        let model = SpecializedTextEmbedding::new(config);
731
732        let text = "Patient takes 100 mg/kg b.i.d. for treatment";
733        let processed = model.preprocess_text(text).unwrap();
734
735        // Should be unchanged when preprocessing is disabled
736        assert_eq!(processed, text);
737    }
738
739    #[tokio::test]
740    async fn test_specialized_text_encoding() {
741        let config = SpecializedTextEmbedding::scibert_config();
742        let mut model = SpecializedTextEmbedding::new(config);
743
744        let text = "The protein folding study shows significant results with p < 0.001";
745        let embedding = model.encode_text(text).await.unwrap();
746
747        assert_eq!(embedding.len(), 768);
748
749        // Test caching - second call should return cached result
750        let embedding2 = model.encode_text(text).await.unwrap();
751        assert_eq!(embedding.to_vec(), embedding2.to_vec());
752        assert_eq!(model.text_embeddings.len(), 1);
753    }
754
755    #[tokio::test]
756    async fn test_domain_specific_features() {
757        // Test SciBERT features
758        let config = SpecializedTextEmbedding::scibert_config();
759        let mut model = SpecializedTextEmbedding::new(config);
760
761        let scientific_text = "The study by Smith et al. shows figure 1 demonstrates the results";
762        let embedding = model.encode_text(scientific_text).await.unwrap();
763
764        // Should detect scientific features (citations, figures)
765        // Values are amplified by 1.2 due to domain pretraining
766        assert_eq!(embedding[0], 1.2); // et al. detected, amplified
767        assert_eq!(embedding[1], 1.2); // figure detected, amplified
768
769        // Test CodeBERT features
770        let config = SpecializedTextEmbedding::codebert_config();
771        let mut model = SpecializedTextEmbedding::new(config);
772
773        let code_text = "function calculateSum() { return a + b; }";
774        let embedding = model.encode_text(code_text).await.unwrap();
775
776        // Should detect code features (amplified by domain pretraining)
777        assert_eq!(embedding[0], 1.2); // function detected, amplified
778        assert!(embedding[2] > 0.0); // brackets detected (text-based features)
779
780        // Test BioBERT features
781        let config = SpecializedTextEmbedding::biobert_config();
782        let mut model = SpecializedTextEmbedding::new(config);
783
784        let biomedical_text =
785            "The protein expression correlates with cancer disease progression, dose 100mg";
786        let embedding = model.encode_text(biomedical_text).await.unwrap();
787
788        // Should detect biomedical features (amplified by domain pretraining)
789        assert_eq!(embedding[0], 1.2); // protein detected, amplified
790        assert_eq!(embedding[1], 1.2); // disease detected, amplified
791        assert_eq!(embedding[2], 1.2); // mg detected, amplified
792    }
793
794    #[tokio::test]
795    async fn test_fine_tuning() {
796        let config = SpecializedTextEmbedding::biobert_config();
797        let mut model = SpecializedTextEmbedding::new(config);
798
799        let training_texts = vec![
800            "Gene expression analysis in cancer cells".to_string(),
801            "Protein folding mechanisms in disease".to_string(),
802            "Drug interaction with target proteins".to_string(),
803        ];
804
805        let stats = model.fine_tune(training_texts).await.unwrap();
806
807        assert!(model.is_trained);
808        assert_eq!(stats.epochs_completed, 5); // BioBERT config has 5 epochs
809        assert!(stats.training_time_seconds > 0.0);
810        assert!(!stats.loss_history.is_empty());
811        assert!(model.model_stats.is_trained);
812        assert!(model.model_stats.last_training_time.is_some());
813    }
814}