oxirs_embed/biomedical_embeddings/
embedding.rs

1//! Module for biomedical embeddings
2
3use super::*;
4use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
5use anyhow::{anyhow, Result};
6use async_trait::async_trait;
7use chrono::Utc;
8use scirs2_core::ndarray_ext::Array1;
9use scirs2_core::random::{Random, Rng};
10use std::collections::HashMap;
11use uuid::Uuid;
12
13impl BiomedicalEmbedding {
14    /// Create new biomedical embedding model
15    pub fn new(config: BiomedicalEmbeddingConfig) -> Self {
16        let model_id = Uuid::new_v4();
17        let now = Utc::now();
18
19        Self {
20            model_id,
21            gene_embeddings: HashMap::new(),
22            protein_embeddings: HashMap::new(),
23            disease_embeddings: HashMap::new(),
24            drug_embeddings: HashMap::new(),
25            compound_embeddings: HashMap::new(),
26            pathway_embeddings: HashMap::new(),
27            relation_embeddings: HashMap::new(),
28            entity_types: HashMap::new(),
29            relation_types: HashMap::new(),
30            triples: Vec::new(),
31            features: BiomedicalFeatures::default(),
32            training_stats: TrainingStats::default(),
33            model_stats: ModelStats {
34                num_entities: 0,
35                num_relations: 0,
36                num_triples: 0,
37                dimensions: config.base_config.dimensions,
38                is_trained: false,
39                model_type: "BiomedicalEmbedding".to_string(),
40                creation_time: now,
41                last_training_time: None,
42            },
43            is_trained: false,
44            config,
45        }
46    }
47
48    /// Get the model type identifier
49    pub fn model_type(&self) -> &str {
50        "BiomedicalEmbedding"
51    }
52
53    /// Check if the model has been trained
54    pub fn is_trained(&self) -> bool {
55        self.is_trained
56    }
57
58    /// Add gene-disease association
59    pub fn add_gene_disease_association(&mut self, gene: &str, disease: &str, score: f32) {
60        self.features
61            .gene_disease_associations
62            .insert((gene.to_string(), disease.to_string()), score);
63
64        // Also add reverse mapping
65        self.features
66            .gene_disease_associations
67            .insert((disease.to_string(), gene.to_string()), score);
68    }
69
70    /// Add drug-target interaction
71    pub fn add_drug_target_interaction(&mut self, drug: &str, target: &str, affinity: f32) {
72        self.features
73            .drug_target_affinities
74            .insert((drug.to_string(), target.to_string()), affinity);
75    }
76
77    /// Add pathway membership
78    pub fn add_pathway_membership(&mut self, entity: &str, pathway: &str, score: f32) {
79        self.features
80            .pathway_memberships
81            .insert((entity.to_string(), pathway.to_string()), score);
82    }
83
84    /// Add protein-protein interaction
85    pub fn add_protein_interaction(&mut self, protein1: &str, protein2: &str, score: f32) {
86        self.features
87            .protein_interactions
88            .insert((protein1.to_string(), protein2.to_string()), score);
89
90        // Symmetric relationship
91        self.features
92            .protein_interactions
93            .insert((protein2.to_string(), protein1.to_string()), score);
94    }
95
96    /// Get entity embedding with biomedical type awareness
97    pub fn get_typed_entity_embedding(&self, entity: &str) -> Result<Vector> {
98        if let Some(entity_type) = self.entity_types.get(entity) {
99            let embedding = match entity_type {
100                BiomedicalEntityType::Gene => self.gene_embeddings.get(entity),
101                BiomedicalEntityType::Protein => self.protein_embeddings.get(entity),
102                BiomedicalEntityType::Disease => self.disease_embeddings.get(entity),
103                BiomedicalEntityType::Drug => self.drug_embeddings.get(entity),
104                BiomedicalEntityType::Compound => self.compound_embeddings.get(entity),
105                BiomedicalEntityType::Pathway => self.pathway_embeddings.get(entity),
106                _ => None,
107            };
108
109            if let Some(emb) = embedding {
110                Ok(Vector::from_array1(emb))
111            } else {
112                Err(anyhow!(
113                    "No embedding found for {} of type {:?}",
114                    entity,
115                    entity_type
116                ))
117            }
118        } else {
119            Err(anyhow!("Unknown entity type for {}", entity))
120        }
121    }
122
123    /// Predict gene-disease associations
124    pub fn predict_gene_disease_associations(
125        &self,
126        gene: &str,
127        k: usize,
128    ) -> Result<Vec<(String, f64)>> {
129        if !self.is_trained {
130            return Err(anyhow!("Model not trained"));
131        }
132
133        let gene_embedding = self
134            .gene_embeddings
135            .get(gene)
136            .ok_or_else(|| anyhow!("Gene {} not found", gene))?;
137
138        let mut scores = Vec::new();
139
140        for (disease, disease_embedding) in &self.disease_embeddings {
141            // Base similarity
142            let similarity = gene_embedding.dot(disease_embedding) as f64;
143
144            // Enhance with existing association data
145            let enhanced_score = if let Some(&assoc_score) = self
146                .features
147                .gene_disease_associations
148                .get(&(gene.to_string(), disease.clone()))
149            {
150                similarity * (1.0 + assoc_score as f64)
151            } else {
152                similarity
153            };
154
155            scores.push((disease.clone(), enhanced_score));
156        }
157
158        // Sort by score and return top k
159        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
160        scores.truncate(k);
161
162        Ok(scores)
163    }
164
165    /// Predict drug targets
166    pub fn predict_drug_targets(&self, drug: &str, k: usize) -> Result<Vec<(String, f64)>> {
167        if !self.is_trained {
168            return Err(anyhow!("Model not trained"));
169        }
170
171        let drug_embedding = self
172            .drug_embeddings
173            .get(drug)
174            .ok_or_else(|| anyhow!("Drug {} not found", drug))?;
175
176        let mut scores = Vec::new();
177
178        for (protein, protein_embedding) in &self.protein_embeddings {
179            // Base similarity
180            let similarity = drug_embedding.dot(protein_embedding) as f64;
181
182            // Enhance with binding affinity data
183            let enhanced_score = if let Some(&affinity) = self
184                .features
185                .drug_target_affinities
186                .get(&(drug.to_string(), protein.clone()))
187            {
188                similarity * (1.0 + affinity as f64)
189            } else {
190                similarity
191            };
192
193            scores.push((protein.clone(), enhanced_score));
194        }
195
196        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
197        scores.truncate(k);
198
199        Ok(scores)
200    }
201
202    /// Find pathway-related entities
203    pub fn find_pathway_entities(&self, pathway: &str, k: usize) -> Result<Vec<(String, f64)>> {
204        let pathway_embedding = self
205            .pathway_embeddings
206            .get(pathway)
207            .ok_or_else(|| anyhow!("Pathway {} not found", pathway))?;
208
209        let mut scores = Vec::new();
210
211        // Check genes
212        for (gene, gene_embedding) in &self.gene_embeddings {
213            let similarity = pathway_embedding.dot(gene_embedding) as f64;
214            scores.push((gene.clone(), similarity));
215        }
216
217        // Check proteins
218        for (protein, protein_embedding) in &self.protein_embeddings {
219            let similarity = pathway_embedding.dot(protein_embedding) as f64;
220            scores.push((protein.clone(), similarity));
221        }
222
223        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
224        scores.truncate(k);
225
226        Ok(scores)
227    }
228
229    /// Extract entity types from triples
230    fn extract_entity_types(&mut self) {
231        for triple in &self.triples {
232            // Extract entity types from IRIs
233            if let Some(subject_type) = BiomedicalEntityType::from_iri(&triple.subject.iri) {
234                self.entity_types
235                    .insert(triple.subject.iri.clone(), subject_type);
236            }
237
238            if let Some(object_type) = BiomedicalEntityType::from_iri(&triple.object.iri) {
239                self.entity_types
240                    .insert(triple.object.iri.clone(), object_type);
241            }
242
243            // Extract relation types
244            if let Some(relation_type) = BiomedicalRelationType::from_iri(&triple.predicate.iri) {
245                self.relation_types
246                    .insert(triple.predicate.iri.clone(), relation_type);
247            }
248        }
249    }
250
251    /// Initialize embeddings with biomedical-specific features
252    fn initialize_embeddings(&mut self) -> Result<()> {
253        let dimensions = self.config.base_config.dimensions;
254
255        // Initialize embeddings for each entity type
256        for (entity, entity_type) in &self.entity_types {
257            let embedding = Array1::from_vec(
258                (0..dimensions)
259                    .map(|_| {
260                        let mut random = Random::default();
261                        (random.random::<f32>() - 0.5) * 0.1
262                    })
263                    .collect(),
264            );
265
266            match entity_type {
267                BiomedicalEntityType::Gene => {
268                    self.gene_embeddings.insert(entity.clone(), embedding);
269                }
270                BiomedicalEntityType::Protein => {
271                    self.protein_embeddings.insert(entity.clone(), embedding);
272                }
273                BiomedicalEntityType::Disease => {
274                    self.disease_embeddings.insert(entity.clone(), embedding);
275                }
276                BiomedicalEntityType::Drug => {
277                    self.drug_embeddings.insert(entity.clone(), embedding);
278                }
279                BiomedicalEntityType::Compound => {
280                    self.compound_embeddings.insert(entity.clone(), embedding);
281                }
282                BiomedicalEntityType::Pathway => {
283                    self.pathway_embeddings.insert(entity.clone(), embedding);
284                }
285                _ => {
286                    // For other types, store in a general embedding map
287                    // This would be extended in a full implementation
288                }
289            }
290        }
291
292        // Initialize relation embeddings
293        for relation in self.relation_types.keys() {
294            let embedding = Array1::from_vec(
295                (0..dimensions)
296                    .map(|_| {
297                        let mut random = Random::default();
298                        (random.random::<f32>() - 0.5) * 0.1
299                    })
300                    .collect(),
301            );
302            self.relation_embeddings.insert(relation.clone(), embedding);
303        }
304
305        Ok(())
306    }
307
308    /// Compute biomedical-specific loss incorporating domain knowledge
309    fn compute_biomedical_loss(&self) -> f32 {
310        let mut total_loss = 0.0;
311        let mut count = 0;
312
313        // Gene-disease association loss
314        for ((gene, disease), &score) in &self.features.gene_disease_associations {
315            if let (Some(gene_emb), Some(disease_emb)) = (
316                self.gene_embeddings.get(gene),
317                self.disease_embeddings.get(disease),
318            ) {
319                let predicted_score = gene_emb.dot(disease_emb);
320                let loss = (predicted_score - score).powi(2);
321                total_loss += loss * self.config.gene_disease_weight;
322                count += 1;
323            }
324        }
325
326        // Drug-target interaction loss
327        for ((drug, target), &affinity) in &self.features.drug_target_affinities {
328            if let (Some(drug_emb), Some(target_emb)) = (
329                self.drug_embeddings.get(drug),
330                self.protein_embeddings.get(target),
331            ) {
332                let predicted_affinity = drug_emb.dot(target_emb);
333                let loss = (predicted_affinity - affinity).powi(2);
334                total_loss += loss * self.config.drug_target_weight;
335                count += 1;
336            }
337        }
338
339        // Pathway membership loss
340        for ((entity, pathway), &score) in &self.features.pathway_memberships {
341            if let Some(pathway_emb) = self.pathway_embeddings.get(pathway) {
342                let entity_emb = self.get_entity_embedding_any_type(entity);
343                if let Some(entity_emb) = entity_emb {
344                    let predicted_score = entity_emb.dot(pathway_emb);
345                    let loss = (predicted_score - score).powi(2);
346                    total_loss += loss * self.config.pathway_weight;
347                    count += 1;
348                }
349            }
350        }
351
352        if count > 0 {
353            total_loss / count as f32
354        } else {
355            0.0
356        }
357    }
358
359    /// Helper to get entity embedding from any type map
360    fn get_entity_embedding_any_type(&self, entity: &str) -> Option<&Array1<f32>> {
361        self.gene_embeddings
362            .get(entity)
363            .or_else(|| self.protein_embeddings.get(entity))
364            .or_else(|| self.disease_embeddings.get(entity))
365            .or_else(|| self.drug_embeddings.get(entity))
366            .or_else(|| self.compound_embeddings.get(entity))
367            .or_else(|| self.pathway_embeddings.get(entity))
368    }
369}
370
371#[async_trait]
372impl EmbeddingModel for BiomedicalEmbedding {
373    fn config(&self) -> &ModelConfig {
374        &self.config.base_config
375    }
376
377    fn model_id(&self) -> &Uuid {
378        &self.model_id
379    }
380
381    fn model_type(&self) -> &'static str {
382        "BiomedicalEmbedding"
383    }
384
385    fn add_triple(&mut self, triple: Triple) -> Result<()> {
386        self.triples.push(triple);
387        Ok(())
388    }
389
390    async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
391        let epochs = epochs.unwrap_or(1000);
392        let start_time = std::time::Instant::now();
393
394        // Extract entity and relation types
395        self.extract_entity_types();
396
397        // Initialize embeddings
398        self.initialize_embeddings()?;
399
400        // Training loop
401        let mut loss_history = Vec::new();
402
403        for epoch in 0..epochs {
404            let epoch_loss = self.compute_biomedical_loss();
405            loss_history.push(epoch_loss as f64);
406
407            // Simple convergence check
408            if epoch > 10 && epoch_loss < 0.001 {
409                break;
410            }
411
412            if epoch % 100 == 0 {
413                println!("Epoch {epoch}: Loss = {epoch_loss:.6}");
414            }
415        }
416
417        let training_time = start_time.elapsed().as_secs_f64();
418
419        self.training_stats = TrainingStats {
420            epochs_completed: epochs,
421            final_loss: loss_history.last().copied().unwrap_or(0.0),
422            training_time_seconds: training_time,
423            convergence_achieved: loss_history.last().is_some_and(|&loss| loss < 0.001),
424            loss_history,
425        };
426
427        self.is_trained = true;
428        self.model_stats.is_trained = true;
429        self.model_stats.last_training_time = Some(Utc::now());
430
431        // Update entity counts
432        self.model_stats.num_entities = self.entity_types.len();
433        self.model_stats.num_relations = self.relation_types.len();
434        self.model_stats.num_triples = self.triples.len();
435
436        Ok(self.training_stats.clone())
437    }
438
439    fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
440        self.get_typed_entity_embedding(entity)
441    }
442
443    fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
444        if let Some(embedding) = self.relation_embeddings.get(relation) {
445            Ok(Vector::from_array1(embedding))
446        } else {
447            Err(anyhow!("Relation {} not found", relation))
448        }
449    }
450
451    fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
452        let subject_emb = self.get_entity_embedding(subject)?;
453        let relation_emb = self.get_relation_embedding(predicate)?;
454        let object_emb = self.get_entity_embedding(object)?;
455
456        // TransE-style scoring with biomedical enhancements
457        let mut score = 0.0;
458        for i in 0..subject_emb.dimensions {
459            let diff = subject_emb.values[i] + relation_emb.values[i] - object_emb.values[i];
460            score += diff * diff;
461        }
462
463        // Convert to similarity score (higher is better)
464        Ok(1.0 / (1.0 + score as f64))
465    }
466
467    fn predict_objects(
468        &self,
469        subject: &str,
470        predicate: &str,
471        k: usize,
472    ) -> Result<Vec<(String, f64)>> {
473        // Use specialized prediction methods based on relation type
474        if let Some(relation_type) = self.relation_types.get(predicate) {
475            match relation_type {
476                BiomedicalRelationType::CausesDisease
477                | BiomedicalRelationType::AssociatedWithDisease => {
478                    return self.predict_gene_disease_associations(subject, k);
479                }
480                BiomedicalRelationType::TargetsProtein | BiomedicalRelationType::BindsToProtein => {
481                    return self.predict_drug_targets(subject, k);
482                }
483                _ => {
484                    // Fall back to generic prediction
485                }
486            }
487        }
488
489        // Generic prediction
490        let _subject_emb = self.get_entity_embedding(subject)?;
491        let _relation_emb = self.get_relation_embedding(predicate)?;
492
493        let mut scores = Vec::new();
494        for entity in self.entity_types.keys() {
495            if entity != subject {
496                if let Ok(score) = self.score_triple(subject, predicate, entity) {
497                    scores.push((entity.clone(), score));
498                }
499            }
500        }
501
502        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
503        scores.truncate(k);
504
505        Ok(scores)
506    }
507
508    fn predict_subjects(
509        &self,
510        predicate: &str,
511        object: &str,
512        k: usize,
513    ) -> Result<Vec<(String, f64)>> {
514        let _object_emb = self.get_entity_embedding(object)?;
515        let _relation_emb = self.get_relation_embedding(predicate)?;
516
517        let mut scores = Vec::new();
518        for entity in self.entity_types.keys() {
519            if entity != object {
520                if let Ok(score) = self.score_triple(entity, predicate, object) {
521                    scores.push((entity.clone(), score));
522                }
523            }
524        }
525
526        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
527        scores.truncate(k);
528
529        Ok(scores)
530    }
531
532    fn predict_relations(
533        &self,
534        subject: &str,
535        object: &str,
536        k: usize,
537    ) -> Result<Vec<(String, f64)>> {
538        let _subject_emb = self.get_entity_embedding(subject)?;
539        let _object_emb = self.get_entity_embedding(object)?;
540
541        let mut scores = Vec::new();
542        for relation in self.relation_types.keys() {
543            if let Ok(score) = self.score_triple(subject, relation, object) {
544                scores.push((relation.clone(), score));
545            }
546        }
547
548        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
549        scores.truncate(k);
550
551        Ok(scores)
552    }
553
554    fn get_entities(&self) -> Vec<String> {
555        self.entity_types.keys().cloned().collect()
556    }
557
558    fn get_relations(&self) -> Vec<String> {
559        self.relation_types.keys().cloned().collect()
560    }
561
562    fn get_stats(&self) -> ModelStats {
563        self.model_stats.clone()
564    }
565
566    fn save(&self, _path: &str) -> Result<()> {
567        // Implementation would serialize the model
568        Ok(())
569    }
570
571    fn load(&mut self, _path: &str) -> Result<()> {
572        // Implementation would deserialize the model
573        Ok(())
574    }
575
576    fn clear(&mut self) {
577        self.gene_embeddings.clear();
578        self.protein_embeddings.clear();
579        self.disease_embeddings.clear();
580        self.drug_embeddings.clear();
581        self.compound_embeddings.clear();
582        self.pathway_embeddings.clear();
583        self.relation_embeddings.clear();
584        self.entity_types.clear();
585        self.relation_types.clear();
586        self.triples.clear();
587        self.features = BiomedicalFeatures::default();
588        self.is_trained = false;
589    }
590
591    fn is_trained(&self) -> bool {
592        self.is_trained
593    }
594
595    async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
596        let mut embeddings = Vec::new();
597
598        for text in texts {
599            match self.get_entity_embedding(text) {
600                Ok(embedding) => {
601                    embeddings.push(embedding.values);
602                }
603                _ => {
604                    // Return zero embedding for unknown entities
605                    embeddings.push(vec![0.0; self.config.base_config.dimensions]);
606                }
607            }
608        }
609
610        Ok(embeddings)
611    }
612}