use crate::symbolic_embeddings::SymbolicEmbeddingIndex;
use crate::{RoleGraph, magic_pair};
use ahash::{AHashMap, AHashSet};
use terraphim_types::{MedicalEdgeType, MedicalNodeType, RoleName, Thesaurus};
#[derive(thiserror::Error, Debug)]
pub enum MedicalRoleGraphError {
#[error("RoleGraph error: {0}")]
RoleGraphError(#[from] crate::Error),
}
type Result<T> = std::result::Result<T, MedicalRoleGraphError>;
pub struct MedicalRoleGraph {
pub role_graph: RoleGraph,
node_types: AHashMap<u64, MedicalNodeType>,
node_terms: AHashMap<u64, String>,
edge_types: AHashMap<u64, MedicalEdgeType>,
outgoing_edges: AHashMap<u64, Vec<(u64, MedicalEdgeType)>>,
incoming_edges: AHashMap<u64, Vec<(u64, MedicalEdgeType)>>,
isa_parents: AHashMap<u64, AHashSet<u64>>,
isa_children: AHashMap<u64, AHashSet<u64>>,
snomed_to_id: AHashMap<u64, u64>,
embedding_index: Option<SymbolicEmbeddingIndex>,
}
impl MedicalRoleGraph {
pub async fn new(role: RoleName, thesaurus: Thesaurus) -> Result<Self> {
let role_graph = RoleGraph::new(role, thesaurus).await?;
Ok(Self {
role_graph,
node_types: AHashMap::new(),
node_terms: AHashMap::new(),
edge_types: AHashMap::new(),
outgoing_edges: AHashMap::new(),
incoming_edges: AHashMap::new(),
isa_parents: AHashMap::new(),
isa_children: AHashMap::new(),
snomed_to_id: AHashMap::new(),
embedding_index: None,
})
}
pub fn new_empty() -> Result<Self> {
let empty_thesaurus = Thesaurus::new("empty".to_string());
let role_graph = RoleGraph::new_sync("empty".into(), empty_thesaurus)?;
Ok(Self {
role_graph,
node_types: AHashMap::new(),
node_terms: AHashMap::new(),
edge_types: AHashMap::new(),
outgoing_edges: AHashMap::new(),
incoming_edges: AHashMap::new(),
isa_parents: AHashMap::new(),
isa_children: AHashMap::new(),
snomed_to_id: AHashMap::new(),
embedding_index: None,
})
}
pub fn add_medical_node(
&mut self,
id: u64,
term: String,
node_type: MedicalNodeType,
snomed_id: Option<u64>,
) {
self.node_types.insert(id, node_type);
self.node_terms.insert(id, term);
if let Some(sid) = snomed_id {
self.snomed_to_id.insert(sid, id);
}
self.embedding_index = None;
}
pub fn add_medical_edge(&mut self, source: u64, target: u64, edge_type: MedicalEdgeType) {
let edge_id = magic_pair(source, target);
self.edge_types.insert(edge_id, edge_type);
self.outgoing_edges
.entry(source)
.or_default()
.push((target, edge_type));
self.incoming_edges
.entry(target)
.or_default()
.push((source, edge_type));
if edge_type == MedicalEdgeType::IsA {
self.isa_parents.entry(source).or_default().insert(target);
self.isa_children.entry(target).or_default().insert(source);
self.embedding_index = None;
}
}
pub fn get_node_type(&self, id: u64) -> Option<MedicalNodeType> {
self.node_types.get(&id).copied()
}
pub fn get_node_term(&self, id: u64) -> Option<&str> {
self.node_terms.get(&id).map(|s| s.as_str())
}
pub fn get_edge_type(&self, source: u64, target: u64) -> Option<MedicalEdgeType> {
let edge_id = magic_pair(source, target);
self.edge_types.get(&edge_id).copied()
}
pub fn get_ancestors(&self, node_id: u64) -> Vec<u64> {
let mut ancestors = AHashSet::new();
let mut stack = Vec::new();
if let Some(parents) = self.isa_parents.get(&node_id) {
for &parent in parents {
stack.push(parent);
}
}
while let Some(current) = stack.pop() {
if ancestors.insert(current) {
if let Some(parents) = self.isa_parents.get(¤t) {
for &parent in parents {
if !ancestors.contains(&parent) {
stack.push(parent);
}
}
}
}
}
ancestors.into_iter().collect()
}
pub fn get_descendants(&self, node_id: u64) -> Vec<u64> {
let mut descendants = AHashSet::new();
let mut stack = Vec::new();
if let Some(children) = self.isa_children.get(&node_id) {
for &child in children {
stack.push(child);
}
}
while let Some(current) = stack.pop() {
if descendants.insert(current) {
if let Some(children) = self.isa_children.get(¤t) {
for &child in children {
if !descendants.contains(&child) {
stack.push(child);
}
}
}
}
}
descendants.into_iter().collect()
}
pub fn get_treatments(&self, condition_id: u64) -> Vec<u64> {
let mut treatments = Vec::new();
if let Some(edges) = self.outgoing_edges.get(&condition_id) {
for &(target, edge_type) in edges {
if edge_type == MedicalEdgeType::Treats {
treatments.push(target);
}
}
}
if let Some(edges) = self.incoming_edges.get(&condition_id) {
for &(source, edge_type) in edges {
if edge_type == MedicalEdgeType::Treats {
treatments.push(source);
}
}
}
treatments
}
pub fn check_contraindication(&self, drug_id: u64, conditions: &[u64]) -> Vec<(u64, u64)> {
let mut contraindications = Vec::new();
for &condition_id in conditions {
let is_contraindicated =
self.outgoing_edges.get(&drug_id).is_some_and(|edges| {
edges
.iter()
.any(|&(t, et)| t == condition_id && et == MedicalEdgeType::Contraindicates)
})
||
self.outgoing_edges.get(&condition_id).is_some_and(|edges| {
edges
.iter()
.any(|&(t, et)| t == drug_id && et == MedicalEdgeType::Contraindicates)
});
if is_contraindicated {
contraindications.push((drug_id, condition_id));
}
}
contraindications
}
pub fn build_embeddings(&mut self) {
let index =
SymbolicEmbeddingIndex::build_from_hierarchy(&self.isa_parents, &self.node_types);
self.embedding_index = Some(index);
}
pub fn symbolic_similarity(&self, a: u64, b: u64) -> Option<f64> {
self.embedding_index.as_ref()?.similarity(a, b)
}
pub fn find_similar(&self, node_id: u64, k: usize) -> Vec<(u64, f64)> {
match &self.embedding_index {
Some(index) => index.nearest_neighbors(node_id, k),
None => Vec::new(),
}
}
pub fn node_count(&self) -> usize {
self.node_types.len()
}
pub fn medical_edge_count(&self) -> usize {
self.edge_types.len()
}
pub fn isa_edge_count(&self) -> usize {
self.isa_parents.values().map(|s| s.len()).sum()
}
pub fn snomed_to_node_id(&self, snomed_id: u64) -> Option<u64> {
self.snomed_to_id.get(&snomed_id).copied()
}
pub fn embedding_index(&self) -> Option<&SymbolicEmbeddingIndex> {
self.embedding_index.as_ref()
}
pub fn iter_node_terms(&self) -> impl Iterator<Item = (u64, &str)> {
self.node_terms
.iter()
.map(|(&id, term)| (id, term.as_str()))
}
}
impl std::fmt::Debug for MedicalRoleGraph {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MedicalRoleGraph")
.field("node_count", &self.node_types.len())
.field("edge_count", &self.edge_types.len())
.field("isa_edge_count", &self.isa_edge_count())
.field("snomed_mappings", &self.snomed_to_id.len())
.field("embeddings_built", &self.embedding_index.is_some())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::magic_pair;
use terraphim_types::Thesaurus;
async fn create_test_medical_rolegraph() -> MedicalRoleGraph {
let role = "medical test".to_string();
let thesaurus = Thesaurus::new("empty".to_string());
MedicalRoleGraph::new(role.into(), thesaurus)
.await
.expect("Failed to create MedicalRoleGraph")
}
async fn create_populated_medical_rolegraph() -> MedicalRoleGraph {
let mut mrg = create_test_medical_rolegraph().await;
mrg.add_medical_node(
100,
"Disease".to_string(),
MedicalNodeType::Disease,
Some(64572001),
);
mrg.add_medical_node(
101,
"Cancer".to_string(),
MedicalNodeType::Disease,
Some(363346000),
);
mrg.add_medical_node(
102,
"Infection".to_string(),
MedicalNodeType::Disease,
Some(40733004),
);
mrg.add_medical_node(
103,
"Lung Cancer".to_string(),
MedicalNodeType::Disease,
Some(93880001),
);
mrg.add_medical_node(
104,
"Breast Cancer".to_string(),
MedicalNodeType::Disease,
Some(254837009),
);
mrg.add_medical_node(200, "Cisplatin".to_string(), MedicalNodeType::Drug, None);
mrg.add_medical_node(201, "Tamoxifen".to_string(), MedicalNodeType::Drug, None);
mrg.add_medical_node(202, "Aspirin".to_string(), MedicalNodeType::Drug, None);
mrg.add_medical_edge(101, 100, MedicalEdgeType::IsA); mrg.add_medical_edge(102, 100, MedicalEdgeType::IsA); mrg.add_medical_edge(103, 101, MedicalEdgeType::IsA); mrg.add_medical_edge(104, 101, MedicalEdgeType::IsA);
mrg.add_medical_edge(200, 103, MedicalEdgeType::Treats); mrg.add_medical_edge(201, 104, MedicalEdgeType::Treats);
mrg.add_medical_edge(202, 103, MedicalEdgeType::Contraindicates);
mrg
}
#[tokio::test]
async fn test_create_medical_rolegraph() {
let mrg = create_test_medical_rolegraph().await;
assert_eq!(mrg.node_count(), 0);
assert_eq!(mrg.medical_edge_count(), 0);
assert_eq!(mrg.isa_edge_count(), 0);
}
#[tokio::test]
async fn test_add_medical_node() {
let mut mrg = create_test_medical_rolegraph().await;
mrg.add_medical_node(
1,
"Diabetes".to_string(),
MedicalNodeType::Disease,
Some(73211009),
);
assert_eq!(mrg.node_count(), 1);
assert_eq!(mrg.get_node_type(1), Some(MedicalNodeType::Disease));
assert_eq!(mrg.get_node_term(1), Some("Diabetes"));
assert_eq!(mrg.snomed_to_node_id(73211009), Some(1));
}
#[tokio::test]
async fn test_add_medical_node_no_snomed() {
let mut mrg = create_test_medical_rolegraph().await;
mrg.add_medical_node(1, "Unknown Drug".to_string(), MedicalNodeType::Drug, None);
assert_eq!(mrg.get_node_type(1), Some(MedicalNodeType::Drug));
assert_eq!(mrg.get_node_term(1), Some("Unknown Drug"));
}
#[tokio::test]
async fn test_add_medical_edge() {
let mut mrg = create_test_medical_rolegraph().await;
mrg.add_medical_node(1, "Drug A".to_string(), MedicalNodeType::Drug, None);
mrg.add_medical_node(2, "Disease B".to_string(), MedicalNodeType::Disease, None);
mrg.add_medical_edge(1, 2, MedicalEdgeType::Treats);
assert_eq!(mrg.medical_edge_count(), 1);
assert_eq!(mrg.get_edge_type(1, 2), Some(MedicalEdgeType::Treats));
}
#[tokio::test]
async fn test_magic_pair_edge_encoding() {
let edge_id = magic_pair(100, 200);
let mut mrg = create_test_medical_rolegraph().await;
mrg.add_medical_node(100, "Node A".to_string(), MedicalNodeType::Concept, None);
mrg.add_medical_node(200, "Node B".to_string(), MedicalNodeType::Concept, None);
mrg.add_medical_edge(100, 200, MedicalEdgeType::RelatedTo);
assert_eq!(
mrg.get_edge_type(100, 200),
Some(MedicalEdgeType::RelatedTo)
);
assert!(mrg.edge_types.contains_key(&edge_id));
}
#[tokio::test]
async fn test_isa_hierarchy() {
let mrg = create_populated_medical_rolegraph().await;
assert_eq!(mrg.isa_edge_count(), 4);
let lung_ancestors = mrg.get_ancestors(103);
assert!(
lung_ancestors.contains(&101),
"Lung Cancer should have Cancer as ancestor"
);
assert!(
lung_ancestors.contains(&100),
"Lung Cancer should have Disease as ancestor"
);
assert_eq!(lung_ancestors.len(), 2);
let disease_descendants = mrg.get_descendants(100);
assert!(disease_descendants.contains(&101));
assert!(disease_descendants.contains(&102));
assert!(disease_descendants.contains(&103));
assert!(disease_descendants.contains(&104));
assert_eq!(disease_descendants.len(), 4);
}
#[tokio::test]
async fn test_ancestors_empty_for_root() {
let mrg = create_populated_medical_rolegraph().await;
let root_ancestors = mrg.get_ancestors(100);
assert!(
root_ancestors.is_empty(),
"Root node (Disease) should have no ancestors"
);
}
#[tokio::test]
async fn test_descendants_empty_for_leaf() {
let mrg = create_populated_medical_rolegraph().await;
let leaf_descendants = mrg.get_descendants(103);
assert!(
leaf_descendants.is_empty(),
"Leaf node (Lung Cancer) should have no descendants"
);
}
#[tokio::test]
async fn test_get_treatments() {
let mrg = create_populated_medical_rolegraph().await;
let lung_treatments = mrg.get_treatments(103);
assert_eq!(lung_treatments.len(), 1);
assert!(
lung_treatments.contains(&200),
"Cisplatin (200) should treat Lung Cancer (103)"
);
let breast_treatments = mrg.get_treatments(104);
assert_eq!(breast_treatments.len(), 1);
assert!(
breast_treatments.contains(&201),
"Tamoxifen (201) should treat Breast Cancer (104)"
);
let disease_treatments = mrg.get_treatments(100);
assert!(disease_treatments.is_empty());
}
#[tokio::test]
async fn test_check_contraindication() {
let mrg = create_populated_medical_rolegraph().await;
let contras = mrg.check_contraindication(202, &[103, 104]);
assert_eq!(contras.len(), 1);
assert_eq!(contras[0], (202, 103));
let no_contras = mrg.check_contraindication(200, &[103, 104]);
assert!(no_contras.is_empty());
}
#[tokio::test]
async fn test_build_embeddings() {
let mut mrg = create_populated_medical_rolegraph().await;
assert!(
mrg.embedding_index().is_none(),
"Embeddings should not be built initially"
);
mrg.build_embeddings();
assert!(
mrg.embedding_index().is_some(),
"Embeddings should be built after build_embeddings()"
);
}
#[tokio::test]
async fn test_symbolic_similarity() {
let mut mrg = create_populated_medical_rolegraph().await;
mrg.build_embeddings();
let self_sim = mrg.symbolic_similarity(103, 103);
assert_eq!(self_sim, Some(1.0));
let sibling_sim = mrg.symbolic_similarity(103, 104).unwrap();
let distant_sim = mrg.symbolic_similarity(103, 102).unwrap();
assert!(
sibling_sim > distant_sim,
"Siblings (Lung Cancer/Breast Cancer) should be more similar ({sibling_sim}) than distant nodes ({distant_sim})"
);
}
#[tokio::test]
async fn test_symbolic_similarity_without_embeddings() {
let mrg = create_populated_medical_rolegraph().await;
assert!(
mrg.symbolic_similarity(103, 104).is_none(),
"Similarity should return None when embeddings are not built"
);
}
#[tokio::test]
async fn test_find_similar() {
let mut mrg = create_populated_medical_rolegraph().await;
mrg.build_embeddings();
let similar = mrg.find_similar(103, 3);
assert!(!similar.is_empty(), "Should find similar nodes");
for window in similar.windows(2) {
assert!(window[0].1 >= window[1].1);
}
}
#[tokio::test]
async fn test_find_similar_without_embeddings() {
let mrg = create_populated_medical_rolegraph().await;
let similar = mrg.find_similar(103, 3);
assert!(
similar.is_empty(),
"find_similar should return empty vec without embeddings"
);
}
#[tokio::test]
async fn test_embedding_invalidation_on_node_add() {
let mut mrg = create_populated_medical_rolegraph().await;
mrg.build_embeddings();
assert!(mrg.embedding_index().is_some());
mrg.add_medical_node(300, "New Node".to_string(), MedicalNodeType::Concept, None);
assert!(
mrg.embedding_index().is_none(),
"Embeddings should be invalidated after adding a node"
);
}
#[tokio::test]
async fn test_embedding_invalidation_on_isa_edge_add() {
let mut mrg = create_populated_medical_rolegraph().await;
mrg.build_embeddings();
assert!(mrg.embedding_index().is_some());
mrg.add_medical_edge(102, 101, MedicalEdgeType::IsA);
assert!(
mrg.embedding_index().is_none(),
"Embeddings should be invalidated after adding IS-A edge"
);
}
#[tokio::test]
async fn test_non_isa_edge_does_not_invalidate_embeddings() {
let mut mrg = create_populated_medical_rolegraph().await;
mrg.build_embeddings();
assert!(mrg.embedding_index().is_some());
mrg.add_medical_edge(200, 102, MedicalEdgeType::Treats);
assert!(
mrg.embedding_index().is_some(),
"Non-IS-A edges should not invalidate embeddings"
);
}
#[tokio::test]
async fn test_node_counts() {
let mrg = create_populated_medical_rolegraph().await;
assert_eq!(mrg.node_count(), 8);
assert_eq!(mrg.medical_edge_count(), 7);
assert_eq!(mrg.isa_edge_count(), 4);
}
#[tokio::test]
async fn test_debug_output() {
let mrg = create_populated_medical_rolegraph().await;
let debug = format!("{:?}", mrg);
assert!(debug.contains("MedicalRoleGraph"));
assert!(debug.contains("node_count"));
assert!(debug.contains("edge_count"));
}
#[tokio::test]
async fn test_get_nonexistent_node() {
let mrg = create_test_medical_rolegraph().await;
assert!(mrg.get_node_type(999).is_none());
assert!(mrg.get_node_term(999).is_none());
assert!(mrg.snomed_to_node_id(999).is_none());
}
#[tokio::test]
async fn test_get_nonexistent_edge() {
let mrg = create_test_medical_rolegraph().await;
assert!(mrg.get_edge_type(1, 2).is_none());
}
#[tokio::test]
async fn test_ancestors_nonexistent_node() {
let mrg = create_populated_medical_rolegraph().await;
let ancestors = mrg.get_ancestors(999);
assert!(ancestors.is_empty());
}
#[tokio::test]
async fn test_descendants_nonexistent_node() {
let mrg = create_populated_medical_rolegraph().await;
let descendants = mrg.get_descendants(999);
assert!(descendants.is_empty());
}
#[tokio::test]
async fn test_multiple_snomed_mappings() {
let mut mrg = create_test_medical_rolegraph().await;
mrg.add_medical_node(
1,
"Disease A".to_string(),
MedicalNodeType::Disease,
Some(100001),
);
mrg.add_medical_node(
2,
"Disease B".to_string(),
MedicalNodeType::Disease,
Some(100002),
);
mrg.add_medical_node(
3,
"Disease C".to_string(),
MedicalNodeType::Disease,
Some(100003),
);
assert_eq!(mrg.snomed_to_node_id(100001), Some(1));
assert_eq!(mrg.snomed_to_node_id(100002), Some(2));
assert_eq!(mrg.snomed_to_node_id(100003), Some(3));
assert_eq!(mrg.snomed_to_node_id(999999), None);
}
#[tokio::test]
async fn test_role_graph_accessible() {
let mrg = create_test_medical_rolegraph().await;
assert_eq!(mrg.role_graph.get_node_count(), 0);
assert_eq!(mrg.role_graph.get_edge_count(), 0);
}
}