Skip to main content

golia_pinyin/
segmenter.rs

1//! Pinyin syllable segmentation via dynamic programming.
2//!
3//! Splits a continuous pinyin string like `zhonghuarenmin` into
4//! `[zhong, hua, ren, min]`. Ambiguous inputs (e.g., `xian` could be
5//! `[xi, an]` or `[xian]`) get all valid segmentations enumerated; the
6//! caller picks one based on dict hits.
7//!
8//! Algorithm: standard DP. `dp[i]` = list of (prev_index, syllable_str)
9//! reachable from position `i`. The final segmentations are reconstructed by
10//! backtracking from `dp[len]`.
11//!
12//! Performance: the input length is bounded by user input (pinyin
13//! sub-string typed before commit, typically < 20 chars). DP is O(n × max_syl)
14//! where max_syl is the longest valid syllable byte-length (6: `zhuang`).
15
16use crate::syllable;
17
18/// One valid segmentation of an input string.
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub struct Segmentation {
21    /// Syllables in left-to-right order. Each is a substring of the input.
22    pub syllables: Vec<String>,
23}
24
25impl Segmentation {
26    /// Number of syllables in this segmentation.
27    pub fn len(&self) -> usize {
28        self.syllables.len()
29    }
30
31    /// `true` iff there are no syllables.
32    pub fn is_empty(&self) -> bool {
33        self.syllables.is_empty()
34    }
35}
36
37/// Longest valid syllable byte-length. Computed at construction; cap for the
38/// DP inner loop.
39const MAX_SYL_LEN: usize = 6; // "zhuang", "shuang", "chuang"
40
41/// Returns all valid segmentations of `input`. Empty input yields a single
42/// empty segmentation. Inputs that don't fully split return an empty `Vec`.
43///
44/// Results are ordered by syllable count ascending (fewer syllables first =
45/// longer-match preference, the default IME convention).
46pub fn segment(input: &str) -> Vec<Segmentation> {
47    let s = input.to_ascii_lowercase();
48    if s.is_empty() {
49        return vec![Segmentation {
50            syllables: Vec::new(),
51        }];
52    }
53
54    let bytes = s.as_bytes();
55    let n = bytes.len();
56
57    // dp[i] = list of (prev_pos, syllable_starting_at_prev) that reach
58    // position i. dp[0] is implicitly the start.
59    let mut dp: Vec<Vec<(usize, &str)>> = vec![Vec::new(); n + 1];
60    let mut reachable = vec![false; n + 1];
61    reachable[0] = true;
62
63    for i in 0..n {
64        if !reachable[i] {
65            continue;
66        }
67        let max_end = (i + MAX_SYL_LEN).min(n);
68        for j in (i + 1)..=max_end {
69            // SAFETY: we built `s` from a valid str; ASCII slicing is safe.
70            let candidate = &s[i..j];
71            if syllable::is_valid(candidate) {
72                dp[j].push((i, candidate));
73                reachable[j] = true;
74            }
75        }
76    }
77
78    if !reachable[n] {
79        return Vec::new();
80    }
81
82    // Backtrack from n. Collect all paths.
83    let mut results: Vec<Segmentation> = Vec::new();
84    let mut path: Vec<&str> = Vec::new();
85    backtrack(&dp, n, &mut path, &mut results);
86
87    results.sort_by_key(|seg| seg.syllables.len());
88    results
89}
90
91fn backtrack<'a>(
92    dp: &[Vec<(usize, &'a str)>],
93    pos: usize,
94    path: &mut Vec<&'a str>,
95    out: &mut Vec<Segmentation>,
96) {
97    if pos == 0 {
98        // path is built end→start as we recurse from pos=n down to pos=0;
99        // reverse to get start→end order.
100        let mut syllables: Vec<String> = path.iter().map(|s| s.to_string()).collect();
101        syllables.reverse();
102        out.push(Segmentation { syllables });
103        return;
104    }
105    for (prev, syl) in &dp[pos] {
106        path.push(*syl);
107        backtrack(dp, *prev, path, out);
108        path.pop();
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    #[test]
117    fn empty_input_yields_empty_segmentation() {
118        let out = segment("");
119        assert_eq!(out.len(), 1);
120        assert!(out[0].is_empty());
121    }
122
123    #[test]
124    fn single_syllable() {
125        let out = segment("zhong");
126        assert!(!out.is_empty());
127        assert_eq!(out[0].syllables, vec!["zhong"]);
128    }
129
130    #[test]
131    fn two_syllable_unambiguous() {
132        let out = segment("zhongguo");
133        // Should at minimum contain [zhong, guo].
134        assert!(
135            out.iter().any(|s| s.syllables == vec!["zhong", "guo"]),
136            "expected [zhong, guo] in {out:?}"
137        );
138    }
139
140    #[test]
141    fn long_unambiguous() {
142        let out = segment("zhonghuarenmin");
143        assert!(
144            out.iter()
145                .any(|s| s.syllables == vec!["zhong", "hua", "ren", "min"]),
146            "expected [zhong, hua, ren, min] in {out:?}"
147        );
148    }
149
150    #[test]
151    fn ambiguous_xian_returns_multiple() {
152        // `xian` itself is one syllable; `xi` + `an` is also valid.
153        let out = segment("xian");
154        let has_xian = out.iter().any(|s| s.syllables == vec!["xian"]);
155        let has_xi_an = out.iter().any(|s| s.syllables == vec!["xi", "an"]);
156        assert!(has_xian, "missing [xian] in {out:?}");
157        assert!(has_xi_an, "missing [xi, an] in {out:?}");
158    }
159
160    #[test]
161    fn fewer_syllables_first() {
162        let out = segment("xian");
163        // Sort order: [xian] (1) before [xi, an] (2).
164        assert_eq!(out[0].len(), 1);
165    }
166
167    #[test]
168    fn invalid_input_returns_empty() {
169        assert!(segment("xxqz").is_empty());
170        assert!(segment("zhongq").is_empty()); // trailing partial syllable
171    }
172
173    #[test]
174    fn case_insensitive() {
175        let out = segment("ZhongGuo");
176        assert!(out.iter().any(|s| s.syllables == vec!["zhong", "guo"]));
177    }
178}