Skip to main content

ctxgraph_extract/
coref.rs

1// Rule-based coreference resolver: maps pronouns to preceding entity mentions.
2// Simple but effective for well-structured technical writing.
3
4use crate::ner::ExtractedEntity;
5
6#[derive(Debug, Clone, Copy, PartialEq)]
7enum PronounType {
8    Person,
9    Neuter,
10    Plural,
11}
12
13pub struct CorefResolver;
14
15impl CorefResolver {
16    /// Given the original text and a list of extracted entities (from NER),
17    /// find pronoun spans and resolve them to the most recently mentioned
18    /// compatible entity. Returns additional entity mentions to add.
19    pub fn resolve(text: &str, entities: &[ExtractedEntity]) -> Vec<ExtractedEntity> {
20        if entities.is_empty() {
21            return Vec::new();
22        }
23
24        // Sort entities by span_start ascending
25        let mut sorted_entities: Vec<&ExtractedEntity> = entities.iter().collect();
26        sorted_entities.sort_by_key(|e| e.span_start);
27
28        let pronoun_spans = find_pronoun_spans(text);
29        let mut result = Vec::new();
30
31        for (pron_start, pron_end, pron_type) in &pronoun_spans {
32            // Find most recently preceding entity compatible with pronoun type.
33            // "preceding" means entity span_start < pronoun start (backward reference only).
34            let candidate = sorted_entities
35                .iter()
36                .rev()
37                .find(|e| {
38                    e.span_start < *pron_start && is_compatible(pron_type, &e.entity_type)
39                });
40
41            if let Some(entity) = candidate {
42                result.push(ExtractedEntity {
43                    text: entity.text.clone(),
44                    entity_type: entity.entity_type.clone(),
45                    span_start: *pron_start,
46                    span_end: *pron_end,
47                    confidence: 0.45,
48                });
49            }
50        }
51
52        result
53    }
54}
55
56fn is_compatible(pron_type: &PronounType, entity_type: &str) -> bool {
57    match pron_type {
58        PronounType::Person => entity_type.eq_ignore_ascii_case("Person"),
59        PronounType::Neuter => !entity_type.eq_ignore_ascii_case("Person"),
60        PronounType::Plural => true,
61    }
62}
63
64/// Classify a lowercase word as a pronoun type, if it is one.
65fn classify_pronoun(word: &str) -> Option<PronounType> {
66    match word {
67        "he" | "him" | "his" | "she" | "her" | "hers" => Some(PronounType::Person),
68        "it" | "its" | "this" | "that" | "these" | "those" => Some(PronounType::Neuter),
69        "they" | "them" | "their" | "theirs" | "we" | "our" | "us" => Some(PronounType::Plural),
70        _ => None,
71    }
72}
73
74/// Walk text and return (byte_start, byte_end, PronounType) for every pronoun found.
75fn find_pronoun_spans(text: &str) -> Vec<(usize, usize, PronounType)> {
76    let mut result = Vec::new();
77    let bytes = text.as_bytes();
78    let len = bytes.len();
79    let mut i = 0usize;
80
81    while i < len {
82        // Skip non-alphabetic characters
83        if !bytes[i].is_ascii_alphabetic() {
84            i += 1;
85            continue;
86        }
87
88        // Found start of a word — find its end
89        let word_start = i;
90        while i < len && bytes[i].is_ascii_alphabetic() {
91            i += 1;
92        }
93        let word_end = i;
94
95        // Lowercase for comparison
96        let word_lower: String = bytes[word_start..word_end]
97            .iter()
98            .map(|b| b.to_ascii_lowercase() as char)
99            .collect();
100
101        if let Some(pron_type) = classify_pronoun(&word_lower) {
102            result.push((word_start, word_end, pron_type));
103        }
104    }
105
106    result
107}