1use crate::{ModelConfig, ModelStats, TrainingStats};
4use anyhow::Result;
5use chrono::Utc;
6use scirs2_core::ndarray_ext::Array1;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use uuid::Uuid;
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub enum SpecializedTextModel {
14 SciBERT,
16 CodeBERT,
18 BioBERT,
20 LegalBERT,
22 FinBERT,
24 ClinicalBERT,
26 ChemBERT,
28}
29
30impl SpecializedTextModel {
31 pub fn model_name(&self) -> &'static str {
33 match self {
34 SpecializedTextModel::SciBERT => "allenai/scibert_scivocab_uncased",
35 SpecializedTextModel::CodeBERT => "microsoft/codebert-base",
36 SpecializedTextModel::BioBERT => "dmis-lab/biobert-base-cased-v1.2",
37 SpecializedTextModel::LegalBERT => "nlpaueb/legal-bert-base-uncased",
38 SpecializedTextModel::FinBERT => "ProsusAI/finbert",
39 SpecializedTextModel::ClinicalBERT => "emilyalsentzer/Bio_ClinicalBERT",
40 SpecializedTextModel::ChemBERT => "seyonec/ChemBERTa-zinc-base-v1",
41 }
42 }
43
44 pub fn vocab_size(&self) -> usize {
46 match self {
47 SpecializedTextModel::SciBERT => 31090,
48 SpecializedTextModel::CodeBERT => 50265,
49 SpecializedTextModel::BioBERT => 28996,
50 SpecializedTextModel::LegalBERT => 30522,
51 SpecializedTextModel::FinBERT => 30522,
52 SpecializedTextModel::ClinicalBERT => 28996,
53 SpecializedTextModel::ChemBERT => 600,
54 }
55 }
56
57 pub fn embedding_dim(&self) -> usize {
59 match self {
60 SpecializedTextModel::SciBERT => 768,
61 SpecializedTextModel::CodeBERT => 768,
62 SpecializedTextModel::BioBERT => 768,
63 SpecializedTextModel::LegalBERT => 768,
64 SpecializedTextModel::FinBERT => 768,
65 SpecializedTextModel::ClinicalBERT => 768,
66 SpecializedTextModel::ChemBERT => 384,
67 }
68 }
69
70 pub fn max_sequence_length(&self) -> usize {
72 match self {
73 SpecializedTextModel::SciBERT => 512,
74 SpecializedTextModel::CodeBERT => 512,
75 SpecializedTextModel::BioBERT => 512,
76 SpecializedTextModel::LegalBERT => 512,
77 SpecializedTextModel::FinBERT => 512,
78 SpecializedTextModel::ClinicalBERT => 512,
79 SpecializedTextModel::ChemBERT => 512,
80 }
81 }
82
83 pub fn get_preprocessing_rules(&self) -> Vec<PreprocessingRule> {
85 match self {
86 SpecializedTextModel::SciBERT => vec![
87 PreprocessingRule::NormalizeScientificNotation,
88 PreprocessingRule::ExpandAbbreviations,
89 PreprocessingRule::HandleChemicalFormulas,
90 PreprocessingRule::PreserveCitations,
91 ],
92 SpecializedTextModel::CodeBERT => vec![
93 PreprocessingRule::PreserveCodeTokens,
94 PreprocessingRule::HandleCamelCase,
95 PreprocessingRule::NormalizeWhitespace,
96 PreprocessingRule::PreservePunctuation,
97 ],
98 SpecializedTextModel::BioBERT => vec![
99 PreprocessingRule::NormalizeMedicalTerms,
100 PreprocessingRule::HandleGeneNames,
101 PreprocessingRule::ExpandMedicalAbbreviations,
102 PreprocessingRule::PreserveDosages,
103 ],
104 SpecializedTextModel::LegalBERT => vec![
105 PreprocessingRule::PreserveLegalCitations,
106 PreprocessingRule::HandleLegalTerms,
107 PreprocessingRule::NormalizeCaseReferences,
108 ],
109 SpecializedTextModel::FinBERT => vec![
110 PreprocessingRule::NormalizeFinancialTerms,
111 PreprocessingRule::HandleCurrencySymbols,
112 PreprocessingRule::PreservePercentages,
113 ],
114 SpecializedTextModel::ClinicalBERT => vec![
115 PreprocessingRule::NormalizeMedicalTerms,
116 PreprocessingRule::HandleMedicalAbbreviations,
117 PreprocessingRule::PreserveDosages,
118 PreprocessingRule::NormalizeTimestamps,
119 ],
120 SpecializedTextModel::ChemBERT => vec![
121 PreprocessingRule::HandleChemicalFormulas,
122 PreprocessingRule::PreserveMolecularStructures,
123 PreprocessingRule::NormalizeChemicalNames,
124 ],
125 }
126 }
127}
128
129#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
131pub enum PreprocessingRule {
132 NormalizeScientificNotation,
134 ExpandAbbreviations,
136 HandleChemicalFormulas,
138 PreserveCitations,
140 PreserveCodeTokens,
142 HandleCamelCase,
144 NormalizeWhitespace,
146 PreservePunctuation,
148 NormalizeMedicalTerms,
150 HandleGeneNames,
152 ExpandMedicalAbbreviations,
154 PreserveDosages,
156 PreserveLegalCitations,
158 HandleLegalTerms,
160 NormalizeCaseReferences,
162 NormalizeFinancialTerms,
164 HandleCurrencySymbols,
166 PreservePercentages,
168 HandleMedicalAbbreviations,
170 NormalizeTimestamps,
172 PreserveMolecularStructures,
174 NormalizeChemicalNames,
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct SpecializedTextConfig {
181 pub model_type: SpecializedTextModel,
182 pub base_config: ModelConfig,
183 pub fine_tune_config: FineTuningConfig,
185 pub preprocessing_enabled: bool,
187 pub vocab_augmentation: bool,
189 pub domain_pretraining: bool,
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct FineTuningConfig {
196 pub learning_rate: f64,
198 pub epochs: usize,
200 pub freeze_base_layers: bool,
202 pub frozen_layers: usize,
204 pub gradual_unfreezing: bool,
206 pub discriminative_rates: Vec<f64>,
208}
209
210impl Default for FineTuningConfig {
211 fn default() -> Self {
212 Self {
213 learning_rate: 2e-5,
214 epochs: 3,
215 freeze_base_layers: false,
216 frozen_layers: 0,
217 gradual_unfreezing: false,
218 discriminative_rates: vec![],
219 }
220 }
221}
222
223impl Default for SpecializedTextConfig {
224 fn default() -> Self {
225 Self {
226 model_type: SpecializedTextModel::BioBERT,
227 base_config: ModelConfig::default(),
228 fine_tune_config: FineTuningConfig::default(),
229 preprocessing_enabled: true,
230 vocab_augmentation: false,
231 domain_pretraining: false,
232 }
233 }
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct SpecializedTextEmbedding {
239 pub config: SpecializedTextConfig,
240 pub model_id: Uuid,
241 pub text_embeddings: HashMap<String, Array1<f32>>,
243 pub domain_vocab: HashSet<String>,
245 pub preprocessing_rules: Vec<PreprocessingRule>,
247 pub training_stats: TrainingStats,
249 pub model_stats: ModelStats,
251 pub is_trained: bool,
252}
253
254impl SpecializedTextEmbedding {
255 pub fn new(config: SpecializedTextConfig) -> Self {
257 let model_id = Uuid::new_v4();
258 let now = Utc::now();
259 let preprocessing_rules = config.model_type.get_preprocessing_rules();
260
261 Self {
262 model_id,
263 text_embeddings: HashMap::new(),
264 domain_vocab: HashSet::new(),
265 preprocessing_rules,
266 training_stats: TrainingStats::default(),
267 model_stats: ModelStats {
268 num_entities: 0,
269 num_relations: 0,
270 num_triples: 0,
271 dimensions: config.model_type.embedding_dim(),
272 is_trained: false,
273 model_type: format!("SpecializedText_{:?}", config.model_type),
274 creation_time: now,
275 last_training_time: None,
276 },
277 is_trained: false,
278 config,
279 }
280 }
281
282 pub fn scibert_config() -> SpecializedTextConfig {
284 SpecializedTextConfig {
285 model_type: SpecializedTextModel::SciBERT,
286 base_config: ModelConfig::default().with_dimensions(768),
287 fine_tune_config: FineTuningConfig::default(),
288 preprocessing_enabled: true,
289 vocab_augmentation: true,
290 domain_pretraining: true,
291 }
292 }
293
294 pub fn codebert_config() -> SpecializedTextConfig {
296 SpecializedTextConfig {
297 model_type: SpecializedTextModel::CodeBERT,
298 base_config: ModelConfig::default().with_dimensions(768),
299 fine_tune_config: FineTuningConfig::default(),
300 preprocessing_enabled: true,
301 vocab_augmentation: false,
302 domain_pretraining: true,
303 }
304 }
305
306 pub fn biobert_config() -> SpecializedTextConfig {
308 SpecializedTextConfig {
309 model_type: SpecializedTextModel::BioBERT,
310 base_config: ModelConfig::default().with_dimensions(768),
311 fine_tune_config: FineTuningConfig {
312 learning_rate: 1e-5,
313 epochs: 5,
314 freeze_base_layers: true,
315 frozen_layers: 6,
316 gradual_unfreezing: true,
317 discriminative_rates: vec![1e-6, 5e-6, 1e-5, 2e-5],
318 },
319 preprocessing_enabled: true,
320 vocab_augmentation: true,
321 domain_pretraining: true,
322 }
323 }
324
325 pub fn preprocess_text(&self, text: &str) -> Result<String> {
327 if !self.config.preprocessing_enabled {
328 return Ok(text.to_string());
329 }
330
331 let mut processed = text.to_string();
332
333 for rule in &self.preprocessing_rules {
334 processed = self.apply_preprocessing_rule(&processed, rule)?;
335 }
336
337 Ok(processed)
338 }
339
340 fn apply_preprocessing_rule(&self, text: &str, rule: &PreprocessingRule) -> Result<String> {
342 match rule {
343 PreprocessingRule::NormalizeScientificNotation => {
344 Ok(text
346 .replace("E+", "e+")
347 .replace("E-", "e-")
348 .replace("E", "e"))
349 }
350 PreprocessingRule::HandleChemicalFormulas => {
351 Ok(text.replace("H2O", "[CHEM]H2O[/CHEM]"))
353 }
354 PreprocessingRule::HandleCamelCase => {
355 let mut result = String::new();
357 let mut chars = text.chars().peekable();
358 while let Some(c) = chars.next() {
359 result.push(c);
360 if c.is_lowercase() && chars.peek().is_some_and(|&next| next.is_uppercase()) {
361 result.push(' ');
362 }
363 }
364 Ok(result)
365 }
366 PreprocessingRule::NormalizeMedicalTerms => {
367 let mut result = text.to_string();
369 let replacements = [
370 ("mg/kg", "milligrams per kilogram"),
371 ("q.d.", "once daily"),
372 ("b.i.d.", "twice daily"),
373 ("t.i.d.", "three times daily"),
374 ("q.i.d.", "four times daily"),
375 ];
376
377 for (abbrev, expansion) in &replacements {
378 result = result.replace(abbrev, expansion);
379 }
380 Ok(result)
381 }
382 PreprocessingRule::HandleGeneNames => {
383 Ok(text
385 .replace("BRCA1", "[GENE]BRCA1[/GENE]")
386 .replace("TP53", "[GENE]TP53[/GENE]"))
387 }
388 PreprocessingRule::PreserveCodeTokens => {
389 Ok(text.replace("function", "[CODE]function[/CODE]"))
391 }
392 _ => {
393 Ok(text.to_string())
395 }
396 }
397 }
398
399 pub async fn encode_text(&mut self, text: &str) -> Result<Array1<f32>> {
401 let processed_text = self.preprocess_text(text)?;
403
404 if let Some(cached_embedding) = self.text_embeddings.get(&processed_text) {
406 return Ok(cached_embedding.clone());
407 }
408
409 let embedding = self.generate_specialized_embedding(&processed_text).await?;
411
412 self.text_embeddings
414 .insert(processed_text, embedding.clone());
415
416 Ok(embedding)
417 }
418
419 async fn generate_specialized_embedding(&self, text: &str) -> Result<Array1<f32>> {
421 let embedding_dim = self.config.model_type.embedding_dim();
425 let mut embedding = vec![0.0; embedding_dim];
426
427 match self.config.model_type {
429 SpecializedTextModel::SciBERT => {
430 embedding[0] = if text.contains("et al.") { 1.0 } else { 0.0 };
432 embedding[1] = if text.contains("figure") || text.contains("table") {
433 1.0
434 } else {
435 0.0
436 };
437 embedding[2] = text.matches(char::is_numeric).count() as f32 / text.len() as f32;
438 }
439 SpecializedTextModel::CodeBERT => {
440 embedding[0] = if text.contains("function") || text.contains("def") {
442 1.0
443 } else {
444 0.0
445 };
446 embedding[1] = if text.contains("class") || text.contains("struct") {
447 1.0
448 } else {
449 0.0
450 };
451 embedding[2] =
452 text.matches(|c: char| "{}()[]".contains(c)).count() as f32 / text.len() as f32;
453 }
454 SpecializedTextModel::BioBERT => {
455 embedding[0] = if text.contains("protein") || text.contains("gene") {
457 1.0
458 } else {
459 0.0
460 };
461 embedding[1] = if text.contains("disease") || text.contains("syndrome") {
462 1.0
463 } else {
464 0.0
465 };
466 embedding[2] = if text.contains("mg") || text.contains("dose") {
467 1.0
468 } else {
469 0.0
470 };
471 }
472 _ => {
473 embedding[0] = text.len() as f32 / 1000.0; embedding[1] = text.split_whitespace().count() as f32 / text.len() as f32;
476 }
478 }
479
480 for (i, item) in embedding.iter_mut().enumerate().take(embedding_dim).skip(3) {
482 let byte_val = text.as_bytes().get(i % text.len()).copied().unwrap_or(0) as f32;
483 *item = (byte_val / 255.0 - 0.5) * 2.0; }
485
486 if self.config.domain_pretraining {
488 for val in &mut embedding {
489 *val *= 1.2; }
491 }
492
493 Ok(Array1::from_vec(embedding))
494 }
495
496 pub async fn fine_tune(&mut self, training_texts: Vec<String>) -> Result<TrainingStats> {
498 let start_time = std::time::Instant::now();
499 let epochs = self.config.fine_tune_config.epochs;
500
501 let mut loss_history = Vec::new();
502
503 for epoch in 0..epochs {
504 let mut epoch_loss = 0.0;
505
506 for text in &training_texts {
507 let embedding = self.encode_text(text).await?;
509
510 let target_variance = 0.1; let actual_variance = embedding.var(0.0);
513 let loss = (actual_variance - target_variance).powi(2);
514 epoch_loss += loss;
515 }
516
517 epoch_loss /= training_texts.len() as f32;
518 loss_history.push(epoch_loss as f64);
519
520 if epoch % 10 == 0 {
521 println!("Fine-tuning epoch {epoch}: Loss = {epoch_loss:.6}");
522 }
523 }
524
525 let training_time = start_time.elapsed().as_secs_f64();
526
527 self.training_stats = TrainingStats {
528 epochs_completed: epochs,
529 final_loss: loss_history.last().copied().unwrap_or(0.0),
530 training_time_seconds: training_time,
531 convergence_achieved: loss_history.last().is_some_and(|&loss| loss < 0.01),
532 loss_history,
533 };
534
535 self.is_trained = true;
536 self.model_stats.is_trained = true;
537 self.model_stats.last_training_time = Some(Utc::now());
538
539 Ok(self.training_stats.clone())
540 }
541
542 pub fn get_stats(&self) -> ModelStats {
544 self.model_stats.clone()
545 }
546
547 pub fn clear_cache(&mut self) {
549 self.text_embeddings.clear();
550 }
551}
552
553#[allow(dead_code)]
555mod regex {
556 #[allow(dead_code)]
557 pub struct Regex(String);
558
559 impl Regex {
560 #[allow(dead_code)]
561 pub fn new(pattern: &str) -> Result<Self, &'static str> {
562 Ok(Regex(pattern.to_string()))
563 }
564
565 #[allow(dead_code)]
566 pub fn replace_all<'a, F>(&self, text: &'a str, _rep: F) -> std::borrow::Cow<'a, str>
567 where
568 F: Fn(&str) -> String,
569 {
570 std::borrow::Cow::Borrowed(text)
572 }
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579 use crate::biomedical_embeddings::types::{
580 BiomedicalEmbedding, BiomedicalEmbeddingConfig, BiomedicalEntityType,
581 };
582
583 #[test]
584 fn test_biomedical_entity_type_from_iri() {
585 assert_eq!(
586 BiomedicalEntityType::from_iri("http://example.org/gene/BRCA1"),
587 Some(BiomedicalEntityType::Gene)
588 );
589 assert_eq!(
590 BiomedicalEntityType::from_iri("http://example.org/disease/cancer"),
591 Some(BiomedicalEntityType::Disease)
592 );
593 assert_eq!(
594 BiomedicalEntityType::from_iri("http://example.org/drug/aspirin"),
595 Some(BiomedicalEntityType::Drug)
596 );
597 }
598
599 #[test]
600 fn test_biomedical_config_default() {
601 let config = BiomedicalEmbeddingConfig::default();
602 assert_eq!(config.gene_disease_weight, 2.0);
603 assert_eq!(config.drug_target_weight, 1.5);
604 assert!(config.use_sequence_similarity);
605 assert_eq!(config.species_filter, Some("Homo sapiens".to_string()));
606 }
607
608 #[test]
609 fn test_biomedical_embedding_creation() {
610 let config = BiomedicalEmbeddingConfig::default();
611 let model = BiomedicalEmbedding::new(config);
612
613 assert_eq!(model.model_type(), "BiomedicalEmbedding");
614 assert!(!model.is_trained());
615 assert_eq!(model.gene_embeddings.len(), 0);
616 }
617
618 #[test]
619 fn test_gene_disease_association() {
620 let mut model = BiomedicalEmbedding::new(BiomedicalEmbeddingConfig::default());
621
622 model.add_gene_disease_association("BRCA1", "breast_cancer", 0.8);
623
624 assert_eq!(
625 model
626 .features
627 .gene_disease_associations
628 .get(&("BRCA1".to_string(), "breast_cancer".to_string())),
629 Some(&0.8)
630 );
631 }
632
633 #[test]
634 fn test_drug_target_interaction() {
635 let mut model = BiomedicalEmbedding::new(BiomedicalEmbeddingConfig::default());
636
637 model.add_drug_target_interaction("aspirin", "COX1", 0.9);
638
639 assert_eq!(
640 model
641 .features
642 .drug_target_affinities
643 .get(&("aspirin".to_string(), "COX1".to_string())),
644 Some(&0.9)
645 );
646 }
647
648 #[test]
649 fn test_specialized_text_model_properties() {
650 let scibert = SpecializedTextModel::SciBERT;
651 assert_eq!(scibert.model_name(), "allenai/scibert_scivocab_uncased");
652 assert_eq!(scibert.vocab_size(), 31090);
653 assert_eq!(scibert.embedding_dim(), 768);
654 assert_eq!(scibert.max_sequence_length(), 512);
655
656 let codebert = SpecializedTextModel::CodeBERT;
657 assert_eq!(codebert.model_name(), "microsoft/codebert-base");
658 assert_eq!(codebert.vocab_size(), 50265);
659
660 let biobert = SpecializedTextModel::BioBERT;
661 assert_eq!(biobert.model_name(), "dmis-lab/biobert-base-cased-v1.2");
662 assert_eq!(biobert.vocab_size(), 28996);
663 }
664
665 #[test]
666 fn test_specialized_text_preprocessing_rules() {
667 let scibert = SpecializedTextModel::SciBERT;
668 let rules = scibert.get_preprocessing_rules();
669 assert!(rules.contains(&PreprocessingRule::NormalizeScientificNotation));
670 assert!(rules.contains(&PreprocessingRule::HandleChemicalFormulas));
671
672 let codebert = SpecializedTextModel::CodeBERT;
673 let rules = codebert.get_preprocessing_rules();
674 assert!(rules.contains(&PreprocessingRule::PreserveCodeTokens));
675 assert!(rules.contains(&PreprocessingRule::HandleCamelCase));
676
677 let biobert = SpecializedTextModel::BioBERT;
678 let rules = biobert.get_preprocessing_rules();
679 assert!(rules.contains(&PreprocessingRule::NormalizeMedicalTerms));
680 assert!(rules.contains(&PreprocessingRule::HandleGeneNames));
681 }
682
683 #[test]
684 fn test_specialized_text_config_factory_methods() {
685 let scibert_config = SpecializedTextEmbedding::scibert_config();
686 assert_eq!(scibert_config.model_type, SpecializedTextModel::SciBERT);
687 assert_eq!(scibert_config.base_config.dimensions, 768);
688 assert!(scibert_config.preprocessing_enabled);
689 assert!(scibert_config.vocab_augmentation);
690 assert!(scibert_config.domain_pretraining);
691
692 let codebert_config = SpecializedTextEmbedding::codebert_config();
693 assert_eq!(codebert_config.model_type, SpecializedTextModel::CodeBERT);
694 assert!(!codebert_config.vocab_augmentation);
695
696 let biobert_config = SpecializedTextEmbedding::biobert_config();
697 assert_eq!(biobert_config.model_type, SpecializedTextModel::BioBERT);
698 assert!(biobert_config.fine_tune_config.freeze_base_layers);
699 assert_eq!(biobert_config.fine_tune_config.frozen_layers, 6);
700 assert!(biobert_config.fine_tune_config.gradual_unfreezing);
701 }
702
703 #[test]
704 fn test_specialized_text_embedding_creation() {
705 let config = SpecializedTextEmbedding::scibert_config();
706 let model = SpecializedTextEmbedding::new(config);
707
708 assert!(model.model_stats.model_type.contains("SciBERT"));
709 assert_eq!(model.model_stats.dimensions, 768);
710 assert!(!model.is_trained);
711 assert_eq!(model.text_embeddings.len(), 0);
712 assert_eq!(model.preprocessing_rules.len(), 4); }
714
715 #[test]
716 fn test_preprocessing_medical_terms() {
717 let config = SpecializedTextEmbedding::biobert_config();
718 let model = SpecializedTextEmbedding::new(config);
719
720 let text = "Patient takes 100 mg/kg b.i.d. for treatment";
721 let processed = model.preprocess_text(text).unwrap();
722
723 assert!(processed.contains("milligrams per kilogram"));
725 assert!(processed.contains("twice daily"));
726 }
727
728 #[test]
729 fn test_preprocessing_disabled() {
730 let mut config = SpecializedTextEmbedding::biobert_config();
731 config.preprocessing_enabled = false;
732 let model = SpecializedTextEmbedding::new(config);
733
734 let text = "Patient takes 100 mg/kg b.i.d. for treatment";
735 let processed = model.preprocess_text(text).unwrap();
736
737 assert_eq!(processed, text);
739 }
740
741 #[tokio::test]
742 async fn test_specialized_text_encoding() {
743 let config = SpecializedTextEmbedding::scibert_config();
744 let mut model = SpecializedTextEmbedding::new(config);
745
746 let text = "The protein folding study shows significant results with p < 0.001";
747 let embedding = model.encode_text(text).await.unwrap();
748
749 assert_eq!(embedding.len(), 768);
750
751 let embedding2 = model.encode_text(text).await.unwrap();
753 assert_eq!(embedding.to_vec(), embedding2.to_vec());
754 assert_eq!(model.text_embeddings.len(), 1);
755 }
756
757 #[tokio::test]
758 async fn test_domain_specific_features() {
759 let config = SpecializedTextEmbedding::scibert_config();
761 let mut model = SpecializedTextEmbedding::new(config);
762
763 let scientific_text = "The study by Smith et al. shows figure 1 demonstrates the results";
764 let embedding = model.encode_text(scientific_text).await.unwrap();
765
766 assert_eq!(embedding[0], 1.2); assert_eq!(embedding[1], 1.2); let config = SpecializedTextEmbedding::codebert_config();
773 let mut model = SpecializedTextEmbedding::new(config);
774
775 let code_text = "function calculateSum() { return a + b; }";
776 let embedding = model.encode_text(code_text).await.unwrap();
777
778 assert_eq!(embedding[0], 1.2); assert!(embedding[2] > 0.0); let config = SpecializedTextEmbedding::biobert_config();
784 let mut model = SpecializedTextEmbedding::new(config);
785
786 let biomedical_text =
787 "The protein expression correlates with cancer disease progression, dose 100mg";
788 let embedding = model.encode_text(biomedical_text).await.unwrap();
789
790 assert_eq!(embedding[0], 1.2); assert_eq!(embedding[1], 1.2); assert_eq!(embedding[2], 1.2); }
795
796 #[tokio::test]
797 async fn test_fine_tuning() {
798 let config = SpecializedTextEmbedding::biobert_config();
799 let mut model = SpecializedTextEmbedding::new(config);
800
801 let training_texts = vec![
802 "Gene expression analysis in cancer cells".to_string(),
803 "Protein folding mechanisms in disease".to_string(),
804 "Drug interaction with target proteins".to_string(),
805 ];
806
807 let stats = model.fine_tune(training_texts).await.unwrap();
808
809 assert!(model.is_trained);
810 assert_eq!(stats.epochs_completed, 5); assert!(stats.training_time_seconds > 0.0);
812 assert!(!stats.loss_history.is_empty());
813 assert!(model.model_stats.is_trained);
814 assert!(model.model_stats.last_training_time.is_some());
815 }
816}