1mod trie;
2mod token;
3
4use trie::Trie;
5use once_cell::sync::Lazy;
6use wasm_minimal_protocol::*;
7
8const CHAR_DATA: &str = include_str!("../data/chars.tsv");
9const WORD_DATA: &str = include_str!("../data/words.tsv");
10const FREQ_DATA: &str = include_str!("../data/freq.txt");
11
12initiate_protocol!();
13
14static TRIE: Lazy<Trie> = Lazy::new(|| build_trie());
15
16fn build_trie() -> Trie {
17 let mut trie = Trie::new();
18
19 for line in CHAR_DATA.lines() {
20 let parts: Vec<&str> = line.split('\t').collect();
21 if parts.len() >= 2 {
22 if let Some(ch) = parts[0].chars().next() {
23 trie.insert_char(ch, parts[1]);
24 }
25 }
26 }
27
28 for line in WORD_DATA.lines() {
29 let parts: Vec<&str> = line.split('\t').collect();
30 if parts.len() >= 2 {
31 trie.insert_word(parts[0], parts[1]);
32 }
33 }
34
35 for line in FREQ_DATA.lines() {
36 let parts: Vec<&str> = line.split('\t').collect();
37 if parts.len() >= 2 {
38 if let Ok(freq) = parts[1].parse::<i64>() {
39 trie.insert_freq(parts[0], freq);
40 }
41 }
42 }
43
44 trie
45}
46
47#[wasm_func]
48pub fn annotate(input: &[u8]) -> Vec<u8> {
49 let text = std::str::from_utf8(input).unwrap_or("");
50 let tokens = TRIE.segment(text);
51 serde_json::to_string(&tokens)
52 .unwrap_or_else(|_| "[]".to_string())
53 .into_bytes()
54}
55
56#[cfg(test)]
57mod tests {
58 use super::*;
59
60 #[test]
61 fn test_segmentation() {
62 let trie = build_trie();
63
64 let cases = vec![
65 (
66 "都會大學",
67 vec![("都會大學", Some("dou1 wui6 daai6 hok6"))],
68 ),
69 (
70 "好學生",
71 vec![
72 ("好", Some("hou2")),
73 ("學生", Some("hok6 saang1")),
74 ],
75 ),
76 (
77 "我係好學生",
78 vec![
79 ("我", Some("ngo5")),
80 ("係", Some("hai6")),
81 ("好", Some("hou2")),
82 ("學生", Some("hok6 saang1")),
83 ],
84 ),
85 ];
86
87 for (input, expected) in cases {
88 println!("Testing: {}", input);
89 let result = trie.segment(input);
90 assert_eq!(result.len(), expected.len(),
91 "token count mismatch for '{}': got {:?}", input,
92 result.iter().map(|t| &t.word).collect::<Vec<_>>()
93 );
94 for (i, token) in result.iter().enumerate() {
95 assert_eq!(token.word, expected[i].0,
96 "word mismatch at index {} for '{}'", i, input);
97 assert_eq!(token.reading.as_deref(), expected[i].1,
98 "reading mismatch at index {} for '{}'", i, input);
99 }
100 }
101 }
102}