use crate::{ModelConfig, ModelStats, TrainingStats, Triple};
use scirs2_core::ndarray_ext::Array1;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum BiomedicalEntityType {
Gene,
Protein,
Disease,
Drug,
Compound,
Pathway,
Cell,
Tissue,
Organ,
Phenotype,
GoTerm,
MeshTerm,
SnomedCt,
IcdCode,
}
impl BiomedicalEntityType {
pub fn namespace(&self) -> &'static str {
match self {
BiomedicalEntityType::Gene => "gene",
BiomedicalEntityType::Protein => "protein",
BiomedicalEntityType::Disease => "disease",
BiomedicalEntityType::Drug => "drug",
BiomedicalEntityType::Compound => "compound",
BiomedicalEntityType::Pathway => "pathway",
BiomedicalEntityType::Cell => "cell",
BiomedicalEntityType::Tissue => "tissue",
BiomedicalEntityType::Organ => "organ",
BiomedicalEntityType::Phenotype => "phenotype",
BiomedicalEntityType::GoTerm => "go",
BiomedicalEntityType::MeshTerm => "mesh",
BiomedicalEntityType::SnomedCt => "snomed",
BiomedicalEntityType::IcdCode => "icd",
}
}
pub fn from_iri(iri: &str) -> Option<Self> {
if iri.contains("gene") || iri.contains("HGNC") {
Some(BiomedicalEntityType::Gene)
} else if iri.contains("protein") || iri.contains("UniProt") {
Some(BiomedicalEntityType::Protein)
} else if iri.contains("disease") || iri.contains("OMIM") || iri.contains("DOID") {
Some(BiomedicalEntityType::Disease)
} else if iri.contains("drug") || iri.contains("DrugBank") {
Some(BiomedicalEntityType::Drug)
} else if iri.contains("compound") || iri.contains("CHEBI") {
Some(BiomedicalEntityType::Compound)
} else if iri.contains("pathway") || iri.contains("KEGG") || iri.contains("Reactome") {
Some(BiomedicalEntityType::Pathway)
} else if iri.contains("GO:") {
Some(BiomedicalEntityType::GoTerm)
} else if iri.contains("MESH") {
Some(BiomedicalEntityType::MeshTerm)
} else if iri.contains("SNOMED") {
Some(BiomedicalEntityType::SnomedCt)
} else if iri.contains("ICD") {
Some(BiomedicalEntityType::IcdCode)
} else {
None
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum BiomedicalRelationType {
CausesDisease,
AssociatedWithDisease,
PredisposesToDisease,
TargetsProtein,
InhibitsProtein,
ActivatesProtein,
BindsToProtein,
ParticipatesInPathway,
RegulatesPathway,
UpstreamOfPathway,
DownstreamOfPathway,
InteractsWith,
PhysicallyInteractsWith,
FunctionallyInteractsWith,
MetabolizedBy,
TransportedBy,
Catalyzes,
IsASubtypeOf,
PartOf,
HasPhenotype,
ExpressedIn,
Overexpressed,
Underexpressed,
}
impl BiomedicalRelationType {
pub fn from_iri(iri: &str) -> Option<Self> {
match iri.to_lowercase().as_str() {
s if s.contains("causes") => Some(BiomedicalRelationType::CausesDisease),
s if s.contains("associated_with") => {
Some(BiomedicalRelationType::AssociatedWithDisease)
}
s if s.contains("targets") => Some(BiomedicalRelationType::TargetsProtein),
s if s.contains("inhibits") => Some(BiomedicalRelationType::InhibitsProtein),
s if s.contains("activates") => Some(BiomedicalRelationType::ActivatesProtein),
s if s.contains("binds") => Some(BiomedicalRelationType::BindsToProtein),
s if s.contains("participates") => Some(BiomedicalRelationType::ParticipatesInPathway),
s if s.contains("interacts") => Some(BiomedicalRelationType::InteractsWith),
s if s.contains("metabolized") => Some(BiomedicalRelationType::MetabolizedBy),
s if s.contains("expressed") => Some(BiomedicalRelationType::ExpressedIn),
s if s.contains("subtype") => Some(BiomedicalRelationType::IsASubtypeOf),
s if s.contains("part_of") => Some(BiomedicalRelationType::PartOf),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BiomedicalEmbeddingConfig {
pub base_config: ModelConfig,
pub gene_disease_weight: f32,
pub drug_target_weight: f32,
pub pathway_weight: f32,
pub protein_interaction_weight: f32,
pub use_sequence_similarity: bool,
pub use_chemical_structure: bool,
pub use_taxonomy: bool,
pub use_temporal_features: bool,
pub species_filter: Option<String>,
}
impl Default for BiomedicalEmbeddingConfig {
fn default() -> Self {
Self {
base_config: ModelConfig::default(),
gene_disease_weight: 2.0,
drug_target_weight: 1.5,
pathway_weight: 1.2,
protein_interaction_weight: 1.0,
use_sequence_similarity: true,
use_chemical_structure: true,
use_taxonomy: true,
use_temporal_features: false,
species_filter: Some("Homo sapiens".to_string()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BiomedicalEmbedding {
pub config: BiomedicalEmbeddingConfig,
pub model_id: Uuid,
pub gene_embeddings: HashMap<String, Array1<f32>>,
pub protein_embeddings: HashMap<String, Array1<f32>>,
pub disease_embeddings: HashMap<String, Array1<f32>>,
pub drug_embeddings: HashMap<String, Array1<f32>>,
pub compound_embeddings: HashMap<String, Array1<f32>>,
pub pathway_embeddings: HashMap<String, Array1<f32>>,
pub relation_embeddings: HashMap<String, Array1<f32>>,
pub entity_types: HashMap<String, BiomedicalEntityType>,
pub relation_types: HashMap<String, BiomedicalRelationType>,
pub triples: Vec<Triple>,
pub features: BiomedicalFeatures,
pub training_stats: TrainingStats,
pub model_stats: ModelStats,
pub is_trained: bool,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct BiomedicalFeatures {
pub gene_disease_associations: HashMap<(String, String), f32>,
pub drug_target_affinities: HashMap<(String, String), f32>,
pub pathway_memberships: HashMap<(String, String), f32>,
pub protein_interactions: HashMap<(String, String), f32>,
pub sequence_similarities: HashMap<(String, String), f32>,
pub structure_similarities: HashMap<(String, String), f32>,
pub expression_correlations: HashMap<(String, String), f32>,
pub tissue_expression: HashMap<(String, String), f32>,
}