use aho_corasick::AhoCorasick;
use std::collections::HashMap;
use crate::snomed::{SemanticType, SnomedConcept, SnomedMatch, SnomedSubset};
pub struct EntityExtractor {
concept_index: HashMap<u64, SnomedConcept>,
parent_map: HashMap<u64, Vec<u64>>,
automaton: AhoCorasick,
pattern_concept_ids: Vec<u64>,
pattern_terms: Vec<String>,
}
impl EntityExtractor {
pub fn new(snomed_data: &[u8]) -> anyhow::Result<Self> {
let subset = SnomedSubset::from_json(snomed_data)?;
let mut concept_index: HashMap<u64, SnomedConcept> = HashMap::new();
let mut parent_map: HashMap<u64, Vec<u64>> = HashMap::new();
let mut pattern_terms: Vec<String> = Vec::new();
let mut pattern_concept_ids: Vec<u64> = Vec::new();
for concept in &subset.concepts {
let term_lower = concept.term.to_lowercase();
if !term_lower.is_empty() {
pattern_terms.push(term_lower);
pattern_concept_ids.push(concept.id);
}
for syn in &concept.synonyms {
let syn_lower = syn.to_lowercase();
if !syn_lower.is_empty() {
pattern_terms.push(syn_lower);
pattern_concept_ids.push(concept.id);
}
}
concept_index.insert(concept.id, concept.clone());
if !concept.parents.is_empty() {
parent_map.insert(concept.id, concept.parents.clone());
}
}
let automaton = AhoCorasick::new(&pattern_terms)?;
Ok(Self {
concept_index,
parent_map,
automaton,
pattern_concept_ids,
pattern_terms,
})
}
pub fn from_terms(terms: Vec<(&str, u64)>) -> Self {
let mut pattern_terms: Vec<String> = Vec::new();
let mut pattern_concept_ids: Vec<u64> = Vec::new();
for (term, id) in terms {
let term_lower = term.to_lowercase();
pattern_terms.push(term_lower);
pattern_concept_ids.push(id);
}
let automaton =
AhoCorasick::new(&pattern_terms).expect("Failed to build Aho-Corasick automaton");
Self {
concept_index: HashMap::new(),
parent_map: HashMap::new(),
automaton,
pattern_concept_ids,
pattern_terms,
}
}
pub fn extract(&self, text: &str) -> Vec<SnomedMatch> {
self.extract_with_confidence(text)
.into_iter()
.map(|m| SnomedMatch {
concept_id: m.concept_id,
term: m.term,
canonical: m.canonical,
semantic_type: m.semantic_type,
span: m.span,
confidence: m.confidence,
})
.collect()
}
pub fn extract_with_confidence(&self, text: &str) -> Vec<ExtractedEntity> {
let text_lower = text.to_lowercase();
let mut matches: Vec<ExtractedEntity> = Vec::new();
for mat in self.automaton.find_iter(&text_lower) {
let pattern_index = mat.pattern().as_usize();
let concept_id = self.pattern_concept_ids[pattern_index];
let term = &self.pattern_terms[pattern_index];
let (canonical, semantic_type) =
if let Some(concept) = self.concept_index.get(&concept_id) {
(concept.term.clone(), concept.semantic_type)
} else {
(term.clone(), SemanticType::Unknown)
};
let start = mat.start();
let end = mat.end();
let overlaps = matches.iter().any(|m| start < m.span.1 && end > m.span.0);
if !overlaps {
matches.push(ExtractedEntity {
concept_id,
term: text[start..end].to_string(),
canonical,
semantic_type,
span: (start, end),
confidence: 1.0,
});
}
}
matches.sort_by_key(|m| m.span.0);
matches
}
pub fn get_concept(&self, concept_id: u64) -> Option<&SnomedConcept> {
self.concept_index.get(&concept_id)
}
pub fn is_descendant(&self, child: u64, parent: u64) -> bool {
self.get_ancestors(child).contains(&parent)
}
pub fn get_ancestors(&self, concept: u64) -> Vec<u64> {
let mut ancestors = Vec::new();
let mut visited = std::collections::HashSet::new();
let mut stack = vec![concept];
while let Some(current) = stack.pop() {
if visited.contains(¤t) {
continue;
}
visited.insert(current);
if let Some(parents) = self.parent_map.get(¤t) {
for &parent in parents {
if !ancestors.contains(&parent) {
ancestors.push(parent);
stack.push(parent);
}
}
}
}
ancestors
}
}
#[derive(Debug, Clone)]
pub struct ExtractedEntity {
pub concept_id: u64,
pub term: String,
pub canonical: String,
pub semantic_type: SemanticType,
pub span: (usize, usize),
pub confidence: f32,
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_extractor() -> EntityExtractor {
let data = r#"[
{"id": 254637007, "term": "Non-small cell lung carcinoma", "semantic": "Disease", "parents": [363358000]},
{"id": 363358000, "term": "Lung carcinoma", "semantic": "Disease", "parents": [63250001]},
{"id": 63250001, "term": "Lung cancer", "semantic": "Disease", "parents": []},
{"id": 363358001, "term": "EGFR", "semantic": "Gene", "synonyms": ["Epidermal growth factor receptor"]},
{"id": 86249004, "term": "Gefitinib", "semantic": "Pharmaceutical"}
]"#;
EntityExtractor::new(data.as_bytes()).unwrap()
}
#[test]
fn test_extract_single_entity() {
let extractor = create_test_extractor();
let result = extractor.extract("Non-small cell lung carcinoma");
assert!(!result.is_empty());
assert_eq!(result[0].semantic_type, SemanticType::Disease);
}
#[test]
fn test_extract_multiple_entities() {
let extractor = create_test_extractor();
let result = extractor.extract("Patient with EGFR mutation and NSCLC");
assert!(!result.is_empty());
}
#[test]
fn test_extract_no_match() {
let extractor = create_test_extractor();
let result = extractor.extract("Patient feeling well");
assert!(result.is_empty());
}
#[test]
fn test_confidence_scoring() {
let extractor = create_test_extractor();
let result = extractor.extract_with_confidence("Patient has lung cancer");
for entity in &result {
assert!(entity.confidence >= 0.0 && entity.confidence <= 1.0);
}
}
#[test]
fn test_get_concept() {
let extractor = create_test_extractor();
let concept = extractor.get_concept(254637007);
assert!(concept.is_some());
assert_eq!(concept.unwrap().term, "Non-small cell lung carcinoma");
}
#[test]
fn test_ancestor_query() {
let extractor = create_test_extractor();
let ancestors = extractor.get_ancestors(254637007);
assert!(ancestors.contains(&363358000) || ancestors.contains(&63250001));
}
}