Skip to main content

kham_core/
ne.rs

1//! Named entity tagging via a gazetteer (word-list approach).
2//!
3//! [`NeTagger`] relabels pre-segmented Thai tokens that appear in the
4//! gazetteer from [`TokenKind::Thai`] to [`TokenKind::Named`]`(kind)`.
5//! The tagger runs as a **post-processing pass** after segmentation — it
6//! does not change the segmentation boundaries, only the token kind.
7//!
8//! **Multi-token matching:** [`NeTagger::tag_tokens`] uses greedy
9//! longest-match over consecutive Thai tokens, so compound names split
10//! by the segmenter (e.g. `กรุง`+`เทพ` → `กรุงเทพ`) are correctly
11//! identified and merged into a single [`TokenKind::Named`] token.
12//!
13//! Three entity categories are supported: [`NamedEntityKind::Person`],
14//! [`NamedEntityKind::Place`], and [`NamedEntityKind::Org`].
15//!
16//! # Data format
17//!
18//! Tab-separated text file, one entry per line:
19//!
20//! ```text
21//! # Thai word<TAB>NE_TAG
22//! กรุงเทพ<TAB>PLACE
23//! ทักษิณ<TAB>PERSON
24//! ปตท<TAB>ORG
25//! ```
26//!
27//! Lines beginning with `#` and blank lines are ignored.
28//! Duplicate keys: last entry wins.
29//!
30//! # Example
31//!
32//! ```rust
33//! use kham_core::ne::NeTagger;
34//! use kham_core::token::NamedEntityKind;
35//!
36//! let tagger = NeTagger::from_tsv("กรุงเทพ\tPLACE\nทักษิณ\tPERSON\n");
37//! assert_eq!(tagger.tag("กรุงเทพ"), Some(NamedEntityKind::Place));
38//! assert_eq!(tagger.tag("xyz"), None);
39//! ```
40
41use alloc::collections::BTreeMap;
42use alloc::string::String;
43use alloc::vec::Vec;
44
45use crate::token::{NamedEntityKind, Token, TokenKind};
46
47static BUILTIN_NE: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/ne_th.bin"));
48
49/// Gazetteer-based named entity tagger.
50///
51/// Construct once with [`NeTagger::builtin`] and reuse across calls.
52pub struct NeTagger(BTreeMap<String, NamedEntityKind>);
53
54impl NeTagger {
55    /// Load the built-in NE gazetteer (hand-curated Thai NEs).
56    pub fn builtin() -> Self {
57        Self::from_tsv(&crate::decompress_builtin(BUILTIN_NE))
58    }
59
60    /// Parse a tab-separated NE gazetteer.
61    ///
62    /// Format: `thai_word\tNE_TAG` — one entry per line.
63    /// Lines beginning with `#` and blank lines are skipped.
64    /// Unknown tag strings are skipped silently.
65    /// For duplicate keys, the last entry wins.
66    pub fn from_tsv(data: &str) -> Self {
67        let mut map: BTreeMap<String, NamedEntityKind> = BTreeMap::new();
68        for line in data.lines() {
69            let line = line.trim();
70            if line.is_empty() || line.starts_with('#') {
71                continue;
72            }
73            let mut parts = line.splitn(2, '\t');
74            let word = match parts.next() {
75                Some(w) if !w.is_empty() => String::from(w),
76                _ => continue,
77            };
78            let tag_str = match parts.next() {
79                Some(t) if !t.is_empty() => t.trim(),
80                _ => continue,
81            };
82            if let Some(kind) = NamedEntityKind::from_tag(tag_str) {
83                map.insert(word, kind);
84            }
85        }
86        NeTagger(map)
87    }
88
89    /// Look up the NE category for a pre-segmented word.
90    ///
91    /// Returns `None` if the word is not in the gazetteer.
92    ///
93    /// # Example
94    ///
95    /// ```rust
96    /// use kham_core::ne::NeTagger;
97    /// use kham_core::token::NamedEntityKind;
98    ///
99    /// let tagger = NeTagger::from_tsv("กรุงเทพ\tPLACE\n");
100    /// assert_eq!(tagger.tag("กรุงเทพ"), Some(NamedEntityKind::Place));
101    /// assert_eq!(tagger.tag("xyz"), None);
102    /// ```
103    pub fn tag(&self, word: &str) -> Option<NamedEntityKind> {
104        self.0.get(word).copied()
105    }
106
107    /// Relabel sequences of consecutive [`TokenKind::Thai`] tokens to
108    /// [`TokenKind::Named`]`(kind)` using **greedy longest-match**.
109    ///
110    /// For each Thai token, the longest consecutive span (up to 5 tokens)
111    /// whose concatenated text hits the gazetteer is chosen. This handles
112    /// compound names that the segmenter splits across multiple tokens —
113    /// for example `กรุง`+`เทพ` → `กรุงเทพ` (PLACE).
114    ///
115    /// Merged tokens borrow their `text` as a zero-copy slice of `source`.
116    /// Non-Thai tokens always pass through unchanged.
117    ///
118    /// `source` must be the normalised string from which `tokens` were
119    /// produced (i.e. the same string passed to `Tokenizer::segment`).
120    ///
121    /// # Example
122    ///
123    /// ```rust
124    /// use kham_core::ne::NeTagger;
125    /// use kham_core::token::{Token, TokenKind, NamedEntityKind};
126    ///
127    /// let source = "กรุงเทพ";
128    /// let tagger = NeTagger::from_tsv("กรุงเทพ\tPLACE\n");
129    /// // Simulate segmenter splitting กรุงเทพ into กรุง + เทพ
130    /// // Each Thai char is 3 bytes: กรุง = 12 bytes, เทพ = 9 bytes
131    /// let tokens = vec![
132    ///     Token::new("กรุง", 0..12,  0..4, TokenKind::Thai),
133    ///     Token::new("เทพ",  12..21, 4..7, TokenKind::Thai),
134    /// ];
135    /// let tagged = tagger.tag_tokens(tokens, source);
136    /// assert_eq!(tagged.len(), 1);
137    /// assert_eq!(tagged[0].text, "กรุงเทพ");
138    /// assert_eq!(tagged[0].kind, TokenKind::Named(NamedEntityKind::Place));
139    /// ```
140    pub fn tag_tokens<'a>(&self, tokens: Vec<Token<'a>>, source: &'a str) -> Vec<Token<'a>> {
141        // Maximum number of consecutive Thai tokens to try merging.
142        const MAX_SPAN: usize = 5;
143
144        let mut out: Vec<Token<'a>> = Vec::with_capacity(tokens.len());
145        let mut i = 0;
146
147        while i < tokens.len() {
148            if tokens[i].kind != TokenKind::Thai {
149                out.push(tokens[i].clone());
150                i += 1;
151                continue;
152            }
153
154            // Find the end of the consecutive Thai run starting at i.
155            let run_end = tokens[i..]
156                .iter()
157                .position(|t| t.kind != TokenKind::Thai)
158                .map_or(tokens.len(), |pos| i + pos);
159            let max_end = run_end.min(i + MAX_SPAN);
160
161            // Greedy longest-match: try longest span first, shrink until hit.
162            let mut matched = false;
163            for end in (i + 1..=max_end).rev() {
164                let span_start = tokens[i].span.start;
165                let span_end = tokens[end - 1].span.end;
166                let candidate = &source[span_start..span_end];
167                if let Some(ne_kind) = self.tag(candidate) {
168                    let char_start = tokens[i].char_span.start;
169                    let char_end = tokens[end - 1].char_span.end;
170                    out.push(Token::new(
171                        candidate,
172                        span_start..span_end,
173                        char_start..char_end,
174                        TokenKind::Named(ne_kind),
175                    ));
176                    i = end;
177                    matched = true;
178                    break;
179                }
180            }
181
182            if !matched {
183                out.push(tokens[i].clone());
184                i += 1;
185            }
186        }
187
188        out
189    }
190
191    /// Number of entries in the gazetteer.
192    #[inline]
193    pub fn len(&self) -> usize {
194        self.0.len()
195    }
196
197    /// Return `true` if the gazetteer has no entries.
198    #[inline]
199    pub fn is_empty(&self) -> bool {
200        self.0.is_empty()
201    }
202}
203
204// ---------------------------------------------------------------------------
205// Tests
206// ---------------------------------------------------------------------------
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn builtin_gazetteer_non_empty() {
214        let t = NeTagger::builtin();
215        assert!(t.len() > 50);
216    }
217
218    #[test]
219    fn place_lookup() {
220        let t = NeTagger::builtin();
221        assert_eq!(t.tag("กรุงเทพ"), Some(NamedEntityKind::Place));
222        assert_eq!(t.tag("ไทย"), Some(NamedEntityKind::Place));
223        assert_eq!(t.tag("ญี่ปุ่น"), Some(NamedEntityKind::Place));
224    }
225
226    #[test]
227    fn org_lookup() {
228        let t = NeTagger::builtin();
229        assert_eq!(t.tag("ปตท"), Some(NamedEntityKind::Org));
230        assert_eq!(t.tag("ธนาคารแห่งประเทศไทย"), Some(NamedEntityKind::Org));
231    }
232
233    #[test]
234    fn person_lookup() {
235        let t = NeTagger::builtin();
236        assert_eq!(t.tag("ทักษิณ"), Some(NamedEntityKind::Person));
237    }
238
239    #[test]
240    fn oov_returns_none() {
241        let t = NeTagger::builtin();
242        assert_eq!(t.tag("กิน"), None);
243        assert_eq!(t.tag(""), None);
244    }
245
246    #[test]
247    fn from_tsv_last_duplicate_wins() {
248        let t = NeTagger::from_tsv("กรุงเทพ\tPLACE\nกรุงเทพ\tORG\n");
249        assert_eq!(t.tag("กรุงเทพ"), Some(NamedEntityKind::Org));
250    }
251
252    #[test]
253    fn from_tsv_unknown_tag_skipped() {
254        let t = NeTagger::from_tsv("กรุงเทพ\tCITY\n");
255        assert_eq!(t.tag("กรุงเทพ"), None);
256    }
257
258    #[test]
259    fn from_tsv_empty() {
260        assert!(NeTagger::from_tsv("").is_empty());
261    }
262
263    #[test]
264    fn tag_tokens_relabels_thai() {
265        use crate::token::Token;
266        let source = "กรุงเทพ";
267        let tagger = NeTagger::from_tsv("กรุงเทพ\tPLACE\n");
268        let tok = Token::new("กรุงเทพ", 0..21, 0..7, TokenKind::Thai);
269        let result = tagger.tag_tokens(alloc::vec![tok], source);
270        assert_eq!(result[0].kind, TokenKind::Named(NamedEntityKind::Place));
271    }
272
273    #[test]
274    fn tag_tokens_passes_through_non_thai() {
275        use crate::token::Token;
276        let source = "hello";
277        let tagger = NeTagger::from_tsv("hello\tPERSON\n");
278        let tok = Token::new("hello", 0..5, 0..5, TokenKind::Latin);
279        let result = tagger.tag_tokens(alloc::vec![tok], source);
280        assert_eq!(result[0].kind, TokenKind::Latin); // not relabeled
281    }
282
283    #[test]
284    fn tag_tokens_oov_unchanged() {
285        use crate::token::Token;
286        let source = "กิน";
287        let tagger = NeTagger::from_tsv("กรุงเทพ\tPLACE\n");
288        let tok = Token::new("กิน", 0..9, 0..3, TokenKind::Thai);
289        let result = tagger.tag_tokens(alloc::vec![tok], source);
290        assert_eq!(result[0].kind, TokenKind::Thai);
291    }
292
293    // ── multi-token NE tests ──────────────────────────────────────────────────
294
295    #[test]
296    fn tag_tokens_multi_merges_two_tokens() {
297        use crate::token::Token;
298        // กรุงเทพ splits into กรุง + เทพ
299        // Each Thai char is 3 bytes: กรุง=12 bytes (4 chars), เทพ=9 bytes (3 chars)
300        let source = "กรุงเทพ";
301        let tagger = NeTagger::from_tsv("กรุงเทพ\tPLACE\n");
302        let tokens = alloc::vec![
303            Token::new("กรุง", 0..12, 0..4, TokenKind::Thai),
304            Token::new("เทพ", 12..21, 4..7, TokenKind::Thai),
305        ];
306        let result = tagger.tag_tokens(tokens, source);
307        assert_eq!(result.len(), 1, "two tokens should merge into one");
308        assert_eq!(result[0].text, "กรุงเทพ");
309        assert_eq!(result[0].kind, TokenKind::Named(NamedEntityKind::Place));
310        assert_eq!(result[0].span, 0..21);
311        assert_eq!(result[0].char_span, 0..7);
312    }
313
314    #[test]
315    fn tag_tokens_multi_greedy_prefers_longer() {
316        use crate::token::Token;
317        // Both "กรุงเทพ" (2-token) and "กรุง" (1-token) in gazetteer —
318        // longest match (กรุงเทพ) must win.
319        let source = "กรุงเทพ";
320        let tagger = NeTagger::from_tsv("กรุง\tPLACE\nกรุงเทพ\tPLACE\n");
321        let tokens = alloc::vec![
322            Token::new("กรุง", 0..12, 0..4, TokenKind::Thai),
323            Token::new("เทพ", 12..21, 4..7, TokenKind::Thai),
324        ];
325        let result = tagger.tag_tokens(tokens, source);
326        assert_eq!(result.len(), 1, "longer match should be preferred");
327        assert_eq!(result[0].text, "กรุงเทพ");
328    }
329
330    #[test]
331    fn tag_tokens_multi_does_not_cross_non_thai() {
332        use crate::token::Token;
333        // "กรุง100เทพ" — Number token between Thai tokens; should NOT merge.
334        // กรุง=12 bytes, 100=3 bytes (ASCII), เทพ=9 bytes
335        let source = "กรุง100เทพ";
336        let tagger = NeTagger::from_tsv("กรุงเทพ\tPLACE\n");
337        let tokens = alloc::vec![
338            Token::new("กรุง", 0..12, 0..4, TokenKind::Thai),
339            Token::new("100", 12..15, 4..7, TokenKind::Number),
340            Token::new("เทพ", 15..24, 7..10, TokenKind::Thai),
341        ];
342        let result = tagger.tag_tokens(tokens, source);
343        assert!(
344            result
345                .iter()
346                .all(|t| t.kind != TokenKind::Named(NamedEntityKind::Place)),
347            "no token should become Named when non-Thai sits between them"
348        );
349        assert_eq!(
350            result.len(),
351            3,
352            "tokens should not merge across Number boundary"
353        );
354    }
355
356    #[test]
357    fn tag_tokens_multi_prefix_context() {
358        use crate::token::Token;
359        // "ไปกรุงเทพ" → [ไป, กรุง, เทพ]; only กรุงเทพ is in gazetteer.
360        // ไป=6 bytes (2 chars), กรุง=12 bytes (4 chars), เทพ=9 bytes (3 chars)
361        let source = "ไปกรุงเทพ";
362        let tagger = NeTagger::from_tsv("กรุงเทพ\tPLACE\n");
363        let tokens = alloc::vec![
364            Token::new("ไป", 0..6, 0..2, TokenKind::Thai),
365            Token::new("กรุง", 6..18, 2..6, TokenKind::Thai),
366            Token::new("เทพ", 18..27, 6..9, TokenKind::Thai),
367        ];
368        let result = tagger.tag_tokens(tokens, source);
369        assert_eq!(result.len(), 2);
370        assert_eq!(result[0].kind, TokenKind::Thai);
371        assert_eq!(result[0].text, "ไป");
372        assert_eq!(result[1].kind, TokenKind::Named(NamedEntityKind::Place));
373        assert_eq!(result[1].text, "กรุงเทพ");
374    }
375
376    #[test]
377    fn named_entity_kind_roundtrip() {
378        for kind in [
379            NamedEntityKind::Person,
380            NamedEntityKind::Place,
381            NamedEntityKind::Org,
382        ] {
383            assert_eq!(NamedEntityKind::from_tag(kind.as_tag()), Some(kind));
384        }
385    }
386}