1use alloc::collections::BTreeMap;
42use alloc::string::String;
43use alloc::vec::Vec;
44
45use crate::token::{NamedEntityKind, Token, TokenKind};
46
47static BUILTIN_NE: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/ne_th.bin"));
48
49pub struct NeTagger(BTreeMap<String, NamedEntityKind>);
53
54impl NeTagger {
55 pub fn builtin() -> Self {
57 Self::from_tsv(&crate::decompress_builtin(BUILTIN_NE))
58 }
59
60 pub fn from_tsv(data: &str) -> Self {
67 let mut map: BTreeMap<String, NamedEntityKind> = BTreeMap::new();
68 for line in data.lines() {
69 let line = line.trim();
70 if line.is_empty() || line.starts_with('#') {
71 continue;
72 }
73 let mut parts = line.splitn(2, '\t');
74 let word = match parts.next() {
75 Some(w) if !w.is_empty() => String::from(w),
76 _ => continue,
77 };
78 let tag_str = match parts.next() {
79 Some(t) if !t.is_empty() => t.trim(),
80 _ => continue,
81 };
82 if let Some(kind) = NamedEntityKind::from_tag(tag_str) {
83 map.insert(word, kind);
84 }
85 }
86 NeTagger(map)
87 }
88
89 pub fn tag(&self, word: &str) -> Option<NamedEntityKind> {
104 self.0.get(word).copied()
105 }
106
107 pub fn tag_tokens<'a>(&self, tokens: Vec<Token<'a>>, source: &'a str) -> Vec<Token<'a>> {
141 const MAX_SPAN: usize = 5;
143
144 let mut out: Vec<Token<'a>> = Vec::with_capacity(tokens.len());
145 let mut i = 0;
146
147 while i < tokens.len() {
148 if tokens[i].kind != TokenKind::Thai {
149 out.push(tokens[i].clone());
150 i += 1;
151 continue;
152 }
153
154 let run_end = tokens[i..]
156 .iter()
157 .position(|t| t.kind != TokenKind::Thai)
158 .map_or(tokens.len(), |pos| i + pos);
159 let max_end = run_end.min(i + MAX_SPAN);
160
161 let mut matched = false;
163 for end in (i + 1..=max_end).rev() {
164 let span_start = tokens[i].span.start;
165 let span_end = tokens[end - 1].span.end;
166 let candidate = &source[span_start..span_end];
167 if let Some(ne_kind) = self.tag(candidate) {
168 let char_start = tokens[i].char_span.start;
169 let char_end = tokens[end - 1].char_span.end;
170 out.push(Token::new(
171 candidate,
172 span_start..span_end,
173 char_start..char_end,
174 TokenKind::Named(ne_kind),
175 ));
176 i = end;
177 matched = true;
178 break;
179 }
180 }
181
182 if !matched {
183 out.push(tokens[i].clone());
184 i += 1;
185 }
186 }
187
188 out
189 }
190
191 #[inline]
193 pub fn len(&self) -> usize {
194 self.0.len()
195 }
196
197 #[inline]
199 pub fn is_empty(&self) -> bool {
200 self.0.is_empty()
201 }
202}
203
204#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn builtin_gazetteer_non_empty() {
214 let t = NeTagger::builtin();
215 assert!(t.len() > 50);
216 }
217
218 #[test]
219 fn place_lookup() {
220 let t = NeTagger::builtin();
221 assert_eq!(t.tag("กรุงเทพ"), Some(NamedEntityKind::Place));
222 assert_eq!(t.tag("ไทย"), Some(NamedEntityKind::Place));
223 assert_eq!(t.tag("ญี่ปุ่น"), Some(NamedEntityKind::Place));
224 }
225
226 #[test]
227 fn org_lookup() {
228 let t = NeTagger::builtin();
229 assert_eq!(t.tag("ปตท"), Some(NamedEntityKind::Org));
230 assert_eq!(t.tag("ธนาคารแห่งประเทศไทย"), Some(NamedEntityKind::Org));
231 }
232
233 #[test]
234 fn person_lookup() {
235 let t = NeTagger::builtin();
236 assert_eq!(t.tag("ทักษิณ"), Some(NamedEntityKind::Person));
237 }
238
239 #[test]
240 fn oov_returns_none() {
241 let t = NeTagger::builtin();
242 assert_eq!(t.tag("กิน"), None);
243 assert_eq!(t.tag(""), None);
244 }
245
246 #[test]
247 fn from_tsv_last_duplicate_wins() {
248 let t = NeTagger::from_tsv("กรุงเทพ\tPLACE\nกรุงเทพ\tORG\n");
249 assert_eq!(t.tag("กรุงเทพ"), Some(NamedEntityKind::Org));
250 }
251
252 #[test]
253 fn from_tsv_unknown_tag_skipped() {
254 let t = NeTagger::from_tsv("กรุงเทพ\tCITY\n");
255 assert_eq!(t.tag("กรุงเทพ"), None);
256 }
257
258 #[test]
259 fn from_tsv_empty() {
260 assert!(NeTagger::from_tsv("").is_empty());
261 }
262
263 #[test]
264 fn tag_tokens_relabels_thai() {
265 use crate::token::Token;
266 let source = "กรุงเทพ";
267 let tagger = NeTagger::from_tsv("กรุงเทพ\tPLACE\n");
268 let tok = Token::new("กรุงเทพ", 0..21, 0..7, TokenKind::Thai);
269 let result = tagger.tag_tokens(alloc::vec![tok], source);
270 assert_eq!(result[0].kind, TokenKind::Named(NamedEntityKind::Place));
271 }
272
273 #[test]
274 fn tag_tokens_passes_through_non_thai() {
275 use crate::token::Token;
276 let source = "hello";
277 let tagger = NeTagger::from_tsv("hello\tPERSON\n");
278 let tok = Token::new("hello", 0..5, 0..5, TokenKind::Latin);
279 let result = tagger.tag_tokens(alloc::vec![tok], source);
280 assert_eq!(result[0].kind, TokenKind::Latin); }
282
283 #[test]
284 fn tag_tokens_oov_unchanged() {
285 use crate::token::Token;
286 let source = "กิน";
287 let tagger = NeTagger::from_tsv("กรุงเทพ\tPLACE\n");
288 let tok = Token::new("กิน", 0..9, 0..3, TokenKind::Thai);
289 let result = tagger.tag_tokens(alloc::vec![tok], source);
290 assert_eq!(result[0].kind, TokenKind::Thai);
291 }
292
293 #[test]
296 fn tag_tokens_multi_merges_two_tokens() {
297 use crate::token::Token;
298 let source = "กรุงเทพ";
301 let tagger = NeTagger::from_tsv("กรุงเทพ\tPLACE\n");
302 let tokens = alloc::vec![
303 Token::new("กรุง", 0..12, 0..4, TokenKind::Thai),
304 Token::new("เทพ", 12..21, 4..7, TokenKind::Thai),
305 ];
306 let result = tagger.tag_tokens(tokens, source);
307 assert_eq!(result.len(), 1, "two tokens should merge into one");
308 assert_eq!(result[0].text, "กรุงเทพ");
309 assert_eq!(result[0].kind, TokenKind::Named(NamedEntityKind::Place));
310 assert_eq!(result[0].span, 0..21);
311 assert_eq!(result[0].char_span, 0..7);
312 }
313
314 #[test]
315 fn tag_tokens_multi_greedy_prefers_longer() {
316 use crate::token::Token;
317 let source = "กรุงเทพ";
320 let tagger = NeTagger::from_tsv("กรุง\tPLACE\nกรุงเทพ\tPLACE\n");
321 let tokens = alloc::vec![
322 Token::new("กรุง", 0..12, 0..4, TokenKind::Thai),
323 Token::new("เทพ", 12..21, 4..7, TokenKind::Thai),
324 ];
325 let result = tagger.tag_tokens(tokens, source);
326 assert_eq!(result.len(), 1, "longer match should be preferred");
327 assert_eq!(result[0].text, "กรุงเทพ");
328 }
329
330 #[test]
331 fn tag_tokens_multi_does_not_cross_non_thai() {
332 use crate::token::Token;
333 let source = "กรุง100เทพ";
336 let tagger = NeTagger::from_tsv("กรุงเทพ\tPLACE\n");
337 let tokens = alloc::vec![
338 Token::new("กรุง", 0..12, 0..4, TokenKind::Thai),
339 Token::new("100", 12..15, 4..7, TokenKind::Number),
340 Token::new("เทพ", 15..24, 7..10, TokenKind::Thai),
341 ];
342 let result = tagger.tag_tokens(tokens, source);
343 assert!(
344 result
345 .iter()
346 .all(|t| t.kind != TokenKind::Named(NamedEntityKind::Place)),
347 "no token should become Named when non-Thai sits between them"
348 );
349 assert_eq!(
350 result.len(),
351 3,
352 "tokens should not merge across Number boundary"
353 );
354 }
355
356 #[test]
357 fn tag_tokens_multi_prefix_context() {
358 use crate::token::Token;
359 let source = "ไปกรุงเทพ";
362 let tagger = NeTagger::from_tsv("กรุงเทพ\tPLACE\n");
363 let tokens = alloc::vec![
364 Token::new("ไป", 0..6, 0..2, TokenKind::Thai),
365 Token::new("กรุง", 6..18, 2..6, TokenKind::Thai),
366 Token::new("เทพ", 18..27, 6..9, TokenKind::Thai),
367 ];
368 let result = tagger.tag_tokens(tokens, source);
369 assert_eq!(result.len(), 2);
370 assert_eq!(result[0].kind, TokenKind::Thai);
371 assert_eq!(result[0].text, "ไป");
372 assert_eq!(result[1].kind, TokenKind::Named(NamedEntityKind::Place));
373 assert_eq!(result[1].text, "กรุงเทพ");
374 }
375
376 #[test]
377 fn named_entity_kind_roundtrip() {
378 for kind in [
379 NamedEntityKind::Person,
380 NamedEntityKind::Place,
381 NamedEntityKind::Org,
382 ] {
383 assert_eq!(NamedEntityKind::from_tag(kind.as_tag()), Some(kind));
384 }
385 }
386}