1use crate::{ModelConfig, ModelStats, TrainingStats};
4use anyhow::Result;
5use chrono::Utc;
6use scirs2_core::ndarray_ext::Array1;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use uuid::Uuid;
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub enum SpecializedTextModel {
14 SciBERT,
16 CodeBERT,
18 BioBERT,
20 LegalBERT,
22 FinBERT,
24 ClinicalBERT,
26 ChemBERT,
28}
29
30impl SpecializedTextModel {
31 pub fn model_name(&self) -> &'static str {
33 match self {
34 SpecializedTextModel::SciBERT => "allenai/scibert_scivocab_uncased",
35 SpecializedTextModel::CodeBERT => "microsoft/codebert-base",
36 SpecializedTextModel::BioBERT => "dmis-lab/biobert-base-cased-v1.2",
37 SpecializedTextModel::LegalBERT => "nlpaueb/legal-bert-base-uncased",
38 SpecializedTextModel::FinBERT => "ProsusAI/finbert",
39 SpecializedTextModel::ClinicalBERT => "emilyalsentzer/Bio_ClinicalBERT",
40 SpecializedTextModel::ChemBERT => "seyonec/ChemBERTa-zinc-base-v1",
41 }
42 }
43
44 pub fn vocab_size(&self) -> usize {
46 match self {
47 SpecializedTextModel::SciBERT => 31090,
48 SpecializedTextModel::CodeBERT => 50265,
49 SpecializedTextModel::BioBERT => 28996,
50 SpecializedTextModel::LegalBERT => 30522,
51 SpecializedTextModel::FinBERT => 30522,
52 SpecializedTextModel::ClinicalBERT => 28996,
53 SpecializedTextModel::ChemBERT => 600,
54 }
55 }
56
57 pub fn embedding_dim(&self) -> usize {
59 match self {
60 SpecializedTextModel::SciBERT => 768,
61 SpecializedTextModel::CodeBERT => 768,
62 SpecializedTextModel::BioBERT => 768,
63 SpecializedTextModel::LegalBERT => 768,
64 SpecializedTextModel::FinBERT => 768,
65 SpecializedTextModel::ClinicalBERT => 768,
66 SpecializedTextModel::ChemBERT => 384,
67 }
68 }
69
70 pub fn max_sequence_length(&self) -> usize {
72 match self {
73 SpecializedTextModel::SciBERT => 512,
74 SpecializedTextModel::CodeBERT => 512,
75 SpecializedTextModel::BioBERT => 512,
76 SpecializedTextModel::LegalBERT => 512,
77 SpecializedTextModel::FinBERT => 512,
78 SpecializedTextModel::ClinicalBERT => 512,
79 SpecializedTextModel::ChemBERT => 512,
80 }
81 }
82
83 pub fn get_preprocessing_rules(&self) -> Vec<PreprocessingRule> {
85 match self {
86 SpecializedTextModel::SciBERT => vec![
87 PreprocessingRule::NormalizeScientificNotation,
88 PreprocessingRule::ExpandAbbreviations,
89 PreprocessingRule::HandleChemicalFormulas,
90 PreprocessingRule::PreserveCitations,
91 ],
92 SpecializedTextModel::CodeBERT => vec![
93 PreprocessingRule::PreserveCodeTokens,
94 PreprocessingRule::HandleCamelCase,
95 PreprocessingRule::NormalizeWhitespace,
96 PreprocessingRule::PreservePunctuation,
97 ],
98 SpecializedTextModel::BioBERT => vec![
99 PreprocessingRule::NormalizeMedicalTerms,
100 PreprocessingRule::HandleGeneNames,
101 PreprocessingRule::ExpandMedicalAbbreviations,
102 PreprocessingRule::PreserveDosages,
103 ],
104 SpecializedTextModel::LegalBERT => vec![
105 PreprocessingRule::PreserveLegalCitations,
106 PreprocessingRule::HandleLegalTerms,
107 PreprocessingRule::NormalizeCaseReferences,
108 ],
109 SpecializedTextModel::FinBERT => vec![
110 PreprocessingRule::NormalizeFinancialTerms,
111 PreprocessingRule::HandleCurrencySymbols,
112 PreprocessingRule::PreservePercentages,
113 ],
114 SpecializedTextModel::ClinicalBERT => vec![
115 PreprocessingRule::NormalizeMedicalTerms,
116 PreprocessingRule::HandleMedicalAbbreviations,
117 PreprocessingRule::PreserveDosages,
118 PreprocessingRule::NormalizeTimestamps,
119 ],
120 SpecializedTextModel::ChemBERT => vec![
121 PreprocessingRule::HandleChemicalFormulas,
122 PreprocessingRule::PreserveMolecularStructures,
123 PreprocessingRule::NormalizeChemicalNames,
124 ],
125 }
126 }
127}
128
129#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
131pub enum PreprocessingRule {
132 NormalizeScientificNotation,
134 ExpandAbbreviations,
136 HandleChemicalFormulas,
138 PreserveCitations,
140 PreserveCodeTokens,
142 HandleCamelCase,
144 NormalizeWhitespace,
146 PreservePunctuation,
148 NormalizeMedicalTerms,
150 HandleGeneNames,
152 ExpandMedicalAbbreviations,
154 PreserveDosages,
156 PreserveLegalCitations,
158 HandleLegalTerms,
160 NormalizeCaseReferences,
162 NormalizeFinancialTerms,
164 HandleCurrencySymbols,
166 PreservePercentages,
168 HandleMedicalAbbreviations,
170 NormalizeTimestamps,
172 PreserveMolecularStructures,
174 NormalizeChemicalNames,
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct SpecializedTextConfig {
181 pub model_type: SpecializedTextModel,
182 pub base_config: ModelConfig,
183 pub fine_tune_config: FineTuningConfig,
185 pub preprocessing_enabled: bool,
187 pub vocab_augmentation: bool,
189 pub domain_pretraining: bool,
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct FineTuningConfig {
196 pub learning_rate: f64,
198 pub epochs: usize,
200 pub freeze_base_layers: bool,
202 pub frozen_layers: usize,
204 pub gradual_unfreezing: bool,
206 pub discriminative_rates: Vec<f64>,
208}
209
210impl Default for FineTuningConfig {
211 fn default() -> Self {
212 Self {
213 learning_rate: 2e-5,
214 epochs: 3,
215 freeze_base_layers: false,
216 frozen_layers: 0,
217 gradual_unfreezing: false,
218 discriminative_rates: vec![],
219 }
220 }
221}
222
223impl Default for SpecializedTextConfig {
224 fn default() -> Self {
225 Self {
226 model_type: SpecializedTextModel::BioBERT,
227 base_config: ModelConfig::default(),
228 fine_tune_config: FineTuningConfig::default(),
229 preprocessing_enabled: true,
230 vocab_augmentation: false,
231 domain_pretraining: false,
232 }
233 }
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct SpecializedTextEmbedding {
239 pub config: SpecializedTextConfig,
240 pub model_id: Uuid,
241 pub text_embeddings: HashMap<String, Array1<f32>>,
243 pub domain_vocab: HashSet<String>,
245 pub preprocessing_rules: Vec<PreprocessingRule>,
247 pub training_stats: TrainingStats,
249 pub model_stats: ModelStats,
251 pub is_trained: bool,
252}
253
254impl SpecializedTextEmbedding {
255 pub fn new(config: SpecializedTextConfig) -> Self {
257 let model_id = Uuid::new_v4();
258 let now = Utc::now();
259 let preprocessing_rules = config.model_type.get_preprocessing_rules();
260
261 Self {
262 model_id,
263 text_embeddings: HashMap::new(),
264 domain_vocab: HashSet::new(),
265 preprocessing_rules,
266 training_stats: TrainingStats::default(),
267 model_stats: ModelStats {
268 num_entities: 0,
269 num_relations: 0,
270 num_triples: 0,
271 dimensions: config.model_type.embedding_dim(),
272 is_trained: false,
273 model_type: format!("SpecializedText_{:?}", config.model_type),
274 creation_time: now,
275 last_training_time: None,
276 },
277 is_trained: false,
278 config,
279 }
280 }
281
282 pub fn scibert_config() -> SpecializedTextConfig {
284 SpecializedTextConfig {
285 model_type: SpecializedTextModel::SciBERT,
286 base_config: ModelConfig::default().with_dimensions(768),
287 fine_tune_config: FineTuningConfig::default(),
288 preprocessing_enabled: true,
289 vocab_augmentation: true,
290 domain_pretraining: true,
291 }
292 }
293
294 pub fn codebert_config() -> SpecializedTextConfig {
296 SpecializedTextConfig {
297 model_type: SpecializedTextModel::CodeBERT,
298 base_config: ModelConfig::default().with_dimensions(768),
299 fine_tune_config: FineTuningConfig::default(),
300 preprocessing_enabled: true,
301 vocab_augmentation: false,
302 domain_pretraining: true,
303 }
304 }
305
306 pub fn biobert_config() -> SpecializedTextConfig {
308 SpecializedTextConfig {
309 model_type: SpecializedTextModel::BioBERT,
310 base_config: ModelConfig::default().with_dimensions(768),
311 fine_tune_config: FineTuningConfig {
312 learning_rate: 1e-5,
313 epochs: 5,
314 freeze_base_layers: true,
315 frozen_layers: 6,
316 gradual_unfreezing: true,
317 discriminative_rates: vec![1e-6, 5e-6, 1e-5, 2e-5],
318 },
319 preprocessing_enabled: true,
320 vocab_augmentation: true,
321 domain_pretraining: true,
322 }
323 }
324
325 pub fn preprocess_text(&self, text: &str) -> Result<String> {
327 if !self.config.preprocessing_enabled {
328 return Ok(text.to_string());
329 }
330
331 let mut processed = text.to_string();
332
333 for rule in &self.preprocessing_rules {
334 processed = self.apply_preprocessing_rule(&processed, rule)?;
335 }
336
337 Ok(processed)
338 }
339
340 fn apply_preprocessing_rule(&self, text: &str, rule: &PreprocessingRule) -> Result<String> {
342 match rule {
343 PreprocessingRule::NormalizeScientificNotation => {
344 Ok(text
346 .replace("E+", "e+")
347 .replace("E-", "e-")
348 .replace("E", "e"))
349 }
350 PreprocessingRule::HandleChemicalFormulas => {
351 Ok(text.replace("H2O", "[CHEM]H2O[/CHEM]"))
353 }
354 PreprocessingRule::HandleCamelCase => {
355 let mut result = String::new();
357 let mut chars = text.chars().peekable();
358 while let Some(c) = chars.next() {
359 result.push(c);
360 if c.is_lowercase() && chars.peek().is_some_and(|&next| next.is_uppercase()) {
361 result.push(' ');
362 }
363 }
364 Ok(result)
365 }
366 PreprocessingRule::NormalizeMedicalTerms => {
367 let mut result = text.to_string();
369 let replacements = [
370 ("mg/kg", "milligrams per kilogram"),
371 ("q.d.", "once daily"),
372 ("b.i.d.", "twice daily"),
373 ("t.i.d.", "three times daily"),
374 ("q.i.d.", "four times daily"),
375 ];
376
377 for (abbrev, expansion) in &replacements {
378 result = result.replace(abbrev, expansion);
379 }
380 Ok(result)
381 }
382 PreprocessingRule::HandleGeneNames => {
383 Ok(text
385 .replace("BRCA1", "[GENE]BRCA1[/GENE]")
386 .replace("TP53", "[GENE]TP53[/GENE]"))
387 }
388 PreprocessingRule::PreserveCodeTokens => {
389 Ok(text.replace("function", "[CODE]function[/CODE]"))
391 }
392 _ => {
393 Ok(text.to_string())
395 }
396 }
397 }
398
399 pub async fn encode_text(&mut self, text: &str) -> Result<Array1<f32>> {
401 let processed_text = self.preprocess_text(text)?;
403
404 if let Some(cached_embedding) = self.text_embeddings.get(&processed_text) {
406 return Ok(cached_embedding.clone());
407 }
408
409 let embedding = self.generate_specialized_embedding(&processed_text).await?;
411
412 self.text_embeddings
414 .insert(processed_text, embedding.clone());
415
416 Ok(embedding)
417 }
418
419 async fn generate_specialized_embedding(&self, text: &str) -> Result<Array1<f32>> {
421 let embedding_dim = self.config.model_type.embedding_dim();
425 let mut embedding = vec![0.0; embedding_dim];
426
427 match self.config.model_type {
429 SpecializedTextModel::SciBERT => {
430 embedding[0] = if text.contains("et al.") { 1.0 } else { 0.0 };
432 embedding[1] = if text.contains("figure") || text.contains("table") {
433 1.0
434 } else {
435 0.0
436 };
437 embedding[2] = text.matches(char::is_numeric).count() as f32 / text.len() as f32;
438 }
439 SpecializedTextModel::CodeBERT => {
440 embedding[0] = if text.contains("function") || text.contains("def") {
442 1.0
443 } else {
444 0.0
445 };
446 embedding[1] = if text.contains("class") || text.contains("struct") {
447 1.0
448 } else {
449 0.0
450 };
451 embedding[2] =
452 text.matches(|c: char| "{}()[]".contains(c)).count() as f32 / text.len() as f32;
453 }
454 SpecializedTextModel::BioBERT => {
455 embedding[0] = if text.contains("protein") || text.contains("gene") {
457 1.0
458 } else {
459 0.0
460 };
461 embedding[1] = if text.contains("disease") || text.contains("syndrome") {
462 1.0
463 } else {
464 0.0
465 };
466 embedding[2] = if text.contains("mg") || text.contains("dose") {
467 1.0
468 } else {
469 0.0
470 };
471 }
472 _ => {
473 embedding[0] = text.len() as f32 / 1000.0; embedding[1] = text.split_whitespace().count() as f32 / text.len() as f32;
476 }
478 }
479
480 for (i, item) in embedding.iter_mut().enumerate().take(embedding_dim).skip(3) {
482 let byte_val = text.as_bytes().get(i % text.len()).copied().unwrap_or(0) as f32;
483 *item = (byte_val / 255.0 - 0.5) * 2.0; }
485
486 if self.config.domain_pretraining {
488 for val in &mut embedding {
489 *val *= 1.2; }
491 }
492
493 Ok(Array1::from_vec(embedding))
494 }
495
496 pub async fn fine_tune(&mut self, training_texts: Vec<String>) -> Result<TrainingStats> {
498 let start_time = std::time::Instant::now();
499 let epochs = self.config.fine_tune_config.epochs;
500
501 let mut loss_history = Vec::new();
502
503 for epoch in 0..epochs {
504 let mut epoch_loss = 0.0;
505
506 for text in &training_texts {
507 let embedding = self.encode_text(text).await?;
509
510 let target_variance = 0.1; let actual_variance = embedding.var(0.0);
513 let loss = (actual_variance - target_variance).powi(2);
514 epoch_loss += loss;
515 }
516
517 epoch_loss /= training_texts.len() as f32;
518 loss_history.push(epoch_loss as f64);
519
520 if epoch % 10 == 0 {
521 println!("Fine-tuning epoch {epoch}: Loss = {epoch_loss:.6}");
522 }
523 }
524
525 let training_time = start_time.elapsed().as_secs_f64();
526
527 self.training_stats = TrainingStats {
528 epochs_completed: epochs,
529 final_loss: loss_history.last().copied().unwrap_or(0.0),
530 training_time_seconds: training_time,
531 convergence_achieved: loss_history.last().is_some_and(|&loss| loss < 0.01),
532 loss_history,
533 };
534
535 self.is_trained = true;
536 self.model_stats.is_trained = true;
537 self.model_stats.last_training_time = Some(Utc::now());
538
539 Ok(self.training_stats.clone())
540 }
541
542 pub fn get_stats(&self) -> ModelStats {
544 self.model_stats.clone()
545 }
546
547 pub fn clear_cache(&mut self) {
549 self.text_embeddings.clear();
550 }
551}
552
553#[allow(dead_code)]
555mod regex {
556 #[allow(dead_code)]
557 pub struct Regex(String);
558
559 impl Regex {
560 #[allow(dead_code)]
561 pub fn new(pattern: &str) -> Result<Self, &'static str> {
562 Ok(Regex(pattern.to_string()))
563 }
564
565 #[allow(dead_code)]
566 pub fn replace_all<'a, F>(&self, text: &'a str, _rep: F) -> std::borrow::Cow<'a, str>
567 where
568 F: Fn(&str) -> String,
569 {
570 std::borrow::Cow::Borrowed(text)
572 }
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579 use crate::biomedical_embeddings::types::{BiomedicalEntityType, BiomedicalEmbeddingConfig, BiomedicalEmbedding};
580
581 #[test]
582 fn test_biomedical_entity_type_from_iri() {
583 assert_eq!(
584 BiomedicalEntityType::from_iri("http://example.org/gene/BRCA1"),
585 Some(BiomedicalEntityType::Gene)
586 );
587 assert_eq!(
588 BiomedicalEntityType::from_iri("http://example.org/disease/cancer"),
589 Some(BiomedicalEntityType::Disease)
590 );
591 assert_eq!(
592 BiomedicalEntityType::from_iri("http://example.org/drug/aspirin"),
593 Some(BiomedicalEntityType::Drug)
594 );
595 }
596
597 #[test]
598 fn test_biomedical_config_default() {
599 let config = BiomedicalEmbeddingConfig::default();
600 assert_eq!(config.gene_disease_weight, 2.0);
601 assert_eq!(config.drug_target_weight, 1.5);
602 assert!(config.use_sequence_similarity);
603 assert_eq!(config.species_filter, Some("Homo sapiens".to_string()));
604 }
605
606 #[test]
607 fn test_biomedical_embedding_creation() {
608 let config = BiomedicalEmbeddingConfig::default();
609 let model = BiomedicalEmbedding::new(config);
610
611 assert_eq!(model.model_type(), "BiomedicalEmbedding");
612 assert!(!model.is_trained());
613 assert_eq!(model.gene_embeddings.len(), 0);
614 }
615
616 #[test]
617 fn test_gene_disease_association() {
618 let mut model = BiomedicalEmbedding::new(BiomedicalEmbeddingConfig::default());
619
620 model.add_gene_disease_association("BRCA1", "breast_cancer", 0.8);
621
622 assert_eq!(
623 model
624 .features
625 .gene_disease_associations
626 .get(&("BRCA1".to_string(), "breast_cancer".to_string())),
627 Some(&0.8)
628 );
629 }
630
631 #[test]
632 fn test_drug_target_interaction() {
633 let mut model = BiomedicalEmbedding::new(BiomedicalEmbeddingConfig::default());
634
635 model.add_drug_target_interaction("aspirin", "COX1", 0.9);
636
637 assert_eq!(
638 model
639 .features
640 .drug_target_affinities
641 .get(&("aspirin".to_string(), "COX1".to_string())),
642 Some(&0.9)
643 );
644 }
645
646 #[test]
647 fn test_specialized_text_model_properties() {
648 let scibert = SpecializedTextModel::SciBERT;
649 assert_eq!(scibert.model_name(), "allenai/scibert_scivocab_uncased");
650 assert_eq!(scibert.vocab_size(), 31090);
651 assert_eq!(scibert.embedding_dim(), 768);
652 assert_eq!(scibert.max_sequence_length(), 512);
653
654 let codebert = SpecializedTextModel::CodeBERT;
655 assert_eq!(codebert.model_name(), "microsoft/codebert-base");
656 assert_eq!(codebert.vocab_size(), 50265);
657
658 let biobert = SpecializedTextModel::BioBERT;
659 assert_eq!(biobert.model_name(), "dmis-lab/biobert-base-cased-v1.2");
660 assert_eq!(biobert.vocab_size(), 28996);
661 }
662
663 #[test]
664 fn test_specialized_text_preprocessing_rules() {
665 let scibert = SpecializedTextModel::SciBERT;
666 let rules = scibert.get_preprocessing_rules();
667 assert!(rules.contains(&PreprocessingRule::NormalizeScientificNotation));
668 assert!(rules.contains(&PreprocessingRule::HandleChemicalFormulas));
669
670 let codebert = SpecializedTextModel::CodeBERT;
671 let rules = codebert.get_preprocessing_rules();
672 assert!(rules.contains(&PreprocessingRule::PreserveCodeTokens));
673 assert!(rules.contains(&PreprocessingRule::HandleCamelCase));
674
675 let biobert = SpecializedTextModel::BioBERT;
676 let rules = biobert.get_preprocessing_rules();
677 assert!(rules.contains(&PreprocessingRule::NormalizeMedicalTerms));
678 assert!(rules.contains(&PreprocessingRule::HandleGeneNames));
679 }
680
681 #[test]
682 fn test_specialized_text_config_factory_methods() {
683 let scibert_config = SpecializedTextEmbedding::scibert_config();
684 assert_eq!(scibert_config.model_type, SpecializedTextModel::SciBERT);
685 assert_eq!(scibert_config.base_config.dimensions, 768);
686 assert!(scibert_config.preprocessing_enabled);
687 assert!(scibert_config.vocab_augmentation);
688 assert!(scibert_config.domain_pretraining);
689
690 let codebert_config = SpecializedTextEmbedding::codebert_config();
691 assert_eq!(codebert_config.model_type, SpecializedTextModel::CodeBERT);
692 assert!(!codebert_config.vocab_augmentation);
693
694 let biobert_config = SpecializedTextEmbedding::biobert_config();
695 assert_eq!(biobert_config.model_type, SpecializedTextModel::BioBERT);
696 assert!(biobert_config.fine_tune_config.freeze_base_layers);
697 assert_eq!(biobert_config.fine_tune_config.frozen_layers, 6);
698 assert!(biobert_config.fine_tune_config.gradual_unfreezing);
699 }
700
701 #[test]
702 fn test_specialized_text_embedding_creation() {
703 let config = SpecializedTextEmbedding::scibert_config();
704 let model = SpecializedTextEmbedding::new(config);
705
706 assert!(model.model_stats.model_type.contains("SciBERT"));
707 assert_eq!(model.model_stats.dimensions, 768);
708 assert!(!model.is_trained);
709 assert_eq!(model.text_embeddings.len(), 0);
710 assert_eq!(model.preprocessing_rules.len(), 4); }
712
713 #[test]
714 fn test_preprocessing_medical_terms() {
715 let config = SpecializedTextEmbedding::biobert_config();
716 let model = SpecializedTextEmbedding::new(config);
717
718 let text = "Patient takes 100 mg/kg b.i.d. for treatment";
719 let processed = model.preprocess_text(text).unwrap();
720
721 assert!(processed.contains("milligrams per kilogram"));
723 assert!(processed.contains("twice daily"));
724 }
725
726 #[test]
727 fn test_preprocessing_disabled() {
728 let mut config = SpecializedTextEmbedding::biobert_config();
729 config.preprocessing_enabled = false;
730 let model = SpecializedTextEmbedding::new(config);
731
732 let text = "Patient takes 100 mg/kg b.i.d. for treatment";
733 let processed = model.preprocess_text(text).unwrap();
734
735 assert_eq!(processed, text);
737 }
738
739 #[tokio::test]
740 async fn test_specialized_text_encoding() {
741 let config = SpecializedTextEmbedding::scibert_config();
742 let mut model = SpecializedTextEmbedding::new(config);
743
744 let text = "The protein folding study shows significant results with p < 0.001";
745 let embedding = model.encode_text(text).await.unwrap();
746
747 assert_eq!(embedding.len(), 768);
748
749 let embedding2 = model.encode_text(text).await.unwrap();
751 assert_eq!(embedding.to_vec(), embedding2.to_vec());
752 assert_eq!(model.text_embeddings.len(), 1);
753 }
754
755 #[tokio::test]
756 async fn test_domain_specific_features() {
757 let config = SpecializedTextEmbedding::scibert_config();
759 let mut model = SpecializedTextEmbedding::new(config);
760
761 let scientific_text = "The study by Smith et al. shows figure 1 demonstrates the results";
762 let embedding = model.encode_text(scientific_text).await.unwrap();
763
764 assert_eq!(embedding[0], 1.2); assert_eq!(embedding[1], 1.2); let config = SpecializedTextEmbedding::codebert_config();
771 let mut model = SpecializedTextEmbedding::new(config);
772
773 let code_text = "function calculateSum() { return a + b; }";
774 let embedding = model.encode_text(code_text).await.unwrap();
775
776 assert_eq!(embedding[0], 1.2); assert!(embedding[2] > 0.0); let config = SpecializedTextEmbedding::biobert_config();
782 let mut model = SpecializedTextEmbedding::new(config);
783
784 let biomedical_text =
785 "The protein expression correlates with cancer disease progression, dose 100mg";
786 let embedding = model.encode_text(biomedical_text).await.unwrap();
787
788 assert_eq!(embedding[0], 1.2); assert_eq!(embedding[1], 1.2); assert_eq!(embedding[2], 1.2); }
793
794 #[tokio::test]
795 async fn test_fine_tuning() {
796 let config = SpecializedTextEmbedding::biobert_config();
797 let mut model = SpecializedTextEmbedding::new(config);
798
799 let training_texts = vec![
800 "Gene expression analysis in cancer cells".to_string(),
801 "Protein folding mechanisms in disease".to_string(),
802 "Drug interaction with target proteins".to_string(),
803 ];
804
805 let stats = model.fine_tune(training_texts).await.unwrap();
806
807 assert!(model.is_trained);
808 assert_eq!(stats.epochs_completed, 5); assert!(stats.training_time_seconds > 0.0);
810 assert!(!stats.loss_history.is_empty());
811 assert!(model.model_stats.is_trained);
812 assert!(model.model_stats.last_training_time.is_some());
813 }
814}