Skip to main content

llmwiki_tooling/
mention.rs

1use std::collections::HashSet;
2use std::ops::Range;
3
4use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind};
5
6use crate::page::PageId;
7use crate::parse::{ClassifiedRange, RangeKind};
8use crate::splice;
9
10/// A bare concept mention found in prose that should be a wikilink.
11#[derive(Debug)]
12pub struct BareMention {
13    pub concept: PageId,
14    pub byte_range: Range<usize>,
15    pub line: usize,
16    pub col: usize,
17}
18
19/// Efficient multi-pattern matcher for auto-linkable page names.
20pub struct ConceptMatcher {
21    automaton: AhoCorasick,
22    concepts: Vec<PageId>,
23}
24
25impl ConceptMatcher {
26    pub fn new(pages: &HashSet<PageId>) -> Self {
27        let concept_list: Vec<PageId> = pages.iter().cloned().collect();
28        let patterns: Vec<&str> = concept_list.iter().map(|c| c.as_str()).collect();
29        // Always case-insensitive since PageIds are normalized to lowercase
30        let automaton = AhoCorasickBuilder::new()
31            .match_kind(MatchKind::LeftmostLongest)
32            .ascii_case_insensitive(true)
33            .build(&patterns)
34            .expect("concept patterns are valid");
35        Self {
36            automaton,
37            concepts: concept_list,
38        }
39    }
40
41    /// Find all bare mentions in a page's prose ranges.
42    pub fn find_bare_mentions(
43        &self,
44        source: &str,
45        classified_ranges: &[ClassifiedRange],
46        self_page: &PageId,
47    ) -> Vec<BareMention> {
48        let line_offsets = splice::compute_line_offsets(source);
49        let mut mentions = Vec::new();
50
51        for cr in classified_ranges {
52            if cr.kind != RangeKind::Prose {
53                continue;
54            }
55
56            let slice = &source[cr.byte_range.clone()];
57
58            for mat in self.automaton.find_iter(slice) {
59                let concept = &self.concepts[mat.pattern().as_usize()];
60
61                if concept == self_page {
62                    continue;
63                }
64
65                let abs_start = cr.byte_range.start + mat.start();
66                let abs_end = cr.byte_range.start + mat.end();
67
68                // Word boundary checks use byte indexing on ASCII-only characters.
69                // Safe because aho-corasick returns byte-aligned positions and we
70                // only inspect ASCII punctuation/alphanumeric at those boundaries.
71                if abs_start > 0 {
72                    let prev = source.as_bytes()[abs_start - 1];
73                    if prev.is_ascii_alphanumeric() || prev == b'_' {
74                        continue;
75                    }
76                    if prev == b'-' && abs_start >= 2 {
77                        let before_dash = source.as_bytes()[abs_start - 2];
78                        if before_dash.is_ascii_alphanumeric() {
79                            continue;
80                        }
81                    }
82                }
83
84                if abs_end < source.len() {
85                    let next = source.as_bytes()[abs_end];
86                    if next.is_ascii_alphanumeric() || next == b'_' {
87                        continue;
88                    }
89                    if next == b'-' && abs_end + 1 < source.len() {
90                        let after_dash = source.as_bytes()[abs_end + 1];
91                        if after_dash.is_ascii_alphanumeric() {
92                            continue;
93                        }
94                    }
95                }
96
97                let line_0 = splice::offset_to_line(&line_offsets, abs_start);
98                let col = abs_start - line_offsets[line_0];
99                mentions.push(BareMention {
100                    concept: concept.clone(),
101                    byte_range: abs_start..abs_end,
102                    line: line_0 + 1,
103                    col: col + 1,
104                });
105            }
106        }
107
108        mentions
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    fn make_matcher(names: &[&str]) -> ConceptMatcher {
117        let concepts: HashSet<PageId> = names.iter().map(|&n| PageId::from(n)).collect();
118        ConceptMatcher::new(&concepts)
119    }
120
121    fn prose_range(start: usize, end: usize) -> ClassifiedRange {
122        ClassifiedRange {
123            kind: RangeKind::Prose,
124            byte_range: start..end,
125        }
126    }
127
128    #[test]
129    fn finds_bare_mention() {
130        let source = "Use GRPO for training.";
131        let matcher = make_matcher(&["GRPO"]);
132        let ranges = vec![prose_range(0, source.len())];
133        let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("other"));
134        assert_eq!(mentions.len(), 1);
135        assert_eq!(mentions[0].concept.as_str(), "grpo");
136    }
137
138    #[test]
139    fn skips_self_page() {
140        let source = "GRPO is great.";
141        let matcher = make_matcher(&["GRPO"]);
142        let ranges = vec![prose_range(0, source.len())];
143        let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("GRPO"));
144        assert!(mentions.is_empty());
145    }
146
147    #[test]
148    fn skips_compound_terms_suffix() {
149        let source = "GRPO-based approach";
150        let matcher = make_matcher(&["GRPO"]);
151        let ranges = vec![prose_range(0, source.len())];
152        let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("other"));
153        assert!(mentions.is_empty());
154    }
155
156    #[test]
157    fn skips_compound_terms_prefix() {
158        let source = "SA-SFT and Mix-CPT are methods";
159        let matcher = make_matcher(&["SFT", "CPT"]);
160        let ranges = vec![prose_range(0, source.len())];
161        let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("other"));
162        assert!(mentions.is_empty());
163    }
164
165    #[test]
166    fn skips_word_boundary_violations() {
167        let source = "xGRPO and GRPOx";
168        let matcher = make_matcher(&["GRPO"]);
169        let ranges = vec![prose_range(0, source.len())];
170        let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("other"));
171        assert!(mentions.is_empty());
172    }
173
174    #[test]
175    fn finds_multiple_concepts() {
176        let source = "DPO and GRPO are methods.";
177        let matcher = make_matcher(&["DPO", "GRPO"]);
178        let ranges = vec![prose_range(0, source.len())];
179        let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("other"));
180        assert_eq!(mentions.len(), 2);
181    }
182
183    #[test]
184    fn skips_non_prose_ranges() {
185        let source = "GRPO in heading";
186        let ranges = vec![ClassifiedRange {
187            kind: RangeKind::Heading,
188            byte_range: 0..source.len(),
189        }];
190        let matcher = make_matcher(&["GRPO"]);
191        let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("other"));
192        assert!(mentions.is_empty());
193    }
194
195    #[test]
196    fn reports_correct_line_col() {
197        let source = "line one\nGRPO here";
198        let matcher = make_matcher(&["GRPO"]);
199        let ranges = vec![prose_range(9, source.len())];
200        let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("other"));
201        assert_eq!(mentions.len(), 1);
202        assert_eq!(mentions[0].line, 2);
203        assert_eq!(mentions[0].col, 1);
204    }
205
206    #[test]
207    fn allows_concept_followed_by_punctuation() {
208        let source = "Use GRPO, DPO.";
209        let matcher = make_matcher(&["GRPO", "DPO"]);
210        let ranges = vec![prose_range(0, source.len())];
211        let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("other"));
212        assert_eq!(mentions.len(), 2);
213    }
214
215    #[test]
216    fn case_insensitive_matching() {
217        let source = "Use grpo for training.";
218        let concepts: HashSet<PageId> = ["GRPO"].iter().map(|&n| PageId::from(n)).collect();
219        let matcher = ConceptMatcher::new(&concepts);
220        let ranges = vec![prose_range(0, source.len())];
221        let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("other"));
222        assert_eq!(mentions.len(), 1);
223    }
224}