use aho_corasick::AhoCorasick;
use std::collections::HashMap;
use crate::umls::{UmlsConcept, UmlsDataset};
#[derive(Debug, Clone)]
pub struct UmlsMatch {
pub cui: String,
pub matched_term: String,
pub canonical_term: String,
pub span: (usize, usize),
pub confidence: f32,
}
pub struct UmlsExtractor {
concept_index: HashMap<String, UmlsConcept>,
automaton: AhoCorasick,
pattern_cuis: Vec<String>,
pattern_terms: Vec<String>,
pattern_count: usize,
}
impl UmlsExtractor {
pub fn from_dataset(dataset: &UmlsDataset) -> anyhow::Result<Self> {
let start = std::time::Instant::now();
let mut concept_index: HashMap<String, UmlsConcept> = HashMap::new();
let mut pattern_terms: Vec<String> = Vec::new();
let mut pattern_cuis: Vec<String> = Vec::new();
for (cui, concept) in &dataset.concepts {
for term in &concept.terms {
let term_lower = term.to_lowercase();
if !term_lower.is_empty() {
pattern_terms.push(term_lower);
pattern_cuis.push(cui.clone());
}
}
concept_index.insert(cui.clone(), concept.clone());
}
let pattern_count = pattern_terms.len();
log::info!(
"Building Aho-Corasick automaton with {} patterns...",
pattern_count
);
let automaton = AhoCorasick::new(&pattern_terms)?;
let build_time = start.elapsed();
log::info!("Automaton built in {}ms", build_time.as_millis());
Ok(Self {
concept_index,
automaton,
pattern_cuis,
pattern_terms,
pattern_count,
})
}
pub fn extract(&self, text: &str) -> Vec<UmlsMatch> {
let text_lower = text.to_lowercase();
let mut matches: Vec<UmlsMatch> = Vec::new();
for mat in self.automaton.find_iter(&text_lower) {
let pattern_index = mat.pattern().as_usize();
let cui = &self.pattern_cuis[pattern_index];
let term = &self.pattern_terms[pattern_index];
let (canonical, confidence) = if let Some(concept) = self.concept_index.get(cui) {
let conf = if concept.preferred_term.to_lowercase() == *term {
1.0
} else {
0.9 };
(concept.preferred_term.clone(), conf)
} else {
(term.clone(), 0.8)
};
let start = mat.start();
let end = mat.end();
let overlaps = matches.iter().any(|m| start < m.span.1 && end > m.span.0);
if !overlaps {
let matched_original = &text[start..end];
matches.push(UmlsMatch {
cui: cui.clone(),
matched_term: matched_original.to_string(),
canonical_term: canonical,
span: (start, end),
confidence,
});
}
}
matches.sort_by_key(|m| m.span.0);
matches
}
pub fn extract_filtered<F>(&self, text: &str, filter: F) -> Vec<UmlsMatch>
where
F: Fn(&UmlsMatch) -> bool,
{
self.extract(text).into_iter().filter(filter).collect()
}
pub fn get_concept(&self, cui: &str) -> Option<&UmlsConcept> {
self.concept_index.get(cui)
}
pub fn pattern_count(&self) -> usize {
self.pattern_count
}
pub fn concept_count(&self) -> usize {
self.concept_index.len()
}
}
#[derive(Debug, Clone)]
pub struct UmlsExtractorStats {
pub concept_count: usize,
pub pattern_count: usize,
pub memory_estimate_mb: f64,
}
impl UmlsExtractorStats {
pub fn from_extractor(extractor: &UmlsExtractor) -> Self {
let concept_count = extractor.concept_count();
let pattern_count = extractor.pattern_count();
let memory_estimate_mb = (pattern_count as f64 * 0.00005) + (concept_count as f64 * 0.0001);
Self {
concept_count,
pattern_count,
memory_estimate_mb,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::umls::UmlsDataset;
use std::io::Write;
use tempfile::NamedTempFile;
fn create_test_dataset() -> UmlsDataset {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "non-small cell lung carcinoma\tC0000001").unwrap();
writeln!(file, "nsclc\tC0000001").unwrap();
writeln!(file, "lung cancer\tC0000001").unwrap();
writeln!(file, "egfr\tC0000002").unwrap();
writeln!(file, "epidermal growth factor receptor\tC0000002").unwrap();
writeln!(file, "gefitinib\tC0000003").unwrap();
UmlsDataset::from_tsv(file.path()).unwrap()
}
#[test]
fn test_extractor_from_dataset() {
let dataset = create_test_dataset();
let extractor = UmlsExtractor::from_dataset(&dataset).unwrap();
assert_eq!(extractor.concept_count(), 3);
assert_eq!(extractor.pattern_count(), 6);
}
#[test]
fn test_extract_single_entity() {
let dataset = create_test_dataset();
let extractor = UmlsExtractor::from_dataset(&dataset).unwrap();
let results = extractor.extract("Patient has lung cancer");
assert!(!results.is_empty());
assert_eq!(results[0].cui, "C0000001");
assert!(results[0].confidence > 0.0);
}
#[test]
fn test_extract_multiple_entities() {
let dataset = create_test_dataset();
let extractor = UmlsExtractor::from_dataset(&dataset).unwrap();
let results = extractor.extract("EGFR mutation in NSCLC patient");
assert!(results.len() >= 2);
let cuis: Vec<&str> = results.iter().map(|r| r.cui.as_str()).collect();
assert!(cuis.contains(&"C0000001")); assert!(cuis.contains(&"C0000002")); }
#[test]
fn test_case_insensitive_matching() {
let dataset = create_test_dataset();
let extractor = UmlsExtractor::from_dataset(&dataset).unwrap();
let results_lower = extractor.extract("patient has lung cancer");
let results_upper = extractor.extract("Patient has LUNG CANCER");
let results_mixed = extractor.extract("Patient has Lung Cancer");
assert!(!results_lower.is_empty());
assert!(!results_upper.is_empty());
assert!(!results_mixed.is_empty());
}
#[test]
fn test_confidence_scoring() {
let dataset = create_test_dataset();
let extractor = UmlsExtractor::from_dataset(&dataset).unwrap();
let results_preferred = extractor.extract("nsclc");
let results_synonym = extractor.extract("lung cancer");
assert!(!results_preferred.is_empty());
assert!(!results_synonym.is_empty());
assert_eq!(results_preferred[0].confidence, 1.0);
assert_eq!(results_synonym[0].confidence, 0.9);
}
#[test]
fn test_get_concept() {
let dataset = create_test_dataset();
let extractor = UmlsExtractor::from_dataset(&dataset).unwrap();
let concept = extractor.get_concept("C0000001").unwrap();
assert_eq!(concept.cui, "C0000001");
assert_eq!(concept.terms.len(), 3);
}
#[test]
fn test_extract_no_match() {
let dataset = create_test_dataset();
let extractor = UmlsExtractor::from_dataset(&dataset).unwrap();
let results = extractor.extract("Patient is feeling well today");
assert!(results.is_empty());
}
}