1mod bert;
2mod dict;
3mod en;
4mod num;
5mod phone_symbol;
6mod utils;
7mod zh;
8
9use {
10 crate::error::GSVError,
11 jieba_rs::Jieba,
12 log::{debug, warn},
13 ndarray::Array2,
14 regex::Regex,
15 std::sync::LazyLock,
16 unicode_segmentation::UnicodeSegmentation,
17};
18pub use {
19 bert::BertModel,
20 en::{EnSentence, EnWord, G2pEn},
21 num::{NumSentence, is_numeric},
22 phone_symbol::get_phone_symbol,
23 utils::{BERT_TOKENIZER, DICT_MONO_CHARS, DICT_POLY_CHARS, argmax_2d, str_is_chinese},
24 zh::{G2PW, G2PWOut, ZhMode, ZhSentence},
25};
26
27static CLEANUP_REGEX: LazyLock<Regex> = LazyLock::new(|| {
29 Regex::new(
30 r"[\u{1F600}-\u{1F64F}\u{1F300}-\u{1F5FF}\u{1F680}-\u{1F6FF}\u{1F900}-\u{1F9FF}\u{2600}-\u{27BF}\u{2000}-\u{206F}\u{2300}-\u{23FF}]+",
31 )
32 .unwrap()
33});
34
35static TOKEN_REGEX: LazyLock<Regex> = LazyLock::new(|| {
37 Regex::new(
38 r#"(?x)
39 \p{Han}+ | # Chinese characters
40 [a-zA-Z]+(?:['-][a-zA-Z]+)* | # English words with optional apostrophes/hyphens
41 \d+(?:\.\d+)? | # Numbers (including decimals)
42 [.,!?;:()\[\]<>\-"$/\u{3001}\u{3002}\u{FF01}\u{FF1F}\u{FF1B}\u{FF1A}\u{FF0C}\u{2018}\u{2019}\u{201C}\u{201D}] | # Punctuation
43 \s+ # Whitespace
44 "#,
45 )
46 .unwrap()
47});
48
49fn cleanup_text(text: &str) -> String {
51 CLEANUP_REGEX.replace_all(text, " ").into_owned()
52}
53
54pub fn split_text(text: &str) -> Vec<String> {
55 let mut items = Vec::with_capacity(text.len() / 20);
56 let mut current = String::with_capacity(64);
57 let mut chars = text.chars().peekable();
58
59 while let Some(c) = chars.next() {
60 if c == '\n' || c == '\r' {
62 let trimmed = current.trim();
63 if !trimmed.is_empty() {
64 items.push(trimmed.to_string());
65 }
66 current.clear();
67 continue;
68 }
69
70 current.push(c);
71
72 let is_end_punctuation = matches!(c, '。' | '!' | '?' | ';' | '.' | '!' | '?' | ';');
74
75 if is_end_punctuation {
76 if c == '.' {
78 if let Some(&next_char) = chars.peek() {
79 if next_char == ' ' {
81 let mut peek_iter = chars.clone();
82 peek_iter.next(); if let Some(after_space) = peek_iter.next() {
84 if after_space.is_uppercase() {
85 continue;
87 }
88 }
89 }
90
91 if next_char.is_digit(10) {
93 continue;
94 }
95
96 if next_char.is_lowercase() {
98 continue;
99 }
100 }
101 }
102 else if matches!(c, '!' | '?' | ';') {
104 if let Some(&next_char) = chars.peek() {
105 if next_char.is_lowercase() {
106 continue;
107 }
108 }
109 }
110
111 let trimmed = current.trim();
112 if !trimmed.is_empty() {
113 items.push(trimmed.to_string());
114 }
115 current.clear();
116 }
117 }
118
119 let trimmed = current.trim();
121 if !trimmed.is_empty() {
122 items.push(trimmed.to_string());
123 }
124
125 items
126}
127
128#[derive(Debug, Clone, Copy, PartialEq)]
129pub enum Lang {
130 Zh,
131 En,
132}
133
134#[derive(Debug, Clone, Copy)]
135pub enum LangId {
136 Auto, AutoYue, }
139
140pub trait SentenceProcessor {
141 fn get_text_for_bert(&self) -> String;
142 fn get_word2ph(&self) -> &[i32];
143 fn get_phone_ids(&self) -> &[i64];
144}
145
146impl SentenceProcessor for EnSentence {
147 fn get_text_for_bert(&self) -> String {
148 let mut result = String::with_capacity(self.text.len() * 10);
149 for word in &self.text {
150 match word {
151 EnWord::Word(w) => {
152 if !result.is_empty() && !result.ends_with(' ') {
153 result.push(' ');
154 }
155 result.push_str(w);
156 }
157 EnWord::Punctuation(p) => {
158 result.push_str(p);
159 }
160 }
161 }
162 debug!("English BERT text: {}", result);
163 result
164 }
165
166 fn get_word2ph(&self) -> &[i32] {
167 &self.word2ph
168 }
169
170 fn get_phone_ids(&self) -> &[i64] {
171 &self.phone_ids
172 }
173}
174
175impl SentenceProcessor for ZhSentence {
176 fn get_text_for_bert(&self) -> String {
177 debug!("Chinese BERT text: {}", self.text);
178 self.text.clone()
179 }
180
181 fn get_word2ph(&self) -> &[i32] {
182 &self.word2ph
183 }
184
185 fn get_phone_ids(&self) -> &[i64] {
186 &self.phone_ids
187 }
188}
189
190pub struct TextProcessor {
191 pub jieba: Jieba,
192 pub g2pw: G2PW,
193 pub g2p_en: G2pEn,
194 pub bert_model: BertModel,
195}
196
197impl TextProcessor {
198 pub fn new(g2pw: G2PW, g2p_en: G2pEn, bert_model: BertModel) -> Result<Self, GSVError> {
199 Ok(Self {
200 jieba: Jieba::new(),
201 g2pw,
202 g2p_en,
203 bert_model,
204 })
205 }
206
207 pub fn get_phone_and_bert(
208 &mut self,
209 text: &str,
210 lang_id: LangId,
211 ) -> Result<Vec<(String, Vec<i64>, Array2<f32>)>, GSVError> {
212 if text.trim().is_empty() {
213 return Err(GSVError::InputEmpty);
214 }
215
216 let cleaned_text = cleanup_text(text);
217 let chunks = split_text(&cleaned_text);
218 let mut result = Vec::with_capacity(chunks.len());
219
220 for chunk in chunks.iter() {
221 debug!("Processing chunk: {}", chunk);
222 let mut phone_builder = PhoneBuilder::new(chunk);
223 phone_builder.extend_text(&self.jieba, chunk);
224
225 if !chunk
226 .trim_end()
227 .ends_with(['。', '.', '?', '?', '!', '!', ';', ';', '\n'])
228 {
229 phone_builder.push_punctuation(".");
230 }
231
232 #[derive(Debug)]
234 struct SubSentenceData {
235 bert_text: String,
236 word2ph: Vec<i32>,
237 phone_ids: Vec<i64>,
238 }
239 let mut sub_sentences_data: Vec<SubSentenceData> = Vec::new();
240
241 for mut sentence in phone_builder.sentences {
242 let g2p_result = match &mut sentence {
243 Sentence::Zh(zh) => {
244 let mode = if matches!(lang_id, LangId::AutoYue) {
245 ZhMode::Cantonese
246 } else {
247 ZhMode::Mandarin
248 };
249 zh.g2p(&mut self.g2pw, mode);
250 zh.build_phone()
251 }
252 Sentence::En(en) => en.g2p(&mut self.g2p_en).and_then(|_| en.build_phone()),
253 };
254
255 match g2p_result {
256 Ok(phone_seq) => {
257 if phone_seq.is_empty() {
258 continue; }
260 sub_sentences_data.push(SubSentenceData {
261 bert_text: sentence.get_text_for_bert(),
262 word2ph: sentence.get_word2ph().to_vec(),
263 phone_ids: sentence.get_phone_ids().to_vec(),
264 });
265 }
266 Err(e) => {
267 warn!("G2P failed for a sentence part in chunk '{}': {}", chunk, e);
268
269 }
271 }
272 }
273
274 #[derive(Default, Debug)]
276 struct GroupedSentence {
277 text: String,
278 word2ph: Vec<i32>,
279 phone_ids: Vec<i64>,
280 }
281 let mut grouped_sentences: Vec<GroupedSentence> = Vec::new();
282 let mut current_group = GroupedSentence::default();
283
284 for data in sub_sentences_data {
285 let ends_sentence = data
286 .bert_text
287 .find(['。', '.', '?', '?', '!', '!', ';', ';']);
288
289 current_group.text.push_str(&data.bert_text);
290 current_group.word2ph.extend(data.word2ph);
291 current_group.phone_ids.extend(data.phone_ids);
292 if ends_sentence.is_some() {
293 grouped_sentences.push(current_group);
294 current_group = GroupedSentence::default()
295 }
296 }
297 if !current_group.text.is_empty() {
299 grouped_sentences.push(current_group);
300 }
301
302 for group in grouped_sentences {
304 debug!("Processing grouped sentence: '{}'", group.text);
305 let total_expected_bert_len = group.phone_ids.len();
306
307 match self
308 .bert_model
309 .get_bert(&group.text, &group.word2ph, total_expected_bert_len)
310 {
311 Ok(bert_features) => {
312 if bert_features.shape()[0] != total_expected_bert_len {
313 let error_msg = format!(
314 "BERT output length mismatch for text '{}': expected {}, got {}",
315 group.text,
316 total_expected_bert_len,
317 bert_features.shape()[0]
318 );
319 warn!("{}", error_msg);
320
321 continue;
322 }
323 result.push((group.text, group.phone_ids, bert_features));
324 }
325 Err(e) => {
326 warn!(
327 "Failed to get BERT features for text '{}': {}",
328 group.text, e
329 );
330 }
331 }
332 }
333 }
334
335 debug!("RESULT (total sentences: {})", result.len());
336 if result.is_empty() {
337 return Err(GSVError::GeneratePhonemesOrBertFeaturesFailed(
338 text.to_owned(),
339 ));
340 }
341 Ok(result)
342 }
343}
344
345fn parse_punctuation(p: &str) -> Option<&'static str> {
346 match p {
347 "," | "," => Some(","),
348 "。" | "." => Some("."),
349 "!" | "!" => Some("!"),
350 "?" | "?" => Some("?"),
351 ";" | ";" => Some(";"),
352 ":" | ":" => Some(":"),
353 "‘" | "’" | "'" => Some("'"),
354 "'" => Some("'"),
355 "“" | "”" | "\"" => Some("\""),
356 """ => Some("\""),
357 "(" | "(" => Some("("),
358 ")" | ")" => Some(")"),
359 "【" | "[" => Some("["),
360 "】" | "]" => Some("]"),
361 "《" | "<" => Some("<"),
362 "》" | ">" => Some(">"),
363 "—" | "–" => Some("-"),
364 "~" | "~" => Some("~"),
365 "…" | "..." => Some("..."),
366 "·" => Some("·"),
367 "、" => Some("、"),
368 "$" => Some("$"),
369 "/" => Some("/"),
370 "\n" => Some("\n"), " " => Some(" "),
372 _ => None,
373 }
374}
375
376#[derive(Debug)]
377enum Sentence {
378 Zh(ZhSentence),
379 En(EnSentence),
380}
381
382impl SentenceProcessor for Sentence {
383 fn get_text_for_bert(&self) -> String {
384 match self {
385 Sentence::Zh(zh) => zh.get_text_for_bert(),
386 Sentence::En(en) => en.get_text_for_bert(),
387 }
388 }
389
390 fn get_word2ph(&self) -> &[i32] {
391 match self {
392 Sentence::Zh(zh) => zh.get_word2ph(),
393 Sentence::En(en) => en.get_word2ph(),
394 }
395 }
396
397 fn get_phone_ids(&self) -> &[i64] {
398 match self {
399 Sentence::Zh(s) => s.get_phone_ids(),
400 Sentence::En(s) => s.get_phone_ids(),
401 }
402 }
403}
404
405struct PhoneBuilder {
406 sentences: Vec<Sentence>,
407 sentence_lang: Lang,
408}
409
410impl PhoneBuilder {
411 fn new(text: &str) -> Self {
412 let sentence_lang = detect_sentence_language(text);
413 Self {
414 sentences: Vec::with_capacity(16),
415 sentence_lang,
416 }
417 }
418
419 fn extend_text(&mut self, jieba: &Jieba, text: &str) {
420 let tokens: Vec<&str> = if str_is_chinese(text) {
421 jieba.cut(text, true).into_iter().collect()
422 } else {
423 TOKEN_REGEX.find_iter(text).map(|m| m.as_str()).collect()
424 };
425
426 for t in tokens {
427 if let Some(p) = parse_punctuation(t) {
428 self.push_punctuation(p);
429 continue;
430 }
431
432 if is_numeric(t) {
433 let ns = NumSentence {
434 text: t.to_owned(),
435 lang: self.sentence_lang,
436 };
437 let txt = match ns.to_lang_text() {
438 Ok(txt) => txt,
439 Err(e) => {
440 warn!("Failed to process numeric token '{}': {}", t, e);
441 t.to_string()
442 }
443 };
444 match self.sentence_lang {
445 Lang::Zh => self.push_zh_word(&txt),
446 Lang::En => self.push_en_word(&txt),
447 }
448 } else if str_is_chinese(t) {
449 self.push_zh_word(t);
450 } else if t
451 .chars()
452 .all(|c| c.is_ascii_alphabetic() || c == '\'' || c == '-')
453 {
454 self.push_en_word(t);
455 } else {
456 for sub_token in TOKEN_REGEX.find_iter(t) {
458 let sub_token_str = sub_token.as_str();
459 if let Some(p) = parse_punctuation(sub_token_str) {
460 self.push_punctuation(p);
461 } else if is_numeric(sub_token_str) {
462 let ns = NumSentence {
463 text: sub_token_str.to_owned(),
464 lang: self.sentence_lang,
465 };
466 let txt = match ns.to_lang_text() {
467 Ok(txt) => txt,
468 Err(e) => {
469 warn!("Failed to process numeric token '{}': {}", sub_token_str, e);
470 sub_token_str.to_string()
471 }
472 };
473 match self.sentence_lang {
474 Lang::Zh => self.push_zh_word(&txt),
475 Lang::En => self.push_en_word(&txt),
476 }
477 } else if str_is_chinese(sub_token_str) {
478 self.push_zh_word(sub_token_str);
479 } else if sub_token_str
480 .chars()
481 .all(|c| c.is_ascii_alphabetic() || c == '\'' || c == '-')
482 {
483 self.push_en_word(sub_token_str);
484 }
485 }
486 }
487 }
488 }
489
490 fn push_punctuation(&mut self, p: &'static str) {
491 match self.sentences.last_mut() {
492 Some(Sentence::Zh(zh)) => {
493 zh.text.push_str(p);
494 zh.phones.push(G2PWOut::RawChar(p.chars().next().unwrap()));
495 }
496 Some(Sentence::En(en)) => {
497 if p == " " && matches!(en.text.last(), Some(EnWord::Word(w)) if w == "a") {
499 return;
500 }
501 en.text.push(EnWord::Punctuation(p));
502 }
503 None => {
504 let en = EnSentence {
505 phone_ids: Vec::with_capacity(16),
506 phones: Vec::with_capacity(16),
507 text: vec![EnWord::Punctuation(p)],
508 word2ph: Vec::with_capacity(16),
509 };
510 self.sentences.push(Sentence::En(en));
511 }
512 }
513 }
514
515 fn push_en_word(&mut self, word: &str) {
516 if word.ends_with(['。', '.', '?', '?', '!', '!', ';', ';', '\n']) {
517 let en = EnSentence {
518 phone_ids: Vec::with_capacity(16),
519 phones: Vec::with_capacity(16),
520 text: vec![EnWord::Word(word.to_string())],
521 word2ph: Vec::with_capacity(16),
522 };
523 self.sentences.push(Sentence::En(en));
524 }
525 match self.sentences.last_mut() {
526 Some(Sentence::En(en)) => {
527 if matches!(en.text.last(), Some(EnWord::Punctuation(p)) if *p == "'" || *p == "-")
529 {
530 let p = en.text.pop().unwrap();
531 if let Some(EnWord::Word(last_word)) = en.text.last_mut() {
532 if let EnWord::Punctuation(p_str) = p {
533 last_word.push_str(p_str);
534 last_word.push_str(word);
535 return;
536 }
537 }
538 en.text.push(p); }
540 en.text.push(EnWord::Word(word.to_string()));
541 }
542 _ => {
543 let en = EnSentence {
544 phone_ids: Vec::with_capacity(16),
545 phones: Vec::with_capacity(16),
546 text: vec![EnWord::Word(word.to_string())],
547 word2ph: Vec::with_capacity(16),
548 };
549 self.sentences.push(Sentence::En(en));
550 }
551 }
552 }
553
554 fn push_zh_word(&mut self, word: &str) {
555 fn add_zh_word(zh: &mut ZhSentence, word: &str) {
556 zh.text.push_str(word);
557 match dict::zh_word_dict(word) {
558 Some(phones) => {
559 zh.phones.extend(
560 phones
561 .into_iter()
562 .map(|p: &String| G2PWOut::Pinyin(p.clone())),
563 );
564 }
565 None => {
566 zh.phones
567 .extend(word.chars().map(|_| G2PWOut::Pinyin(String::new())));
568 }
569 }
570 }
571
572 if word.ends_with(['。', '.', '?', '?', '!', '!', ';', ';', '\n']) {
573 let zh = ZhSentence {
574 phone_ids: Vec::with_capacity(16),
575 phones: Vec::with_capacity(16),
576 word2ph: Vec::with_capacity(16),
577 text: String::with_capacity(32),
578 };
579 self.sentences.push(Sentence::Zh(zh));
580 }
581
582 match self.sentences.last_mut() {
583 Some(Sentence::Zh(zh)) => add_zh_word(zh, word),
584 _ => {
585 let mut zh = ZhSentence {
586 phone_ids: Vec::with_capacity(16),
587 phones: Vec::with_capacity(16),
588 word2ph: Vec::with_capacity(16),
589 text: String::with_capacity(32),
590 };
591 add_zh_word(&mut zh, word);
592 self.sentences.push(Sentence::Zh(zh));
593 }
594 }
595 }
596}
597
598fn detect_sentence_language(text: &str) -> Lang {
600 let graphemes = text.graphemes(true).collect::<Vec<&str>>();
601 let total_chars = graphemes.len();
602 if total_chars == 0 {
603 return Lang::Zh; }
605
606 let zh_count = graphemes.iter().filter(|&&g| str_is_chinese(g)).count();
607 let zh_percent = zh_count as f32 / total_chars as f32;
608
609 debug!("chinese percent {}", zh_percent);
610 if zh_percent > 0.3 { Lang::Zh } else { Lang::En }
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[test]
618 fn test_split_text() {
619 assert_eq!(split_text("Dr. Smith"), ["Dr. Smith"]);
620 assert_eq!(split_text("1.0版本"), ["1.0版本"]);
621 }
622}