1use crate::error::G2pError;
7use crate::phonemizer::PhonemeIdMap;
8
9use crate::phonemizer::ProsodyFeature;
10use crate::phonemizer::ProsodyInfo;
11use crate::token_map::token_to_pua;
12
13pub fn tokens_to_ids(
18 tokens: &[String],
19 phoneme_id_map: &PhonemeIdMap,
20) -> Result<Vec<i64>, G2pError> {
21 let mut ids = Vec::with_capacity(tokens.len() * 2);
22 for token in tokens {
23 match phoneme_id_map.get(token) {
24 Some(id_list) => ids.extend(id_list.iter().copied()),
25 None => {
26 return Err(G2pError::PhonemeIdNotFound {
27 phoneme: token.clone(),
28 });
29 }
30 }
31 }
32 Ok(ids)
33}
34
35pub fn prosody_to_features(prosody: &[Option<ProsodyInfo>]) -> Vec<ProsodyFeature> {
38 prosody
39 .iter()
40 .map(|p| match p {
41 Some(info) => [info.a1, info.a2, info.a3],
42 None => [0, 0, 0],
43 })
44 .collect()
45}
46
47#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
49pub enum UnknownTokenMode {
50 Strict,
52 #[default]
54 Skip,
55}
56
57pub struct PiperEncoder {
60 id_map: PhonemeIdMap,
61 mode: UnknownTokenMode,
62 bos_id: i64,
63 eos_id: i64,
64 pad_id: i64,
65}
66
67impl PiperEncoder {
68 pub fn new(id_map: PhonemeIdMap, mode: UnknownTokenMode) -> Result<Self, G2pError> {
70 let bos_id = id_map
71 .get("^")
72 .and_then(|ids| ids.first().copied())
73 .ok_or_else(|| G2pError::Phonemize("phoneme_id_map missing '^' (BOS)".into()))?;
74 let eos_id = id_map
75 .get("$")
76 .and_then(|ids| ids.first().copied())
77 .ok_or_else(|| G2pError::Phonemize("phoneme_id_map missing '$' (EOS)".into()))?;
78 let pad_id = id_map
79 .get("_")
80 .and_then(|ids| ids.first().copied())
81 .ok_or_else(|| G2pError::Phonemize("phoneme_id_map missing '_' (PAD)".into()))?;
82 Ok(Self {
83 id_map,
84 mode,
85 bos_id,
86 eos_id,
87 pad_id,
88 })
89 }
90
91 fn resolve_eos_id(&self, eos_token: Option<&str>) -> Result<i64, G2pError> {
96 match eos_token {
97 None => Ok(self.eos_id),
98 Some(token) => {
99 if let Some(&id) = self.id_map.get(token).and_then(|ids| ids.first()) {
101 return Ok(id);
102 }
103 if let Some(pua_char) = token_to_pua(token) {
105 let pua_str = pua_char.to_string();
106 if let Some(&id) = self.id_map.get(&pua_str).and_then(|ids| ids.first()) {
107 return Ok(id);
108 }
109 }
110 Err(G2pError::PhonemeIdNotFound {
111 phoneme: token.to_string(),
112 })
113 }
114 }
115 }
116
117 pub fn encode(&self, tokens: &[String]) -> Result<Vec<i64>, G2pError> {
119 self.encode_with_eos(tokens, None)
120 }
121
122 pub fn encode_with_eos(
128 &self,
129 tokens: &[String],
130 eos_token: Option<&str>,
131 ) -> Result<Vec<i64>, G2pError> {
132 let (ids, _) = self.encode_with_prosody_and_eos(tokens, &[], eos_token)?;
133 Ok(ids)
134 }
135
136 pub fn encode_with_prosody(
138 &self,
139 tokens: &[String],
140 prosody: &[Option<ProsodyInfo>],
141 ) -> Result<(Vec<i64>, Vec<ProsodyFeature>), G2pError> {
142 self.encode_with_prosody_and_eos(tokens, prosody, None)
143 }
144
145 pub fn encode_with_prosody_and_eos(
147 &self,
148 tokens: &[String],
149 prosody: &[Option<ProsodyInfo>],
150 eos_token: Option<&str>,
151 ) -> Result<(Vec<i64>, Vec<ProsodyFeature>), G2pError> {
152 let resolved_eos = self.resolve_eos_id(eos_token)?;
153 let mut ids = Vec::with_capacity(tokens.len() * 3 + 3);
154 let mut pros = Vec::with_capacity(tokens.len() * 3 + 3);
155
156 ids.push(self.bos_id);
158 pros.push([0, 0, 0]);
159 ids.push(self.pad_id);
160 pros.push([0, 0, 0]);
161
162 for (i, token) in tokens.iter().enumerate() {
163 let mapped: String = match token_to_pua(token) {
166 Some(pua_char) => pua_char.to_string(),
167 None => token.clone(),
168 };
169 for ch in mapped.chars() {
170 let ch_str = ch.to_string();
171 match self.id_map.get(&ch_str) {
172 Some(id_list) => {
173 let p = prosody.get(i).and_then(|o| o.as_ref());
174 let feat = match p {
175 Some(info) => [info.a1, info.a2, info.a3],
176 None => [0, 0, 0],
177 };
178 for &id in id_list {
179 ids.push(id);
180 pros.push(feat);
181 }
182 }
183 None => match self.mode {
184 UnknownTokenMode::Strict => {
185 return Err(G2pError::PhonemeIdNotFound { phoneme: ch_str });
186 }
187 UnknownTokenMode::Skip => {
188 tracing::warn!(phoneme = %ch_str, "unknown symbol dropped");
189 }
190 },
191 }
192 }
193 ids.push(self.pad_id);
194 pros.push([0, 0, 0]);
195 }
196
197 ids.push(resolved_eos);
198 pros.push([0, 0, 0]);
199 Ok((ids, pros))
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use std::collections::HashMap;
207
208 fn make_map(entries: &[(&str, &[i64])]) -> PhonemeIdMap {
210 let mut map = HashMap::new();
211 for (key, ids) in entries {
212 map.insert(key.to_string(), ids.to_vec());
213 }
214 map
215 }
216
217 #[test]
218 fn test_basic_token_to_id() {
219 let map = make_map(&[
220 ("^", &[1]),
221 ("_", &[0]),
222 ("$", &[2]),
223 ("a", &[15]),
224 ("k", &[30]),
225 ]);
226 let tokens: Vec<String> = vec!["^", "a", "_", "k", "$"]
227 .into_iter()
228 .map(String::from)
229 .collect();
230
231 let ids = tokens_to_ids(&tokens, &map).unwrap();
232 assert_eq!(ids, vec![1, 15, 0, 30, 2]);
233 }
234
235 #[test]
236 fn test_pua_character_conversion() {
237 let map = make_map(&[("^", &[1]), ("\u{E000}", &[45]), ("$", &[2])]);
239 let tokens: Vec<String> = vec!["^", "\u{E000}", "$"]
240 .into_iter()
241 .map(String::from)
242 .collect();
243
244 let ids = tokens_to_ids(&tokens, &map).unwrap();
245 assert_eq!(ids, vec![1, 45, 2]);
246 }
247
248 #[test]
249 fn test_unknown_phoneme_error() {
250 let map = make_map(&[("a", &[15])]);
251 let tokens: Vec<String> = vec!["a", "Z"].into_iter().map(String::from).collect();
252
253 let result = tokens_to_ids(&tokens, &map);
254 assert!(result.is_err());
255 let err = result.unwrap_err();
256 let msg = format!("{err}");
257 assert!(
258 msg.contains("Z"),
259 "error message should contain the unknown phoneme 'Z', got: {msg}"
260 );
261 }
262
263 #[test]
264 fn test_prosody_conversion() {
265 let prosody = vec![
266 Some(ProsodyInfo {
267 a1: -2,
268 a2: 1,
269 a3: 5,
270 }),
271 None,
272 Some(ProsodyInfo {
273 a1: 0,
274 a2: 3,
275 a3: 4,
276 }),
277 ];
278
279 let features = prosody_to_features(&prosody);
280 assert_eq!(features.len(), 3);
281 assert_eq!(features[0], [-2, 1, 5]);
282 assert_eq!(features[1], [0, 0, 0]);
283 assert_eq!(features[2], [0, 3, 4]);
284 }
285
286 #[test]
287 fn test_multi_id_mapping() {
288 let map = make_map(&[("a", &[10, 11]), ("b", &[20])]);
290 let tokens: Vec<String> = vec!["a", "b"].into_iter().map(String::from).collect();
291
292 let ids = tokens_to_ids(&tokens, &map).unwrap();
293 assert_eq!(ids, vec![10, 11, 20]);
294 }
295
296 #[test]
297 fn test_empty_tokens() {
298 let map = make_map(&[("a", &[1])]);
299 let tokens: Vec<String> = vec![];
300
301 let ids = tokens_to_ids(&tokens, &map).unwrap();
302 assert!(ids.is_empty());
303 }
304
305 #[test]
306 fn test_piper_encoder_basic() {
307 let map = make_map(&[
308 ("^", &[1]),
309 ("_", &[0]),
310 ("$", &[2]),
311 ("a", &[15]),
312 ("k", &[30]),
313 ]);
314 let encoder = PiperEncoder::new(map, UnknownTokenMode::Skip).unwrap();
315 let tokens: Vec<String> = vec!["a", "k"].into_iter().map(String::from).collect();
316 let ids = encoder.encode(&tokens).unwrap();
317 assert_eq!(ids[0], 1); assert_eq!(*ids.last().unwrap(), 2); assert!(ids.contains(&15));
320 assert!(ids.contains(&30));
321 }
322
323 #[test]
324 fn test_piper_encoder_strict_error() {
325 let map = make_map(&[("^", &[1]), ("_", &[0]), ("$", &[2]), ("a", &[15])]);
326 let encoder = PiperEncoder::new(map, UnknownTokenMode::Strict).unwrap();
327 let tokens: Vec<String> = vec!["a", "Z"].into_iter().map(String::from).collect();
328 assert!(encoder.encode(&tokens).is_err());
329 }
330
331 #[test]
332 fn test_piper_encoder_skip_unknown() {
333 let map = make_map(&[("^", &[1]), ("_", &[0]), ("$", &[2]), ("a", &[15])]);
334 let encoder = PiperEncoder::new(map, UnknownTokenMode::Skip).unwrap();
335 let tokens: Vec<String> = vec!["a", "Z"].into_iter().map(String::from).collect();
336 let ids = encoder.encode(&tokens).unwrap();
337 assert!(ids.contains(&15));
338 }
339
340 #[test]
341 fn test_piper_encoder_missing_bos() {
342 let map = make_map(&[("_", &[0]), ("$", &[2])]);
343 assert!(PiperEncoder::new(map, UnknownTokenMode::Skip).is_err());
344 }
345
346 #[test]
347 fn test_encode_with_default_eos() {
348 let map = make_map(&[
349 ("^", &[1]),
350 ("_", &[0]),
351 ("$", &[2]),
352 ("a", &[15]),
353 ("k", &[30]),
354 ]);
355 let encoder = PiperEncoder::new(map, UnknownTokenMode::Skip).unwrap();
356 let tokens: Vec<String> = vec!["a", "k"].into_iter().map(String::from).collect();
357 let ids_default = encoder.encode(&tokens).unwrap();
358 let ids_none = encoder.encode_with_eos(&tokens, None).unwrap();
359 assert_eq!(ids_default, ids_none);
360 }
361
362 #[test]
363 fn test_encode_with_question_eos() {
364 let map = make_map(&[
365 ("^", &[1]),
366 ("_", &[0]),
367 ("$", &[2]),
368 ("?", &[99]),
369 ("a", &[15]),
370 ]);
371 let encoder = PiperEncoder::new(map, UnknownTokenMode::Skip).unwrap();
372 let tokens: Vec<String> = vec!["a"].into_iter().map(String::from).collect();
373 let ids = encoder.encode_with_eos(&tokens, Some("?")).unwrap();
374 assert_eq!(*ids.last().unwrap(), 99);
376 assert_eq!(ids[0], 1);
378 }
379
380 #[test]
381 fn test_encode_with_pua_eos() {
382 let pua_char = crate::token_map::token_to_pua("?!").unwrap();
384 let pua_str = pua_char.to_string();
385 let map = make_map(&[
386 ("^", &[1]),
387 ("_", &[0]),
388 ("$", &[2]),
389 (&pua_str, &[88]),
390 ("a", &[15]),
391 ]);
392 let encoder = PiperEncoder::new(map, UnknownTokenMode::Skip).unwrap();
393 let tokens: Vec<String> = vec!["a"].into_iter().map(String::from).collect();
394 let ids = encoder.encode_with_eos(&tokens, Some("?!")).unwrap();
395 assert_eq!(*ids.last().unwrap(), 88);
396 }
397
398 #[test]
399 fn test_encode_with_prosody_and_eos() {
400 let map = make_map(&[
401 ("^", &[1]),
402 ("_", &[0]),
403 ("$", &[2]),
404 ("?", &[99]),
405 ("a", &[15]),
406 ]);
407 let encoder = PiperEncoder::new(map, UnknownTokenMode::Skip).unwrap();
408 let tokens: Vec<String> = vec!["a"].into_iter().map(String::from).collect();
409 let prosody = vec![Some(ProsodyInfo {
410 a1: -2,
411 a2: 1,
412 a3: 5,
413 })];
414 let (ids, pros) = encoder
415 .encode_with_prosody_and_eos(&tokens, &prosody, Some("?"))
416 .unwrap();
417 assert_eq!(*ids.last().unwrap(), 99);
419 assert_eq!(*pros.last().unwrap(), [0, 0, 0]);
421 assert_eq!(ids.len(), pros.len());
423 }
424
425 #[test]
426 fn test_resolve_eos_invalid() {
427 let map = make_map(&[("^", &[1]), ("_", &[0]), ("$", &[2]), ("a", &[15])]);
428 let encoder = PiperEncoder::new(map, UnknownTokenMode::Skip).unwrap();
429 let tokens: Vec<String> = vec!["a"].into_iter().map(String::from).collect();
430 let result = encoder.encode_with_eos(&tokens, Some("NONEXISTENT"));
431 assert!(result.is_err());
432 let msg = format!("{}", result.unwrap_err());
433 assert!(
434 msg.contains("NONEXISTENT"),
435 "error should mention the unknown token, got: {msg}"
436 );
437 }
438}