Skip to main content

piper_plus_g2p/
encode.rs

1//! Phoneme token-to-ID conversion.
2//!
3//! Converts phoneme token strings to phoneme_id integers using the
4//! phoneme_id_map from config.json.
5
6use crate::error::G2pError;
7use crate::phonemizer::PhonemeIdMap;
8
9use crate::phonemizer::ProsodyFeature;
10use crate::phonemizer::ProsodyInfo;
11use crate::token_map::token_to_pua;
12
13/// Convert a sequence of phoneme token strings to phoneme IDs.
14///
15/// Each token is a single character (regular or PUA). The token is looked up
16/// in the phoneme_id_map to get the corresponding integer ID(s).
17pub fn tokens_to_ids(
18    tokens: &[String],
19    phoneme_id_map: &PhonemeIdMap,
20) -> Result<Vec<i64>, G2pError> {
21    let mut ids = Vec::with_capacity(tokens.len() * 2);
22    for token in tokens {
23        match phoneme_id_map.get(token) {
24            Some(id_list) => ids.extend(id_list.iter().copied()),
25            None => {
26                return Err(G2pError::PhonemeIdNotFound {
27                    phoneme: token.clone(),
28                });
29            }
30        }
31    }
32    Ok(ids)
33}
34
35/// Convert prosody info list to prosody features array (for ONNX input).
36/// Each `ProsodyInfo` becomes `[a1, a2, a3]`. `None` becomes `[0, 0, 0]`.
37pub fn prosody_to_features(prosody: &[Option<ProsodyInfo>]) -> Vec<ProsodyFeature> {
38    prosody
39        .iter()
40        .map(|p| match p {
41            Some(info) => [info.a1, info.a2, info.a3],
42            None => [0, 0, 0],
43        })
44        .collect()
45}
46
47/// Encoding mode for handling unknown tokens.
48#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
49pub enum UnknownTokenMode {
50    /// Raise an error on unknown tokens (strict mode).
51    Strict,
52    /// Skip unknown tokens with a warning log (default).
53    #[default]
54    Skip,
55}
56
57/// High-level encoder that converts IPA token sequences into
58/// Piper-compatible phoneme ID arrays with BOS/EOS/PAD insertion.
59pub struct PiperEncoder {
60    id_map: PhonemeIdMap,
61    mode: UnknownTokenMode,
62    bos_id: i64,
63    eos_id: i64,
64    pad_id: i64,
65}
66
67impl PiperEncoder {
68    /// Create a new encoder from a phoneme ID map.
69    pub fn new(id_map: PhonemeIdMap, mode: UnknownTokenMode) -> Result<Self, G2pError> {
70        let bos_id = id_map
71            .get("^")
72            .and_then(|ids| ids.first().copied())
73            .ok_or_else(|| G2pError::Phonemize("phoneme_id_map missing '^' (BOS)".into()))?;
74        let eos_id = id_map
75            .get("$")
76            .and_then(|ids| ids.first().copied())
77            .ok_or_else(|| G2pError::Phonemize("phoneme_id_map missing '$' (EOS)".into()))?;
78        let pad_id = id_map
79            .get("_")
80            .and_then(|ids| ids.first().copied())
81            .ok_or_else(|| G2pError::Phonemize("phoneme_id_map missing '_' (PAD)".into()))?;
82        Ok(Self {
83            id_map,
84            mode,
85            bos_id,
86            eos_id,
87            pad_id,
88        })
89    }
90
91    /// Resolve the EOS phoneme ID for the given token.
92    ///
93    /// - `None` → default EOS (`"$"`)
94    /// - `Some(token)` → look up the token in the id_map (direct, then PUA)
95    fn resolve_eos_id(&self, eos_token: Option<&str>) -> Result<i64, G2pError> {
96        match eos_token {
97            None => Ok(self.eos_id),
98            Some(token) => {
99                // Try direct lookup first
100                if let Some(&id) = self.id_map.get(token).and_then(|ids| ids.first()) {
101                    return Ok(id);
102                }
103                // Try PUA mapping
104                if let Some(pua_char) = token_to_pua(token) {
105                    let pua_str = pua_char.to_string();
106                    if let Some(&id) = self.id_map.get(&pua_str).and_then(|ids| ids.first()) {
107                        return Ok(id);
108                    }
109                }
110                Err(G2pError::PhonemeIdNotFound {
111                    phoneme: token.to_string(),
112                })
113            }
114        }
115    }
116
117    /// Encode IPA tokens to phoneme IDs with BOS/EOS/PAD insertion.
118    pub fn encode(&self, tokens: &[String]) -> Result<Vec<i64>, G2pError> {
119        self.encode_with_eos(tokens, None)
120    }
121
122    /// Encode IPA tokens with a custom EOS token.
123    ///
124    /// - `eos_token = None` → default EOS (`"$"`)
125    /// - `eos_token = Some("?")` → look up `"?"` in id_map
126    /// - `eos_token = Some("?!")` → look up PUA-mapped `"?!"` in id_map
127    pub fn encode_with_eos(
128        &self,
129        tokens: &[String],
130        eos_token: Option<&str>,
131    ) -> Result<Vec<i64>, G2pError> {
132        let (ids, _) = self.encode_with_prosody_and_eos(tokens, &[], eos_token)?;
133        Ok(ids)
134    }
135
136    /// Encode IPA tokens with prosody alignment.
137    pub fn encode_with_prosody(
138        &self,
139        tokens: &[String],
140        prosody: &[Option<ProsodyInfo>],
141    ) -> Result<(Vec<i64>, Vec<ProsodyFeature>), G2pError> {
142        self.encode_with_prosody_and_eos(tokens, prosody, None)
143    }
144
145    /// Encode IPA tokens with prosody alignment and a custom EOS token.
146    pub fn encode_with_prosody_and_eos(
147        &self,
148        tokens: &[String],
149        prosody: &[Option<ProsodyInfo>],
150        eos_token: Option<&str>,
151    ) -> Result<(Vec<i64>, Vec<ProsodyFeature>), G2pError> {
152        let resolved_eos = self.resolve_eos_id(eos_token)?;
153        let mut ids = Vec::with_capacity(tokens.len() * 3 + 3);
154        let mut pros = Vec::with_capacity(tokens.len() * 3 + 3);
155
156        // BOS + PAD
157        ids.push(self.bos_id);
158        pros.push([0, 0, 0]);
159        ids.push(self.pad_id);
160        pros.push([0, 0, 0]);
161
162        for (i, token) in tokens.iter().enumerate() {
163            // If the token has a PUA mapping, use the single PUA char;
164            // otherwise iterate the chars of the original token.
165            let mapped: String = match token_to_pua(token) {
166                Some(pua_char) => pua_char.to_string(),
167                None => token.clone(),
168            };
169            for ch in mapped.chars() {
170                let ch_str = ch.to_string();
171                match self.id_map.get(&ch_str) {
172                    Some(id_list) => {
173                        let p = prosody.get(i).and_then(|o| o.as_ref());
174                        let feat = match p {
175                            Some(info) => [info.a1, info.a2, info.a3],
176                            None => [0, 0, 0],
177                        };
178                        for &id in id_list {
179                            ids.push(id);
180                            pros.push(feat);
181                        }
182                    }
183                    None => match self.mode {
184                        UnknownTokenMode::Strict => {
185                            return Err(G2pError::PhonemeIdNotFound { phoneme: ch_str });
186                        }
187                        UnknownTokenMode::Skip => {
188                            tracing::warn!(phoneme = %ch_str, "unknown symbol dropped");
189                        }
190                    },
191                }
192            }
193            ids.push(self.pad_id);
194            pros.push([0, 0, 0]);
195        }
196
197        ids.push(resolved_eos);
198        pros.push([0, 0, 0]);
199        Ok((ids, pros))
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use std::collections::HashMap;
207
208    /// Helper: build a PhonemeIdMap from (key, ids) pairs.
209    fn make_map(entries: &[(&str, &[i64])]) -> PhonemeIdMap {
210        let mut map = HashMap::new();
211        for (key, ids) in entries {
212            map.insert(key.to_string(), ids.to_vec());
213        }
214        map
215    }
216
217    #[test]
218    fn test_basic_token_to_id() {
219        let map = make_map(&[
220            ("^", &[1]),
221            ("_", &[0]),
222            ("$", &[2]),
223            ("a", &[15]),
224            ("k", &[30]),
225        ]);
226        let tokens: Vec<String> = vec!["^", "a", "_", "k", "$"]
227            .into_iter()
228            .map(String::from)
229            .collect();
230
231        let ids = tokens_to_ids(&tokens, &map).unwrap();
232        assert_eq!(ids, vec![1, 15, 0, 30, 2]);
233    }
234
235    #[test]
236    fn test_pua_character_conversion() {
237        // PUA char U+E000 represents "a:" (long vowel)
238        let map = make_map(&[("^", &[1]), ("\u{E000}", &[45]), ("$", &[2])]);
239        let tokens: Vec<String> = vec!["^", "\u{E000}", "$"]
240            .into_iter()
241            .map(String::from)
242            .collect();
243
244        let ids = tokens_to_ids(&tokens, &map).unwrap();
245        assert_eq!(ids, vec![1, 45, 2]);
246    }
247
248    #[test]
249    fn test_unknown_phoneme_error() {
250        let map = make_map(&[("a", &[15])]);
251        let tokens: Vec<String> = vec!["a", "Z"].into_iter().map(String::from).collect();
252
253        let result = tokens_to_ids(&tokens, &map);
254        assert!(result.is_err());
255        let err = result.unwrap_err();
256        let msg = format!("{err}");
257        assert!(
258            msg.contains("Z"),
259            "error message should contain the unknown phoneme 'Z', got: {msg}"
260        );
261    }
262
263    #[test]
264    fn test_prosody_conversion() {
265        let prosody = vec![
266            Some(ProsodyInfo {
267                a1: -2,
268                a2: 1,
269                a3: 5,
270            }),
271            None,
272            Some(ProsodyInfo {
273                a1: 0,
274                a2: 3,
275                a3: 4,
276            }),
277        ];
278
279        let features = prosody_to_features(&prosody);
280        assert_eq!(features.len(), 3);
281        assert_eq!(features[0], [-2, 1, 5]);
282        assert_eq!(features[1], [0, 0, 0]);
283        assert_eq!(features[2], [0, 3, 4]);
284    }
285
286    #[test]
287    fn test_multi_id_mapping() {
288        // Some phoneme_id_map entries map to multiple IDs
289        let map = make_map(&[("a", &[10, 11]), ("b", &[20])]);
290        let tokens: Vec<String> = vec!["a", "b"].into_iter().map(String::from).collect();
291
292        let ids = tokens_to_ids(&tokens, &map).unwrap();
293        assert_eq!(ids, vec![10, 11, 20]);
294    }
295
296    #[test]
297    fn test_empty_tokens() {
298        let map = make_map(&[("a", &[1])]);
299        let tokens: Vec<String> = vec![];
300
301        let ids = tokens_to_ids(&tokens, &map).unwrap();
302        assert!(ids.is_empty());
303    }
304
305    #[test]
306    fn test_piper_encoder_basic() {
307        let map = make_map(&[
308            ("^", &[1]),
309            ("_", &[0]),
310            ("$", &[2]),
311            ("a", &[15]),
312            ("k", &[30]),
313        ]);
314        let encoder = PiperEncoder::new(map, UnknownTokenMode::Skip).unwrap();
315        let tokens: Vec<String> = vec!["a", "k"].into_iter().map(String::from).collect();
316        let ids = encoder.encode(&tokens).unwrap();
317        assert_eq!(ids[0], 1); // BOS
318        assert_eq!(*ids.last().unwrap(), 2); // EOS
319        assert!(ids.contains(&15));
320        assert!(ids.contains(&30));
321    }
322
323    #[test]
324    fn test_piper_encoder_strict_error() {
325        let map = make_map(&[("^", &[1]), ("_", &[0]), ("$", &[2]), ("a", &[15])]);
326        let encoder = PiperEncoder::new(map, UnknownTokenMode::Strict).unwrap();
327        let tokens: Vec<String> = vec!["a", "Z"].into_iter().map(String::from).collect();
328        assert!(encoder.encode(&tokens).is_err());
329    }
330
331    #[test]
332    fn test_piper_encoder_skip_unknown() {
333        let map = make_map(&[("^", &[1]), ("_", &[0]), ("$", &[2]), ("a", &[15])]);
334        let encoder = PiperEncoder::new(map, UnknownTokenMode::Skip).unwrap();
335        let tokens: Vec<String> = vec!["a", "Z"].into_iter().map(String::from).collect();
336        let ids = encoder.encode(&tokens).unwrap();
337        assert!(ids.contains(&15));
338    }
339
340    #[test]
341    fn test_piper_encoder_missing_bos() {
342        let map = make_map(&[("_", &[0]), ("$", &[2])]);
343        assert!(PiperEncoder::new(map, UnknownTokenMode::Skip).is_err());
344    }
345
346    #[test]
347    fn test_encode_with_default_eos() {
348        let map = make_map(&[
349            ("^", &[1]),
350            ("_", &[0]),
351            ("$", &[2]),
352            ("a", &[15]),
353            ("k", &[30]),
354        ]);
355        let encoder = PiperEncoder::new(map, UnknownTokenMode::Skip).unwrap();
356        let tokens: Vec<String> = vec!["a", "k"].into_iter().map(String::from).collect();
357        let ids_default = encoder.encode(&tokens).unwrap();
358        let ids_none = encoder.encode_with_eos(&tokens, None).unwrap();
359        assert_eq!(ids_default, ids_none);
360    }
361
362    #[test]
363    fn test_encode_with_question_eos() {
364        let map = make_map(&[
365            ("^", &[1]),
366            ("_", &[0]),
367            ("$", &[2]),
368            ("?", &[99]),
369            ("a", &[15]),
370        ]);
371        let encoder = PiperEncoder::new(map, UnknownTokenMode::Skip).unwrap();
372        let tokens: Vec<String> = vec!["a"].into_iter().map(String::from).collect();
373        let ids = encoder.encode_with_eos(&tokens, Some("?")).unwrap();
374        // Last element should be "?" ID (99), not default EOS (2)
375        assert_eq!(*ids.last().unwrap(), 99);
376        // BOS should still be first
377        assert_eq!(ids[0], 1);
378    }
379
380    #[test]
381    fn test_encode_with_pua_eos() {
382        // "?!" maps to PUA U+E016 via token_to_pua
383        let pua_char = crate::token_map::token_to_pua("?!").unwrap();
384        let pua_str = pua_char.to_string();
385        let map = make_map(&[
386            ("^", &[1]),
387            ("_", &[0]),
388            ("$", &[2]),
389            (&pua_str, &[88]),
390            ("a", &[15]),
391        ]);
392        let encoder = PiperEncoder::new(map, UnknownTokenMode::Skip).unwrap();
393        let tokens: Vec<String> = vec!["a"].into_iter().map(String::from).collect();
394        let ids = encoder.encode_with_eos(&tokens, Some("?!")).unwrap();
395        assert_eq!(*ids.last().unwrap(), 88);
396    }
397
398    #[test]
399    fn test_encode_with_prosody_and_eos() {
400        let map = make_map(&[
401            ("^", &[1]),
402            ("_", &[0]),
403            ("$", &[2]),
404            ("?", &[99]),
405            ("a", &[15]),
406        ]);
407        let encoder = PiperEncoder::new(map, UnknownTokenMode::Skip).unwrap();
408        let tokens: Vec<String> = vec!["a"].into_iter().map(String::from).collect();
409        let prosody = vec![Some(ProsodyInfo {
410            a1: -2,
411            a2: 1,
412            a3: 5,
413        })];
414        let (ids, pros) = encoder
415            .encode_with_prosody_and_eos(&tokens, &prosody, Some("?"))
416            .unwrap();
417        // Last ID should be custom EOS
418        assert_eq!(*ids.last().unwrap(), 99);
419        // Prosody for EOS should be zero
420        assert_eq!(*pros.last().unwrap(), [0, 0, 0]);
421        // ids and pros must have same length
422        assert_eq!(ids.len(), pros.len());
423    }
424
425    #[test]
426    fn test_resolve_eos_invalid() {
427        let map = make_map(&[("^", &[1]), ("_", &[0]), ("$", &[2]), ("a", &[15])]);
428        let encoder = PiperEncoder::new(map, UnknownTokenMode::Skip).unwrap();
429        let tokens: Vec<String> = vec!["a"].into_iter().map(String::from).collect();
430        let result = encoder.encode_with_eos(&tokens, Some("NONEXISTENT"));
431        assert!(result.is_err());
432        let msg = format!("{}", result.unwrap_err());
433        assert!(
434            msg.contains("NONEXISTENT"),
435            "error should mention the unknown token, got: {msg}"
436        );
437    }
438}