use daachorse::DoubleArrayAhoCorasick;
use std::collections::HashMap;
use crate::medical_artifact::{
ArtifactHeader, PatternMeta, artifact_exists, load_umls_artifact, save_umls_artifact,
};
use crate::umls::{UmlsConcept, UmlsDataset};
use crate::umls_extractor::UmlsMatch;
const DEFAULT_MAX_PATTERNS_PER_SHARD: usize = 500_000;
pub struct ShardedUmlsExtractor {
shards: Vec<DoubleArrayAhoCorasick<u32>>,
shard_metadata: Vec<Vec<PatternMetadata>>,
concept_index: HashMap<String, UmlsConcept>,
total_patterns: usize,
}
#[derive(Clone, serde::Serialize, serde::Deserialize)]
struct PatternMetadata {
cuis: Vec<String>,
term: String,
}
impl ShardedUmlsExtractor {
pub fn from_dataset(dataset: &UmlsDataset) -> anyhow::Result<Self> {
Self::from_dataset_with_shard_size(dataset, DEFAULT_MAX_PATTERNS_PER_SHARD)
}
pub fn from_dataset_with_shard_size(
dataset: &UmlsDataset,
max_patterns_per_shard: usize,
) -> anyhow::Result<Self> {
let start = std::time::Instant::now();
let mut all_patterns: Vec<(String, String)> = Vec::with_capacity(dataset.term_count);
for (cui, concept) in &dataset.concepts {
for term in &concept.terms {
all_patterns.push((term.to_lowercase(), cui.clone()));
}
}
all_patterns.sort_by(|a, b| a.0.cmp(&b.0));
let mut merged: Vec<(String, Vec<String>)> = Vec::new();
for (term, cui) in all_patterns {
if let Some(last) = merged.last_mut() {
if last.0 == term {
if !last.1.contains(&cui) {
last.1.push(cui);
}
continue;
}
}
merged.push((term, vec![cui]));
}
let multi_cui_count = merged.iter().filter(|(_, cuis)| cuis.len() > 1).count();
if multi_cui_count > 0 {
log::warn!(
"{} terms map to multiple CUIs; all CUIs preserved",
multi_cui_count
);
}
let total_patterns = merged.len();
log::info!(
"Building sharded extractor with {} patterns (max {} per shard)...",
total_patterns,
max_patterns_per_shard
);
let num_shards = total_patterns.div_ceil(max_patterns_per_shard);
let mut shards: Vec<DoubleArrayAhoCorasick<u32>> = Vec::with_capacity(num_shards);
let mut shard_metadata: Vec<Vec<PatternMetadata>> = Vec::with_capacity(num_shards);
for shard_idx in 0..num_shards {
let start_idx = shard_idx * max_patterns_per_shard;
let end_idx = ((shard_idx + 1) * max_patterns_per_shard).min(total_patterns);
let shard_patterns: Vec<String> = merged[start_idx..end_idx]
.iter()
.map(|(term, _)| term.clone())
.collect();
let metadata: Vec<PatternMetadata> = merged[start_idx..end_idx]
.iter()
.map(|(term, cuis)| PatternMetadata {
term: term.clone(),
cuis: cuis.clone(),
})
.collect();
log::debug!(
"Building shard {}/{} with {} patterns...",
shard_idx + 1,
num_shards,
shard_patterns.len()
);
let automaton = DoubleArrayAhoCorasick::<u32>::new(shard_patterns).map_err(|e| {
anyhow::anyhow!("Failed to build daachorse shard {}: {:?}", shard_idx, e)
})?;
shards.push(automaton);
shard_metadata.push(metadata);
}
let concept_index: HashMap<String, UmlsConcept> = dataset
.concepts
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
let build_time = start.elapsed();
log::info!(
"Sharded extractor built in {}ms ({} shards)",
build_time.as_millis(),
num_shards
);
Ok(Self {
shards,
shard_metadata,
concept_index,
total_patterns,
})
}
pub fn save_to_artifact(&self, path: &std::path::Path) -> anyhow::Result<()> {
let shard_bytes: Vec<Vec<u8>> = self.shards.iter().map(|s| s.serialize()).collect();
let shard_metadata: Vec<Vec<PatternMeta>> = self
.shard_metadata
.iter()
.map(|shard| {
shard
.iter()
.map(|m| PatternMeta {
cuis: m.cuis.clone(),
term: m.term.clone(),
})
.collect()
})
.collect();
let header = ArtifactHeader {
shard_metadata,
concept_index: self.concept_index.clone(),
total_patterns: self.total_patterns,
shard_byte_lengths: shard_bytes.iter().map(|b: &Vec<u8>| b.len()).collect(),
};
save_umls_artifact(&header, &shard_bytes, path)
}
pub fn load_from_artifact(path: &std::path::Path) -> anyhow::Result<Self> {
let (header, shard_bytes) = load_umls_artifact(path)?;
let shards: Vec<DoubleArrayAhoCorasick<u32>> = shard_bytes
.iter()
.map(|bytes| {
let (automaton, _remaining) =
unsafe { DoubleArrayAhoCorasick::<u32>::deserialize_unchecked(bytes) };
automaton
})
.collect();
let shard_metadata: Vec<Vec<PatternMetadata>> = header
.shard_metadata
.into_iter()
.map(|shard| {
shard
.into_iter()
.map(|m| PatternMetadata {
cuis: m.cuis,
term: m.term,
})
.collect()
})
.collect();
Ok(Self {
shards,
shard_metadata,
concept_index: header.concept_index,
total_patterns: header.total_patterns,
})
}
pub fn artifact_exists(path: &std::path::Path) -> bool {
artifact_exists(path)
}
pub fn extract(&self, text: &str) -> Vec<UmlsMatch> {
let text_lower = text.to_lowercase();
let mut all_matches: Vec<UmlsMatch> = Vec::new();
for (shard_idx, automaton) in self.shards.iter().enumerate() {
let metadata = &self.shard_metadata[shard_idx];
for mat in automaton.find_iter(&text_lower) {
let pattern_idx = mat.value() as usize;
let meta = &metadata[pattern_idx];
let start = mat.start();
let end = mat.end();
let matched_original = &text[start..end];
for cui in &meta.cuis {
let (canonical, confidence) = if let Some(concept) = self.concept_index.get(cui)
{
let conf = if concept.preferred_term.to_lowercase() == meta.term {
1.0
} else {
0.9
};
(concept.preferred_term.clone(), conf)
} else {
(meta.term.clone(), 0.8)
};
all_matches.push(UmlsMatch {
cui: cui.clone(),
matched_term: matched_original.to_string(),
canonical_term: canonical,
span: (start, end),
confidence,
});
}
}
}
all_matches.sort_by(|a, b| {
a.span
.0
.cmp(&b.span.0)
.then(a.span.1.cmp(&b.span.1))
.then(a.cui.cmp(&b.cui))
});
all_matches.dedup_by(|a, b| a.span == b.span && a.cui == b.cui);
all_matches
}
pub fn shard_count(&self) -> usize {
self.shards.len()
}
pub fn pattern_count(&self) -> usize {
self.total_patterns
}
pub fn concept_count(&self) -> usize {
self.concept_index.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::umls::UmlsDataset;
use std::io::Write;
use tempfile::NamedTempFile;
fn create_test_dataset() -> UmlsDataset {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "non-small cell lung carcinoma\tC0000001").unwrap();
writeln!(file, "nsclc\tC0000001").unwrap();
writeln!(file, "lung cancer\tC0000001").unwrap();
writeln!(file, "egfr\tC0000002").unwrap();
writeln!(file, "epidermal growth factor receptor\tC0000002").unwrap();
writeln!(file, "gefitinib\tC0000003").unwrap();
UmlsDataset::from_tsv(file.path()).unwrap()
}
#[test]
fn test_sharded_extractor() {
let dataset = create_test_dataset();
let extractor = ShardedUmlsExtractor::from_dataset_with_shard_size(
&dataset, 2, )
.unwrap();
assert!(extractor.shard_count() >= 2);
assert_eq!(extractor.pattern_count(), 6);
}
#[test]
fn test_extract_single_entity() {
let dataset = create_test_dataset();
let extractor = ShardedUmlsExtractor::from_dataset(&dataset).unwrap();
let results = extractor.extract("Patient has lung cancer");
assert!(!results.is_empty());
assert_eq!(results[0].cui, "C0000001");
}
#[test]
fn test_extract_multiple_entities() {
let dataset = create_test_dataset();
let extractor = ShardedUmlsExtractor::from_dataset(&dataset).unwrap();
let results = extractor.extract("EGFR mutation in NSCLC patient");
assert!(results.len() >= 2);
let cuis: Vec<&str> = results.iter().map(|r| r.cui.as_str()).collect();
assert!(cuis.contains(&"C0000001"));
assert!(cuis.contains(&"C0000002"));
}
#[test]
fn test_case_insensitive_matching() {
let dataset = create_test_dataset();
let extractor = ShardedUmlsExtractor::from_dataset(&dataset).unwrap();
let results = extractor.extract("Patient has LUNG CANCER");
assert!(!results.is_empty());
}
#[test]
fn test_artifact_round_trip() {
let dataset = create_test_dataset();
let extractor = ShardedUmlsExtractor::from_dataset(&dataset).unwrap();
let dir = tempfile::tempdir().unwrap();
let artifact_path = dir.path().join("umls_test.bin.zst");
extractor.save_to_artifact(&artifact_path).unwrap();
assert!(artifact_path.exists());
assert!(ShardedUmlsExtractor::artifact_exists(&artifact_path));
let loaded = ShardedUmlsExtractor::load_from_artifact(&artifact_path).unwrap();
assert_eq!(loaded.pattern_count(), extractor.pattern_count());
assert_eq!(loaded.shard_count(), extractor.shard_count());
assert_eq!(loaded.concept_count(), extractor.concept_count());
let results = loaded.extract("Patient has lung cancer and EGFR mutation");
assert!(!results.is_empty());
let cuis: Vec<&str> = results.iter().map(|r| r.cui.as_str()).collect();
assert!(cuis.contains(&"C0000001"));
assert!(cuis.contains(&"C0000002"));
}
fn create_multi_cui_dataset() -> UmlsDataset {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "cold\tC0009264").unwrap(); writeln!(file, "common cold\tC0009264").unwrap();
writeln!(file, "cold\tC0234192").unwrap(); writeln!(file, "cold temperature\tC0234192").unwrap();
writeln!(file, "fever\tC0015967").unwrap();
UmlsDataset::from_tsv(file.path()).unwrap()
}
#[test]
fn test_multi_cui_term_preserved() {
let dataset = create_multi_cui_dataset();
let extractor = ShardedUmlsExtractor::from_dataset(&dataset).unwrap();
assert_eq!(extractor.pattern_count(), 4);
let results = extractor.extract("Patient has cold and fever");
let cold_matches: Vec<&str> = results
.iter()
.filter(|m| m.matched_term.to_lowercase() == "cold")
.map(|m| m.cui.as_str())
.collect();
assert!(
cold_matches.contains(&"C0009264"),
"Missing CUI C0009264 (Common Cold) for term 'cold'; got: {:?}",
cold_matches
);
assert!(
cold_matches.contains(&"C0234192"),
"Missing CUI C0234192 (Cold Temperature) for term 'cold'; got: {:?}",
cold_matches
);
let fever_matches: Vec<&str> = results
.iter()
.filter(|m| m.matched_term.to_lowercase() == "fever")
.map(|m| m.cui.as_str())
.collect();
assert_eq!(fever_matches, vec!["C0015967"]);
}
#[test]
fn test_multi_cui_artifact_round_trip() {
let dataset = create_multi_cui_dataset();
let extractor = ShardedUmlsExtractor::from_dataset(&dataset).unwrap();
let dir = tempfile::tempdir().unwrap();
let artifact_path = dir.path().join("multi_cui_test.bin.zst");
extractor.save_to_artifact(&artifact_path).unwrap();
let loaded = ShardedUmlsExtractor::load_from_artifact(&artifact_path).unwrap();
assert_eq!(loaded.pattern_count(), extractor.pattern_count());
let results = loaded.extract("Patient has cold");
let cold_cuis: Vec<&str> = results
.iter()
.filter(|m| m.matched_term.to_lowercase() == "cold")
.map(|m| m.cui.as_str())
.collect();
assert!(
cold_cuis.contains(&"C0009264"),
"After round-trip: missing C0009264; got: {:?}",
cold_cuis
);
assert!(
cold_cuis.contains(&"C0234192"),
"After round-trip: missing C0234192; got: {:?}",
cold_cuis
);
}
}