llmwiki_tooling/
mention.rs1use std::collections::HashSet;
2use std::ops::Range;
3
4use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind};
5
6use crate::page::PageId;
7use crate::parse::{ClassifiedRange, RangeKind};
8use crate::splice;
9
10#[derive(Debug)]
12pub struct BareMention {
13 pub concept: PageId,
14 pub byte_range: Range<usize>,
15 pub line: usize,
16 pub col: usize,
17}
18
19pub struct ConceptMatcher {
21 automaton: AhoCorasick,
22 concepts: Vec<PageId>,
23}
24
25impl ConceptMatcher {
26 pub fn new(pages: &HashSet<PageId>) -> Self {
27 let concept_list: Vec<PageId> = pages.iter().cloned().collect();
28 let patterns: Vec<&str> = concept_list.iter().map(|c| c.as_str()).collect();
29 let automaton = AhoCorasickBuilder::new()
31 .match_kind(MatchKind::LeftmostLongest)
32 .ascii_case_insensitive(true)
33 .build(&patterns)
34 .expect("concept patterns are valid");
35 Self {
36 automaton,
37 concepts: concept_list,
38 }
39 }
40
41 pub fn find_bare_mentions(
43 &self,
44 source: &str,
45 classified_ranges: &[ClassifiedRange],
46 self_page: &PageId,
47 ) -> Vec<BareMention> {
48 let line_offsets = splice::compute_line_offsets(source);
49 let mut mentions = Vec::new();
50
51 for cr in classified_ranges {
52 if cr.kind != RangeKind::Prose {
53 continue;
54 }
55
56 let slice = &source[cr.byte_range.clone()];
57
58 for mat in self.automaton.find_iter(slice) {
59 let concept = &self.concepts[mat.pattern().as_usize()];
60
61 if concept == self_page {
62 continue;
63 }
64
65 let abs_start = cr.byte_range.start + mat.start();
66 let abs_end = cr.byte_range.start + mat.end();
67
68 if abs_start > 0 {
72 let prev = source.as_bytes()[abs_start - 1];
73 if prev.is_ascii_alphanumeric() || prev == b'_' {
74 continue;
75 }
76 if prev == b'-' && abs_start >= 2 {
77 let before_dash = source.as_bytes()[abs_start - 2];
78 if before_dash.is_ascii_alphanumeric() {
79 continue;
80 }
81 }
82 }
83
84 if abs_end < source.len() {
85 let next = source.as_bytes()[abs_end];
86 if next.is_ascii_alphanumeric() || next == b'_' {
87 continue;
88 }
89 if next == b'-' && abs_end + 1 < source.len() {
90 let after_dash = source.as_bytes()[abs_end + 1];
91 if after_dash.is_ascii_alphanumeric() {
92 continue;
93 }
94 }
95 }
96
97 let line_0 = splice::offset_to_line(&line_offsets, abs_start);
98 let col = abs_start - line_offsets[line_0];
99 mentions.push(BareMention {
100 concept: concept.clone(),
101 byte_range: abs_start..abs_end,
102 line: line_0 + 1,
103 col: col + 1,
104 });
105 }
106 }
107
108 mentions
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 fn make_matcher(names: &[&str]) -> ConceptMatcher {
117 let concepts: HashSet<PageId> = names.iter().map(|&n| PageId::from(n)).collect();
118 ConceptMatcher::new(&concepts)
119 }
120
121 fn prose_range(start: usize, end: usize) -> ClassifiedRange {
122 ClassifiedRange {
123 kind: RangeKind::Prose,
124 byte_range: start..end,
125 }
126 }
127
128 #[test]
129 fn finds_bare_mention() {
130 let source = "Use GRPO for training.";
131 let matcher = make_matcher(&["GRPO"]);
132 let ranges = vec![prose_range(0, source.len())];
133 let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("other"));
134 assert_eq!(mentions.len(), 1);
135 assert_eq!(mentions[0].concept.as_str(), "grpo");
136 }
137
138 #[test]
139 fn skips_self_page() {
140 let source = "GRPO is great.";
141 let matcher = make_matcher(&["GRPO"]);
142 let ranges = vec![prose_range(0, source.len())];
143 let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("GRPO"));
144 assert!(mentions.is_empty());
145 }
146
147 #[test]
148 fn skips_compound_terms_suffix() {
149 let source = "GRPO-based approach";
150 let matcher = make_matcher(&["GRPO"]);
151 let ranges = vec![prose_range(0, source.len())];
152 let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("other"));
153 assert!(mentions.is_empty());
154 }
155
156 #[test]
157 fn skips_compound_terms_prefix() {
158 let source = "SA-SFT and Mix-CPT are methods";
159 let matcher = make_matcher(&["SFT", "CPT"]);
160 let ranges = vec![prose_range(0, source.len())];
161 let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("other"));
162 assert!(mentions.is_empty());
163 }
164
165 #[test]
166 fn skips_word_boundary_violations() {
167 let source = "xGRPO and GRPOx";
168 let matcher = make_matcher(&["GRPO"]);
169 let ranges = vec![prose_range(0, source.len())];
170 let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("other"));
171 assert!(mentions.is_empty());
172 }
173
174 #[test]
175 fn finds_multiple_concepts() {
176 let source = "DPO and GRPO are methods.";
177 let matcher = make_matcher(&["DPO", "GRPO"]);
178 let ranges = vec![prose_range(0, source.len())];
179 let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("other"));
180 assert_eq!(mentions.len(), 2);
181 }
182
183 #[test]
184 fn skips_non_prose_ranges() {
185 let source = "GRPO in heading";
186 let ranges = vec![ClassifiedRange {
187 kind: RangeKind::Heading,
188 byte_range: 0..source.len(),
189 }];
190 let matcher = make_matcher(&["GRPO"]);
191 let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("other"));
192 assert!(mentions.is_empty());
193 }
194
195 #[test]
196 fn reports_correct_line_col() {
197 let source = "line one\nGRPO here";
198 let matcher = make_matcher(&["GRPO"]);
199 let ranges = vec![prose_range(9, source.len())];
200 let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("other"));
201 assert_eq!(mentions.len(), 1);
202 assert_eq!(mentions[0].line, 2);
203 assert_eq!(mentions[0].col, 1);
204 }
205
206 #[test]
207 fn allows_concept_followed_by_punctuation() {
208 let source = "Use GRPO, DPO.";
209 let matcher = make_matcher(&["GRPO", "DPO"]);
210 let ranges = vec![prose_range(0, source.len())];
211 let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("other"));
212 assert_eq!(mentions.len(), 2);
213 }
214
215 #[test]
216 fn case_insensitive_matching() {
217 let source = "Use grpo for training.";
218 let concepts: HashSet<PageId> = ["GRPO"].iter().map(|&n| PageId::from(n)).collect();
219 let matcher = ConceptMatcher::new(&concepts);
220 let ranges = vec![prose_range(0, source.len())];
221 let mentions = matcher.find_bare_mentions(source, &ranges, &PageId::from("other"));
222 assert_eq!(mentions.len(), 1);
223 }
224}