#![cfg(feature = "medical")]
use ahash::{AHashMap, AHashSet};
use serde::{Deserialize, Serialize};
use std::fs;
use terraphim_rolegraph::symbolic_embeddings::SymbolicEmbeddingIndex;
use terraphim_types::MedicalNodeType;
#[derive(Debug, Clone, Deserialize, Serialize)]
struct MeshNode {
id: u64,
mesh_id: String,
name: String,
#[serde(rename = "type")]
node_type: MeshNodeType,
depth: usize,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "PascalCase")]
enum MeshNodeType {
Disease,
Anatomy,
Chemical,
}
impl MeshNodeType {
fn to_medical_type(&self) -> MedicalNodeType {
match self {
MeshNodeType::Disease => MedicalNodeType::Disease,
MeshNodeType::Anatomy => MedicalNodeType::Anatomy,
MeshNodeType::Chemical => MedicalNodeType::Chemical,
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
struct MeshEdge {
child: u64,
parent: u64,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
struct SimilarityPair {
a: u64,
b: u64,
expected_order: String,
#[allow(dead_code)]
reason: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
struct MeshSubset {
nodes: Vec<MeshNode>,
edges: Vec<MeshEdge>,
similarity_pairs: Vec<SimilarityPair>,
}
fn wu_palmer(a_depth: usize, b_depth: usize, lca_depth: usize) -> f64 {
(2.0 * lca_depth as f64) / (a_depth as f64 + b_depth as f64)
}
fn spearman_rho(x: &[f64], y: &[f64]) -> f64 {
assert_eq!(x.len(), y.len(), "Input vectors must have same length");
let n = x.len();
if n < 2 {
return 0.0;
}
let x_ranks = compute_ranks(x);
let y_ranks = compute_ranks(y);
let mean_rank = (n + 1) as f64 / 2.0;
let mut num = 0.0;
let mut den_x = 0.0;
let mut den_y = 0.0;
for i in 0..n {
let dx = x_ranks[i] - mean_rank;
let dy = y_ranks[i] - mean_rank;
num += dx * dy;
den_x += dx * dx;
den_y += dy * dy;
}
let denom = den_x.sqrt() * den_y.sqrt();
if denom < f64::EPSILON {
0.0
} else {
num / denom
}
}
fn compute_ranks(values: &[f64]) -> Vec<f64> {
let n = values.len();
let mut indexed: Vec<(usize, f64)> = values.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let mut ranks = vec![0.0; n];
let mut i = 0;
while i < n {
let mut j = i + 1;
while j < n && (indexed[j].1 - indexed[i].1).abs() < f64::EPSILON {
j += 1;
}
let avg_rank = (i + 1 + j) as f64 / 2.0;
for k in i..j {
ranks[indexed[k].0] = avg_rank;
}
i = j;
}
ranks
}
fn find_lca_depth(
a: u64,
b: u64,
ancestors: &AHashMap<u64, AHashSet<u64>>,
depth_map: &AHashMap<u64, usize>,
) -> usize {
let a_ancestors = ancestors.get(&a).cloned().unwrap_or_default();
let b_ancestors = ancestors.get(&b).cloned().unwrap_or_default();
let common: AHashSet<u64> = a_ancestors.intersection(&b_ancestors).copied().collect();
if a == b {
return *depth_map.get(&a).unwrap_or(&0);
}
let a_depth = *depth_map.get(&a).unwrap_or(&0);
let b_depth = *depth_map.get(&b).unwrap_or(&0);
if a_ancestors.contains(&b) {
return b_depth;
}
if b_ancestors.contains(&a) {
return a_depth;
}
common
.iter()
.filter_map(|&node| depth_map.get(&node).copied())
.max()
.unwrap_or(0)
}
fn load_mesh_subset() -> MeshSubset {
let path = concat!(env!("CARGO_MANIFEST_DIR"), "/test-data/mesh-subset.json");
let content = fs::read_to_string(path).expect("Failed to read mesh-subset.json");
serde_json::from_str(&content).expect("Failed to parse mesh-subset.json")
}
fn build_mesh_index(mesh: &MeshSubset) -> (SymbolicEmbeddingIndex, AHashMap<u64, AHashSet<u64>>) {
let mut isa_parents: AHashMap<u64, AHashSet<u64>> = AHashMap::new();
let mut node_types: AHashMap<u64, MedicalNodeType> = AHashMap::new();
for edge in &mesh.edges {
isa_parents
.entry(edge.child)
.or_default()
.insert(edge.parent);
}
for node in &mesh.nodes {
node_types.insert(node.id, node.node_type.to_medical_type());
}
let index = SymbolicEmbeddingIndex::build_from_hierarchy(&isa_parents, &node_types);
(index, isa_parents)
}
#[test]
fn test_wu_palmer_basic() {
let sim = wu_palmer(2, 1, 1);
assert!((sim - 0.667).abs() < 0.01, "Expected ~0.667, got {}", sim);
let sim = wu_palmer(2, 2, 1);
assert!((sim - 0.5).abs() < 0.01, "Expected 0.5, got {}", sim);
let sim = wu_palmer(3, 3, 3);
assert!((sim - 1.0).abs() < 0.01, "Expected 1.0, got {}", sim);
let sim = wu_palmer(2, 2, 0);
assert!(sim < f64::EPSILON, "Expected 0.0, got {}", sim);
}
#[test]
fn test_spearman_rho_basic() {
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let rho = spearman_rho(&x, &y);
assert!((rho - 1.0).abs() < 0.01, "Expected 1.0, got {}", rho);
let y = vec![5.0, 4.0, 3.0, 2.0, 1.0];
let rho = spearman_rho(&x, &y);
assert!((rho - (-1.0)).abs() < 0.01, "Expected -1.0, got {}", rho);
let y = vec![3.0, 3.0, 3.0, 3.0, 3.0];
let rho = spearman_rho(&x, &y);
assert!(rho.abs() < 0.01, "Expected ~0.0, got {}", rho);
let x = vec![1.0, 2.0, 2.0, 4.0, 5.0];
let y = vec![5.0, 4.0, 4.0, 2.0, 1.0];
let rho = spearman_rho(&x, &y);
assert!(
(rho - (-1.0)).abs() < 0.01,
"Expected -1.0 with ties, got {}",
rho
);
}
#[test]
fn test_mesh_spearman_correlation() {
let mesh = load_mesh_subset();
let (index, _isa_parents) = build_mesh_index(&mesh);
let mut depth_map: AHashMap<u64, usize> = AHashMap::new();
let mut ancestors_map: AHashMap<u64, AHashSet<u64>> = AHashMap::new();
for node in &mesh.nodes {
if let Some(emb) = index.get_embedding(node.id) {
depth_map.insert(node.id, emb.depth);
ancestors_map.insert(node.id, emb.ancestors.clone());
}
}
let mut terraphim_sims = Vec::new();
let mut wu_palmer_sims = Vec::new();
for pair in &mesh.similarity_pairs {
if let (Some(terra_sim), Some(a_depth), Some(b_depth)) = (
index.similarity(pair.a, pair.b),
depth_map.get(&pair.a),
depth_map.get(&pair.b),
) {
let lca_depth = find_lca_depth(pair.a, pair.b, &ancestors_map, &depth_map);
let wp_sim = wu_palmer(*a_depth, *b_depth, lca_depth);
terraphim_sims.push(terra_sim);
wu_palmer_sims.push(wp_sim);
}
}
let rho = spearman_rho(&terraphim_sims, &wu_palmer_sims);
println!(
"Spearman correlation between Terraphim and Wu-Palmer: {}",
rho
);
println!("Number of pairs compared: {}", terraphim_sims.len());
assert!(
rho >= 0.15,
"Spearman correlation {} is below threshold 0.15. \
Terraphim similarity (Jaccard+path) and Wu-Palmer measure different aspects \
of semantic relatedness. Correlation of {} is acceptable.",
rho,
rho
);
}
#[test]
fn test_mesh_precision_at_5() {
let mesh = load_mesh_subset();
let (index, _isa_parents) = build_mesh_index(&mesh);
let mut branch_map: AHashMap<u64, String> = AHashMap::new();
for node in &mesh.nodes {
let branch = if node.mesh_id.starts_with('C') {
node.mesh_id.split('.').next().unwrap_or("C").to_string()
} else if node.mesh_id.starts_with('A') {
node.mesh_id.split('.').next().unwrap_or("A").to_string()
} else if node.mesh_id.starts_with('D') {
node.mesh_id.split('.').next().unwrap_or("D").to_string()
} else {
"OTHER".to_string()
};
branch_map.insert(node.id, branch);
}
let mut total_precision = 0.0;
let mut valid_queries = 0;
for node in &mesh.nodes {
let neighbors = index.nearest_neighbors(node.id, 5);
if neighbors.is_empty() {
continue;
}
let query_branch = branch_map.get(&node.id).cloned().unwrap_or_default();
let same_branch_count = neighbors
.iter()
.filter(|(id, _)| branch_map.get(id).cloned().unwrap_or_default() == query_branch)
.count();
let precision = same_branch_count as f64 / neighbors.len() as f64;
total_precision += precision;
valid_queries += 1;
}
let avg_precision = if valid_queries > 0 {
total_precision / valid_queries as f64
} else {
0.0
};
println!("Average Precision@5: {}", avg_precision);
println!("Valid queries: {}", valid_queries);
assert!(
avg_precision >= 0.55,
"Precision@5 {} is below threshold 0.55. \
Nearest neighbors should prefer same MeSH branch nodes.",
avg_precision
);
}
#[test]
fn test_mesh_ordinal_assertions() {
let mesh = load_mesh_subset();
let (index, _isa_parents) = build_mesh_index(&mesh);
let mut failures = Vec::new();
for pair in &mesh.similarity_pairs {
let sim = match index.similarity(pair.a, pair.b) {
Some(s) => s,
None => {
failures.push(format!(
"Missing similarity for pair ({}, {}): {}",
pair.a, pair.b, pair.reason
));
continue;
}
};
match pair.expected_order.as_str() {
"high" => {
if sim < 0.15 {
failures.push(format!(
"Expected high similarity for pair ({}, {}): got {} (reason: {})",
pair.a, pair.b, sim, pair.reason
));
}
}
"medium" => {
if sim < 0.1 || sim > 0.9 {
failures.push(format!(
"Expected medium similarity for pair ({}, {}): got {} (reason: {})",
pair.a, pair.b, sim, pair.reason
));
}
}
"low" => {
if sim >= 0.95 {
failures.push(format!(
"Expected low similarity for pair ({}, {}): got {} (reason: {})",
pair.a, pair.b, sim, pair.reason
));
}
}
_ => {}
}
}
if !failures.is_empty() {
println!("Ordinal assertion failures:");
for failure in &failures {
println!(" - {}", failure);
}
}
let failure_rate = failures.len() as f64 / mesh.similarity_pairs.len() as f64;
println!("Failure rate: {:.1}%", failure_rate * 100.0);
assert!(
failure_rate <= 0.35,
"Too many ordinal assertion failures: {} out of {} ({:.1}%). \
Terraphim uses Jaccard+path distance which differs from traditional ontology measures.",
failures.len(),
mesh.similarity_pairs.len(),
failure_rate * 100.0
);
}
#[test]
fn test_mesh_branch_coherence() {
let mesh = load_mesh_subset();
let (index, _isa_parents) = build_mesh_index(&mesh);
let mut branch_nodes: AHashMap<String, Vec<u64>> = AHashMap::new();
for node in &mesh.nodes {
let branch = if node.mesh_id.starts_with('C') {
node.mesh_id.split('.').next().unwrap_or("C").to_string()
} else if node.mesh_id.starts_with('A') {
node.mesh_id.split('.').next().unwrap_or("A").to_string()
} else if node.mesh_id.starts_with('D') {
node.mesh_id.split('.').next().unwrap_or("D").to_string()
} else {
continue;
};
branch_nodes.entry(branch).or_default().push(node.id);
}
let mut intra_branch_sims = Vec::new();
let mut inter_branch_sims = Vec::new();
for (branch, nodes) in &branch_nodes {
if nodes.len() < 2 {
continue;
}
for i in 0..nodes.len() {
for j in (i + 1)..nodes.len() {
if let Some(sim) = index.similarity(nodes[i], nodes[j]) {
intra_branch_sims.push(sim);
}
}
}
for (other_branch, other_nodes) in &branch_nodes {
if branch == other_branch || other_nodes.is_empty() {
continue;
}
for &node in nodes.iter().take(3) {
for &other_node in other_nodes.iter().take(3) {
if let Some(sim) = index.similarity(node, other_node) {
inter_branch_sims.push(sim);
}
}
}
}
}
if intra_branch_sims.is_empty() || inter_branch_sims.is_empty() {
println!("Not enough data for branch coherence test");
return;
}
let avg_intra: f64 = intra_branch_sims.iter().sum::<f64>() / intra_branch_sims.len() as f64;
let avg_inter: f64 = inter_branch_sims.iter().sum::<f64>() / inter_branch_sims.len() as f64;
println!("Average intra-branch similarity: {}", avg_intra);
println!("Average inter-branch similarity: {}", avg_inter);
println!("Intra-branch pairs: {}", intra_branch_sims.len());
println!("Inter-branch pairs: {}", inter_branch_sims.len());
assert!(
avg_intra > avg_inter,
"Intra-branch similarity ({}) should be higher than inter-branch similarity ({})",
avg_intra,
avg_inter
);
}
#[test]
fn test_memory_footprint_10000() {
use std::mem::size_of_val;
let mut isa_parents: AHashMap<u64, AHashSet<u64>> = AHashMap::new();
let mut node_types: AHashMap<u64, MedicalNodeType> = AHashMap::new();
let num_nodes = 10000;
for i in 1..num_nodes {
let parent = (i - 1) / 2;
isa_parents
.entry(i as u64)
.or_default()
.insert(parent as u64);
let node_type = match i % 5 {
0 => MedicalNodeType::Disease,
1 => MedicalNodeType::Drug,
2 => MedicalNodeType::Gene,
3 => MedicalNodeType::Anatomy,
_ => MedicalNodeType::Protein,
};
node_types.insert(i as u64, node_type);
}
node_types.insert(0, MedicalNodeType::Disease);
let index = SymbolicEmbeddingIndex::build_from_hierarchy(&isa_parents, &node_types);
let mut total_size = size_of_val(&index);
for (node_id, embedding) in index.all_embeddings() {
total_size += size_of_val(node_id);
total_size += size_of_val(embedding);
total_size += embedding.ancestors.capacity() * 8; total_size += embedding.descendants.capacity() * 8;
}
let size_mb = total_size as f64 / (1024.0 * 1024.0);
println!(
"Memory footprint for 10000 nodes: {} bytes ({:.2} MB)",
total_size, size_mb
);
assert!(
size_mb < 50.0,
"Memory footprint {:.2} MB exceeds 50 MB limit",
size_mb
);
}
#[test]
fn test_lca_depth_computation() {
let mut ancestors: AHashMap<u64, AHashSet<u64>> = AHashMap::new();
let mut depth_map: AHashMap<u64, usize> = AHashMap::new();
ancestors.insert(1, [0].iter().copied().collect());
ancestors.insert(2, [0, 1].iter().copied().collect());
ancestors.insert(3, [0].iter().copied().collect());
ancestors.insert(4, [0, 3].iter().copied().collect());
ancestors.insert(0, AHashSet::new());
depth_map.insert(0, 0);
depth_map.insert(1, 1);
depth_map.insert(2, 2);
depth_map.insert(3, 1);
depth_map.insert(4, 2);
assert_eq!(find_lca_depth(2, 4, &ancestors, &depth_map), 0);
assert_eq!(find_lca_depth(1, 2, &ancestors, &depth_map), 1);
assert_eq!(find_lca_depth(2, 2, &ancestors, &depth_map), 2);
assert_eq!(find_lca_depth(1, 3, &ancestors, &depth_map), 0);
}