oxirs_embed/biomedical_embeddings/
embedding.rs1use super::*;
4use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
5use anyhow::{anyhow, Result};
6use async_trait::async_trait;
7use chrono::Utc;
8use scirs2_core::ndarray_ext::Array1;
9use scirs2_core::random::{Random, Rng};
10use std::collections::HashMap;
11use uuid::Uuid;
12
13impl BiomedicalEmbedding {
14 pub fn new(config: BiomedicalEmbeddingConfig) -> Self {
16 let model_id = Uuid::new_v4();
17 let now = Utc::now();
18
19 Self {
20 model_id,
21 gene_embeddings: HashMap::new(),
22 protein_embeddings: HashMap::new(),
23 disease_embeddings: HashMap::new(),
24 drug_embeddings: HashMap::new(),
25 compound_embeddings: HashMap::new(),
26 pathway_embeddings: HashMap::new(),
27 relation_embeddings: HashMap::new(),
28 entity_types: HashMap::new(),
29 relation_types: HashMap::new(),
30 triples: Vec::new(),
31 features: BiomedicalFeatures::default(),
32 training_stats: TrainingStats::default(),
33 model_stats: ModelStats {
34 num_entities: 0,
35 num_relations: 0,
36 num_triples: 0,
37 dimensions: config.base_config.dimensions,
38 is_trained: false,
39 model_type: "BiomedicalEmbedding".to_string(),
40 creation_time: now,
41 last_training_time: None,
42 },
43 is_trained: false,
44 config,
45 }
46 }
47
48 pub fn model_type(&self) -> &str {
50 "BiomedicalEmbedding"
51 }
52
53 pub fn is_trained(&self) -> bool {
55 self.is_trained
56 }
57
58 pub fn add_gene_disease_association(&mut self, gene: &str, disease: &str, score: f32) {
60 self.features
61 .gene_disease_associations
62 .insert((gene.to_string(), disease.to_string()), score);
63
64 self.features
66 .gene_disease_associations
67 .insert((disease.to_string(), gene.to_string()), score);
68 }
69
70 pub fn add_drug_target_interaction(&mut self, drug: &str, target: &str, affinity: f32) {
72 self.features
73 .drug_target_affinities
74 .insert((drug.to_string(), target.to_string()), affinity);
75 }
76
77 pub fn add_pathway_membership(&mut self, entity: &str, pathway: &str, score: f32) {
79 self.features
80 .pathway_memberships
81 .insert((entity.to_string(), pathway.to_string()), score);
82 }
83
84 pub fn add_protein_interaction(&mut self, protein1: &str, protein2: &str, score: f32) {
86 self.features
87 .protein_interactions
88 .insert((protein1.to_string(), protein2.to_string()), score);
89
90 self.features
92 .protein_interactions
93 .insert((protein2.to_string(), protein1.to_string()), score);
94 }
95
96 pub fn get_typed_entity_embedding(&self, entity: &str) -> Result<Vector> {
98 if let Some(entity_type) = self.entity_types.get(entity) {
99 let embedding = match entity_type {
100 BiomedicalEntityType::Gene => self.gene_embeddings.get(entity),
101 BiomedicalEntityType::Protein => self.protein_embeddings.get(entity),
102 BiomedicalEntityType::Disease => self.disease_embeddings.get(entity),
103 BiomedicalEntityType::Drug => self.drug_embeddings.get(entity),
104 BiomedicalEntityType::Compound => self.compound_embeddings.get(entity),
105 BiomedicalEntityType::Pathway => self.pathway_embeddings.get(entity),
106 _ => None,
107 };
108
109 if let Some(emb) = embedding {
110 Ok(Vector::from_array1(emb))
111 } else {
112 Err(anyhow!(
113 "No embedding found for {} of type {:?}",
114 entity,
115 entity_type
116 ))
117 }
118 } else {
119 Err(anyhow!("Unknown entity type for {}", entity))
120 }
121 }
122
123 pub fn predict_gene_disease_associations(
125 &self,
126 gene: &str,
127 k: usize,
128 ) -> Result<Vec<(String, f64)>> {
129 if !self.is_trained {
130 return Err(anyhow!("Model not trained"));
131 }
132
133 let gene_embedding = self
134 .gene_embeddings
135 .get(gene)
136 .ok_or_else(|| anyhow!("Gene {} not found", gene))?;
137
138 let mut scores = Vec::new();
139
140 for (disease, disease_embedding) in &self.disease_embeddings {
141 let similarity = gene_embedding.dot(disease_embedding) as f64;
143
144 let enhanced_score = if let Some(&assoc_score) = self
146 .features
147 .gene_disease_associations
148 .get(&(gene.to_string(), disease.clone()))
149 {
150 similarity * (1.0 + assoc_score as f64)
151 } else {
152 similarity
153 };
154
155 scores.push((disease.clone(), enhanced_score));
156 }
157
158 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
160 scores.truncate(k);
161
162 Ok(scores)
163 }
164
165 pub fn predict_drug_targets(&self, drug: &str, k: usize) -> Result<Vec<(String, f64)>> {
167 if !self.is_trained {
168 return Err(anyhow!("Model not trained"));
169 }
170
171 let drug_embedding = self
172 .drug_embeddings
173 .get(drug)
174 .ok_or_else(|| anyhow!("Drug {} not found", drug))?;
175
176 let mut scores = Vec::new();
177
178 for (protein, protein_embedding) in &self.protein_embeddings {
179 let similarity = drug_embedding.dot(protein_embedding) as f64;
181
182 let enhanced_score = if let Some(&affinity) = self
184 .features
185 .drug_target_affinities
186 .get(&(drug.to_string(), protein.clone()))
187 {
188 similarity * (1.0 + affinity as f64)
189 } else {
190 similarity
191 };
192
193 scores.push((protein.clone(), enhanced_score));
194 }
195
196 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
197 scores.truncate(k);
198
199 Ok(scores)
200 }
201
202 pub fn find_pathway_entities(&self, pathway: &str, k: usize) -> Result<Vec<(String, f64)>> {
204 let pathway_embedding = self
205 .pathway_embeddings
206 .get(pathway)
207 .ok_or_else(|| anyhow!("Pathway {} not found", pathway))?;
208
209 let mut scores = Vec::new();
210
211 for (gene, gene_embedding) in &self.gene_embeddings {
213 let similarity = pathway_embedding.dot(gene_embedding) as f64;
214 scores.push((gene.clone(), similarity));
215 }
216
217 for (protein, protein_embedding) in &self.protein_embeddings {
219 let similarity = pathway_embedding.dot(protein_embedding) as f64;
220 scores.push((protein.clone(), similarity));
221 }
222
223 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
224 scores.truncate(k);
225
226 Ok(scores)
227 }
228
229 fn extract_entity_types(&mut self) {
231 for triple in &self.triples {
232 if let Some(subject_type) = BiomedicalEntityType::from_iri(&triple.subject.iri) {
234 self.entity_types
235 .insert(triple.subject.iri.clone(), subject_type);
236 }
237
238 if let Some(object_type) = BiomedicalEntityType::from_iri(&triple.object.iri) {
239 self.entity_types
240 .insert(triple.object.iri.clone(), object_type);
241 }
242
243 if let Some(relation_type) = BiomedicalRelationType::from_iri(&triple.predicate.iri) {
245 self.relation_types
246 .insert(triple.predicate.iri.clone(), relation_type);
247 }
248 }
249 }
250
251 fn initialize_embeddings(&mut self) -> Result<()> {
253 let dimensions = self.config.base_config.dimensions;
254
255 for (entity, entity_type) in &self.entity_types {
257 let embedding = Array1::from_vec(
258 (0..dimensions)
259 .map(|_| {
260 let mut random = Random::default();
261 (random.random::<f32>() - 0.5) * 0.1
262 })
263 .collect(),
264 );
265
266 match entity_type {
267 BiomedicalEntityType::Gene => {
268 self.gene_embeddings.insert(entity.clone(), embedding);
269 }
270 BiomedicalEntityType::Protein => {
271 self.protein_embeddings.insert(entity.clone(), embedding);
272 }
273 BiomedicalEntityType::Disease => {
274 self.disease_embeddings.insert(entity.clone(), embedding);
275 }
276 BiomedicalEntityType::Drug => {
277 self.drug_embeddings.insert(entity.clone(), embedding);
278 }
279 BiomedicalEntityType::Compound => {
280 self.compound_embeddings.insert(entity.clone(), embedding);
281 }
282 BiomedicalEntityType::Pathway => {
283 self.pathway_embeddings.insert(entity.clone(), embedding);
284 }
285 _ => {
286 }
289 }
290 }
291
292 for relation in self.relation_types.keys() {
294 let embedding = Array1::from_vec(
295 (0..dimensions)
296 .map(|_| {
297 let mut random = Random::default();
298 (random.random::<f32>() - 0.5) * 0.1
299 })
300 .collect(),
301 );
302 self.relation_embeddings.insert(relation.clone(), embedding);
303 }
304
305 Ok(())
306 }
307
308 fn compute_biomedical_loss(&self) -> f32 {
310 let mut total_loss = 0.0;
311 let mut count = 0;
312
313 for ((gene, disease), &score) in &self.features.gene_disease_associations {
315 if let (Some(gene_emb), Some(disease_emb)) = (
316 self.gene_embeddings.get(gene),
317 self.disease_embeddings.get(disease),
318 ) {
319 let predicted_score = gene_emb.dot(disease_emb);
320 let loss = (predicted_score - score).powi(2);
321 total_loss += loss * self.config.gene_disease_weight;
322 count += 1;
323 }
324 }
325
326 for ((drug, target), &affinity) in &self.features.drug_target_affinities {
328 if let (Some(drug_emb), Some(target_emb)) = (
329 self.drug_embeddings.get(drug),
330 self.protein_embeddings.get(target),
331 ) {
332 let predicted_affinity = drug_emb.dot(target_emb);
333 let loss = (predicted_affinity - affinity).powi(2);
334 total_loss += loss * self.config.drug_target_weight;
335 count += 1;
336 }
337 }
338
339 for ((entity, pathway), &score) in &self.features.pathway_memberships {
341 if let Some(pathway_emb) = self.pathway_embeddings.get(pathway) {
342 let entity_emb = self.get_entity_embedding_any_type(entity);
343 if let Some(entity_emb) = entity_emb {
344 let predicted_score = entity_emb.dot(pathway_emb);
345 let loss = (predicted_score - score).powi(2);
346 total_loss += loss * self.config.pathway_weight;
347 count += 1;
348 }
349 }
350 }
351
352 if count > 0 {
353 total_loss / count as f32
354 } else {
355 0.0
356 }
357 }
358
359 fn get_entity_embedding_any_type(&self, entity: &str) -> Option<&Array1<f32>> {
361 self.gene_embeddings
362 .get(entity)
363 .or_else(|| self.protein_embeddings.get(entity))
364 .or_else(|| self.disease_embeddings.get(entity))
365 .or_else(|| self.drug_embeddings.get(entity))
366 .or_else(|| self.compound_embeddings.get(entity))
367 .or_else(|| self.pathway_embeddings.get(entity))
368 }
369}
370
371#[async_trait]
372impl EmbeddingModel for BiomedicalEmbedding {
373 fn config(&self) -> &ModelConfig {
374 &self.config.base_config
375 }
376
377 fn model_id(&self) -> &Uuid {
378 &self.model_id
379 }
380
381 fn model_type(&self) -> &'static str {
382 "BiomedicalEmbedding"
383 }
384
385 fn add_triple(&mut self, triple: Triple) -> Result<()> {
386 self.triples.push(triple);
387 Ok(())
388 }
389
390 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
391 let epochs = epochs.unwrap_or(1000);
392 let start_time = std::time::Instant::now();
393
394 self.extract_entity_types();
396
397 self.initialize_embeddings()?;
399
400 let mut loss_history = Vec::new();
402
403 for epoch in 0..epochs {
404 let epoch_loss = self.compute_biomedical_loss();
405 loss_history.push(epoch_loss as f64);
406
407 if epoch > 10 && epoch_loss < 0.001 {
409 break;
410 }
411
412 if epoch % 100 == 0 {
413 println!("Epoch {epoch}: Loss = {epoch_loss:.6}");
414 }
415 }
416
417 let training_time = start_time.elapsed().as_secs_f64();
418
419 self.training_stats = TrainingStats {
420 epochs_completed: epochs,
421 final_loss: loss_history.last().copied().unwrap_or(0.0),
422 training_time_seconds: training_time,
423 convergence_achieved: loss_history.last().is_some_and(|&loss| loss < 0.001),
424 loss_history,
425 };
426
427 self.is_trained = true;
428 self.model_stats.is_trained = true;
429 self.model_stats.last_training_time = Some(Utc::now());
430
431 self.model_stats.num_entities = self.entity_types.len();
433 self.model_stats.num_relations = self.relation_types.len();
434 self.model_stats.num_triples = self.triples.len();
435
436 Ok(self.training_stats.clone())
437 }
438
439 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
440 self.get_typed_entity_embedding(entity)
441 }
442
443 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
444 if let Some(embedding) = self.relation_embeddings.get(relation) {
445 Ok(Vector::from_array1(embedding))
446 } else {
447 Err(anyhow!("Relation {} not found", relation))
448 }
449 }
450
451 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
452 let subject_emb = self.get_entity_embedding(subject)?;
453 let relation_emb = self.get_relation_embedding(predicate)?;
454 let object_emb = self.get_entity_embedding(object)?;
455
456 let mut score = 0.0;
458 for i in 0..subject_emb.dimensions {
459 let diff = subject_emb.values[i] + relation_emb.values[i] - object_emb.values[i];
460 score += diff * diff;
461 }
462
463 Ok(1.0 / (1.0 + score as f64))
465 }
466
467 fn predict_objects(
468 &self,
469 subject: &str,
470 predicate: &str,
471 k: usize,
472 ) -> Result<Vec<(String, f64)>> {
473 if let Some(relation_type) = self.relation_types.get(predicate) {
475 match relation_type {
476 BiomedicalRelationType::CausesDisease
477 | BiomedicalRelationType::AssociatedWithDisease => {
478 return self.predict_gene_disease_associations(subject, k);
479 }
480 BiomedicalRelationType::TargetsProtein | BiomedicalRelationType::BindsToProtein => {
481 return self.predict_drug_targets(subject, k);
482 }
483 _ => {
484 }
486 }
487 }
488
489 let _subject_emb = self.get_entity_embedding(subject)?;
491 let _relation_emb = self.get_relation_embedding(predicate)?;
492
493 let mut scores = Vec::new();
494 for entity in self.entity_types.keys() {
495 if entity != subject {
496 if let Ok(score) = self.score_triple(subject, predicate, entity) {
497 scores.push((entity.clone(), score));
498 }
499 }
500 }
501
502 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
503 scores.truncate(k);
504
505 Ok(scores)
506 }
507
508 fn predict_subjects(
509 &self,
510 predicate: &str,
511 object: &str,
512 k: usize,
513 ) -> Result<Vec<(String, f64)>> {
514 let _object_emb = self.get_entity_embedding(object)?;
515 let _relation_emb = self.get_relation_embedding(predicate)?;
516
517 let mut scores = Vec::new();
518 for entity in self.entity_types.keys() {
519 if entity != object {
520 if let Ok(score) = self.score_triple(entity, predicate, object) {
521 scores.push((entity.clone(), score));
522 }
523 }
524 }
525
526 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
527 scores.truncate(k);
528
529 Ok(scores)
530 }
531
532 fn predict_relations(
533 &self,
534 subject: &str,
535 object: &str,
536 k: usize,
537 ) -> Result<Vec<(String, f64)>> {
538 let _subject_emb = self.get_entity_embedding(subject)?;
539 let _object_emb = self.get_entity_embedding(object)?;
540
541 let mut scores = Vec::new();
542 for relation in self.relation_types.keys() {
543 if let Ok(score) = self.score_triple(subject, relation, object) {
544 scores.push((relation.clone(), score));
545 }
546 }
547
548 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
549 scores.truncate(k);
550
551 Ok(scores)
552 }
553
554 fn get_entities(&self) -> Vec<String> {
555 self.entity_types.keys().cloned().collect()
556 }
557
558 fn get_relations(&self) -> Vec<String> {
559 self.relation_types.keys().cloned().collect()
560 }
561
562 fn get_stats(&self) -> ModelStats {
563 self.model_stats.clone()
564 }
565
566 fn save(&self, _path: &str) -> Result<()> {
567 Ok(())
569 }
570
571 fn load(&mut self, _path: &str) -> Result<()> {
572 Ok(())
574 }
575
576 fn clear(&mut self) {
577 self.gene_embeddings.clear();
578 self.protein_embeddings.clear();
579 self.disease_embeddings.clear();
580 self.drug_embeddings.clear();
581 self.compound_embeddings.clear();
582 self.pathway_embeddings.clear();
583 self.relation_embeddings.clear();
584 self.entity_types.clear();
585 self.relation_types.clear();
586 self.triples.clear();
587 self.features = BiomedicalFeatures::default();
588 self.is_trained = false;
589 }
590
591 fn is_trained(&self) -> bool {
592 self.is_trained
593 }
594
595 async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
596 let mut embeddings = Vec::new();
597
598 for text in texts {
599 match self.get_entity_embedding(text) {
600 Ok(embedding) => {
601 embeddings.push(embedding.values);
602 }
603 _ => {
604 embeddings.push(vec![0.0; self.config.base_config.dimensions]);
606 }
607 }
608 }
609
610 Ok(embeddings)
611 }
612}