Skip to main content

rust_canto/
lib.rs

1mod trie;
2mod token;
3mod yale;
4use yale::jyutping_to_yale;
5use yale::jyutping_to_yale_vec;
6
7use trie::Trie;
8use token::Token;
9use once_cell::sync::Lazy;
10use wasm_minimal_protocol::*;
11
12const CHAR_DATA: &str = include_str!("../data/chars.tsv");
13const WORD_DATA: &str = include_str!("../data/words.tsv");
14const FREQ_DATA: &str = include_str!("../data/freq.txt");
15const LETTERED_DATA: &str = include_str!("../data/lettered.tsv");
16
17initiate_protocol!();
18
19static TRIE: Lazy<Trie> = Lazy::new(|| build_trie());
20
21fn build_trie() -> Trie {
22    let mut trie = Trie::new();
23
24    for line in CHAR_DATA.lines() {
25        let parts: Vec<&str> = line.split('\t').collect();
26        if parts.len() >= 2 {
27            if let Some(ch) = parts[0].chars().next() {
28                // parse "5%" → 5, missing → 100 (highest priority)
29                let weight = parts.get(2)
30                    .map(|s| s.replace('%', "").trim().parse::<u32>().unwrap_or(0))
31                    .unwrap_or(100);
32                trie.insert_char(ch, parts[1], weight);
33            }
34        }
35    }
36
37    for line in WORD_DATA.lines() {
38        let Some((left, right)) = line.split_once('\t') else {
39            continue;
40        };
41        trie.insert_word(left, right);
42    }
43
44    for line in FREQ_DATA.lines() {
45        let parts: Vec<&str> = line.split('\t').collect();
46        if parts.len() >= 2 {
47            if let Ok(freq) = parts[1].parse::<i64>() {
48                trie.insert_freq(parts[0], freq);
49            }
50        }
51    }
52
53    for line in LETTERED_DATA.lines() {
54        let Some((left, right)) = line.split_once('\t') else {
55            continue;
56        };
57        trie.insert_lettered(left, right);
58    }
59
60    trie
61}
62
63#[wasm_func]
64pub fn annotate(input: &[u8]) -> Vec<u8> {
65    let text = std::str::from_utf8(input).unwrap_or("");
66    let tokens = TRIE.segment(text);
67
68    let output: Vec<Token> = tokens
69        .into_iter()
70        .map(|t| Token {
71            word: t.word,
72            yale: t.reading.as_deref().and_then(jyutping_to_yale_vec),
73            reading: t.reading,
74        })
75        .collect();
76
77    serde_json::to_string(&output)
78        .unwrap_or_else(|_| "[]".to_string())
79        .into_bytes()
80}
81
82/// Input: jyutping bytes, e.g. b"gwong2 dung1 waa2"
83/// Output: Yale with tone numbers, e.g. b"gwong2 dung1 waa2"
84#[wasm_func]
85pub fn to_yale_numeric(input: &[u8]) -> Vec<u8> {
86    let jp = std::str::from_utf8(input).unwrap_or("");
87    jyutping_to_yale(jp, false)
88        .unwrap_or_default()
89        .into_bytes()
90}
91
92/// Input: jyutping bytes
93/// Output: Yale with diacritics, e.g. b"gwóngdūngwá"
94#[wasm_func]
95pub fn to_yale_diacritics(input: &[u8]) -> Vec<u8> {
96    let jp = std::str::from_utf8(input).unwrap_or("");
97    jyutping_to_yale(jp, true)
98        .unwrap_or_default()
99        .into_bytes()
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    #[test]
107    fn test_segmentation() {
108        let trie = build_trie();
109
110        let cases = vec![
111            (
112                "都會大學入面3%人識用AB膠",
113                vec![
114                    ("都會大學", Some("dou1 wui6 daai6 hok6")),
115                    ("入面", Some("jap6 min6")),
116                    ("3", None),
117                    ("%", Some("pat6 sen1")),
118                    ("人", Some("jan4")),
119                    ("識", Some("sik1")),
120                    ("用", Some("jung6")),
121                    ("AB膠", Some("ei1 bi1 gaau1")),
122                ],
123            ),
124            (
125                "我會番教會",
126                vec![
127                    ("我", Some("ngo5")),
128                    ("會", Some("wui5")),
129                    ("番", Some("faan1")),
130                    ("教會", Some("gaau3 wui2")),
131                ],
132            ),
133            (
134                "佢係好學生",
135                vec![
136                    ("佢", Some("keoi5")),
137                    ("係", Some("hai6")),
138                    ("好", Some("hou2")),
139                    ("學生", Some("hok6 saang1")),
140                ],
141            ),
142        ];
143
144        for (input, expected) in cases {
145            println!("Testing: {}", input);
146            let result = trie.segment(input);
147            assert_eq!(result.len(), expected.len(),
148                "token count mismatch for '{}': got {:?}", input,
149                result.iter().map(|t| &t.word).collect::<Vec<_>>()
150            );
151            for (i, token) in result.iter().enumerate() {
152                assert_eq!(token.word, expected[i].0,
153                    "word mismatch at index {} for '{}'", i, input);
154                assert_eq!(token.reading.as_deref(), expected[i].1,
155                    "reading mismatch at index {} for '{}'", i, input);
156            }
157        }
158    }
159}