gpt_sovits/
text.rs

1mod bert;
2mod dict;
3mod en;
4mod num;
5mod phone_symbol;
6mod utils;
7mod zh;
8
9use {
10    crate::error::GSVError,
11    jieba_rs::Jieba,
12    log::{debug, warn},
13    ndarray::Array2,
14    regex::Regex,
15    std::sync::LazyLock,
16    unicode_segmentation::UnicodeSegmentation,
17};
18pub use {
19    bert::BertModel,
20    en::{EnSentence, EnWord, G2pEn},
21    num::{NumSentence, is_numeric},
22    phone_symbol::get_phone_symbol,
23    utils::{BERT_TOKENIZER, DICT_MONO_CHARS, DICT_POLY_CHARS, argmax_2d, str_is_chinese},
24    zh::{G2PW, G2PWOut, ZhMode, ZhSentence},
25};
26
27// Regex to handle emojis and symbols
28static CLEANUP_REGEX: LazyLock<Regex> = LazyLock::new(|| {
29    Regex::new(
30        r"[\u{1F600}-\u{1F64F}\u{1F300}-\u{1F5FF}\u{1F680}-\u{1F6FF}\u{1F900}-\u{1F9FF}\u{2600}-\u{27BF}\u{2000}-\u{206F}\u{2300}-\u{23FF}]+",
31    )
32    .unwrap()
33});
34
35// Simplified regex for tokenization
36static TOKEN_REGEX: LazyLock<Regex> = LazyLock::new(|| {
37    Regex::new(
38        r#"(?x)
39        \p{Han}+ |              # Chinese characters
40        [a-zA-Z]+(?:['-][a-zA-Z]+)* | # English words with optional apostrophes/hyphens
41        \d+(?:\.\d+)? |          # Numbers (including decimals)
42        [.,!?;:()\[\]<>\-"$/\u{3001}\u{3002}\u{FF01}\u{FF1F}\u{FF1B}\u{FF1A}\u{FF0C}\u{2018}\u{2019}\u{201C}\u{201D}] | # Punctuation
43        \s+                      # Whitespace
44        "#,
45    )
46    .unwrap()
47});
48
49/// Filters out emojis and other non-essential symbols from the text.
50fn cleanup_text(text: &str) -> String {
51    CLEANUP_REGEX.replace_all(text, " ").into_owned()
52}
53
54pub fn split_text(text: &str) -> Vec<String> {
55    let mut items = Vec::with_capacity(text.len() / 20);
56    let mut current = String::with_capacity(64);
57    let mut chars = text.chars().peekable();
58
59    while let Some(c) = chars.next() {
60        // Handle newlines separately - don't add them to current sentence
61        if c == '\n' || c == '\r' {
62            let trimmed = current.trim();
63            if !trimmed.is_empty() {
64                items.push(trimmed.to_string());
65            }
66            current.clear();
67            continue;
68        }
69
70        current.push(c);
71
72        // Check if current character is end punctuation
73        let is_end_punctuation = matches!(c, '。' | '!' | '?' | ';' | '.' | '!' | '?' | ';');
74
75        if is_end_punctuation {
76            // Special handling for period (.)
77            if c == '.' {
78                if let Some(&next_char) = chars.peek() {
79                    // Case 1: Abbreviation like "Dr. Smith" - next char is space followed by uppercase
80                    if next_char == ' ' {
81                        let mut peek_iter = chars.clone();
82                        peek_iter.next(); // Skip the space
83                        if let Some(after_space) = peek_iter.next() {
84                            if after_space.is_uppercase() {
85                                // This is likely an abbreviation, continue
86                                continue;
87                            }
88                        }
89                    }
90
91                    // Case 2: Decimal number like "1.0版本" - next char is digit
92                    if next_char.is_digit(10) {
93                        continue;
94                    }
95
96                    // Case 3: Abbreviation with lowercase letter following
97                    if next_char.is_lowercase() {
98                        continue;
99                    }
100                }
101            }
102            // For other punctuation, check if next character is lowercase letter
103            else if matches!(c, '!' | '?' | ';') {
104                if let Some(&next_char) = chars.peek() {
105                    if next_char.is_lowercase() {
106                        continue;
107                    }
108                }
109            }
110
111            let trimmed = current.trim();
112            if !trimmed.is_empty() {
113                items.push(trimmed.to_string());
114            }
115            current.clear();
116        }
117    }
118
119    // Handle any remaining text
120    let trimmed = current.trim();
121    if !trimmed.is_empty() {
122        items.push(trimmed.to_string());
123    }
124
125    items
126}
127
128#[derive(Debug, Clone, Copy, PartialEq)]
129pub enum Lang {
130    Zh,
131    En,
132}
133
134#[derive(Debug, Clone, Copy)]
135pub enum LangId {
136    Auto,    // Mandarin
137    AutoYue, // Cantonese
138}
139
140pub trait SentenceProcessor {
141    fn get_text_for_bert(&self) -> String;
142    fn get_word2ph(&self) -> &[i32];
143    fn get_phone_ids(&self) -> &[i64];
144}
145
146impl SentenceProcessor for EnSentence {
147    fn get_text_for_bert(&self) -> String {
148        let mut result = String::with_capacity(self.text.len() * 10);
149        for word in &self.text {
150            match word {
151                EnWord::Word(w) => {
152                    if !result.is_empty() && !result.ends_with(' ') {
153                        result.push(' ');
154                    }
155                    result.push_str(w);
156                }
157                EnWord::Punctuation(p) => {
158                    result.push_str(p);
159                }
160            }
161        }
162        debug!("English BERT text: {}", result);
163        result
164    }
165
166    fn get_word2ph(&self) -> &[i32] {
167        &self.word2ph
168    }
169
170    fn get_phone_ids(&self) -> &[i64] {
171        &self.phone_ids
172    }
173}
174
175impl SentenceProcessor for ZhSentence {
176    fn get_text_for_bert(&self) -> String {
177        debug!("Chinese BERT text: {}", self.text);
178        self.text.clone()
179    }
180
181    fn get_word2ph(&self) -> &[i32] {
182        &self.word2ph
183    }
184
185    fn get_phone_ids(&self) -> &[i64] {
186        &self.phone_ids
187    }
188}
189
190pub struct TextProcessor {
191    pub jieba: Jieba,
192    pub g2pw: G2PW,
193    pub g2p_en: G2pEn,
194    pub bert_model: BertModel,
195}
196
197impl TextProcessor {
198    pub fn new(g2pw: G2PW, g2p_en: G2pEn, bert_model: BertModel) -> Result<Self, GSVError> {
199        Ok(Self {
200            jieba: Jieba::new(),
201            g2pw,
202            g2p_en,
203            bert_model,
204        })
205    }
206
207    pub fn get_phone_and_bert(
208        &mut self,
209        text: &str,
210        lang_id: LangId,
211    ) -> Result<Vec<(String, Vec<i64>, Array2<f32>)>, GSVError> {
212        if text.trim().is_empty() {
213            return Err(GSVError::InputEmpty);
214        }
215
216        let cleaned_text = cleanup_text(text);
217        let chunks = split_text(&cleaned_text);
218        let mut result = Vec::with_capacity(chunks.len());
219
220        for chunk in chunks.iter() {
221            debug!("Processing chunk: {}", chunk);
222            let mut phone_builder = PhoneBuilder::new(chunk);
223            phone_builder.extend_text(&self.jieba, chunk);
224
225            if !chunk
226                .trim_end()
227                .ends_with(['。', '.', '?', '?', '!', '!', ';', ';', '\n'])
228            {
229                phone_builder.push_punctuation(".");
230            }
231
232            // --- A. Collect data for all sub-sentences in the chunk ---
233            #[derive(Debug)]
234            struct SubSentenceData {
235                bert_text: String,
236                word2ph: Vec<i32>,
237                phone_ids: Vec<i64>,
238            }
239            let mut sub_sentences_data: Vec<SubSentenceData> = Vec::new();
240
241            for mut sentence in phone_builder.sentences {
242                let g2p_result = match &mut sentence {
243                    Sentence::Zh(zh) => {
244                        let mode = if matches!(lang_id, LangId::AutoYue) {
245                            ZhMode::Cantonese
246                        } else {
247                            ZhMode::Mandarin
248                        };
249                        zh.g2p(&mut self.g2pw, mode);
250                        zh.build_phone()
251                    }
252                    Sentence::En(en) => en.g2p(&mut self.g2p_en).and_then(|_| en.build_phone()),
253                };
254
255                match g2p_result {
256                    Ok(phone_seq) => {
257                        if phone_seq.is_empty() {
258                            continue; // Skip parts that produce no phonemes
259                        }
260                        sub_sentences_data.push(SubSentenceData {
261                            bert_text: sentence.get_text_for_bert(),
262                            word2ph: sentence.get_word2ph().to_vec(),
263                            phone_ids: sentence.get_phone_ids().to_vec(),
264                        });
265                    }
266                    Err(e) => {
267                        warn!("G2P failed for a sentence part in chunk '{}': {}", chunk, e);
268
269                        // Continue processing other parts of the chunk
270                    }
271                }
272            }
273
274            // --- B. Group sub-sentences into logically complete sentences ---
275            #[derive(Default, Debug)]
276            struct GroupedSentence {
277                text: String,
278                word2ph: Vec<i32>,
279                phone_ids: Vec<i64>,
280            }
281            let mut grouped_sentences: Vec<GroupedSentence> = Vec::new();
282            let mut current_group = GroupedSentence::default();
283
284            for data in sub_sentences_data {
285                let ends_sentence = data
286                    .bert_text
287                    .find(['。', '.', '?', '?', '!', '!', ';', ';']);
288
289                current_group.text.push_str(&data.bert_text);
290                current_group.word2ph.extend(data.word2ph);
291                current_group.phone_ids.extend(data.phone_ids);
292                if ends_sentence.is_some() {
293                    grouped_sentences.push(current_group);
294                    current_group = GroupedSentence::default()
295                }
296            }
297            // Add any remaining part that didn't end with punctuation
298            if !current_group.text.is_empty() {
299                grouped_sentences.push(current_group);
300            }
301
302            // --- C. Process each complete sentence with BERT ---
303            for group in grouped_sentences {
304                debug!("Processing grouped sentence: '{}'", group.text);
305                let total_expected_bert_len = group.phone_ids.len();
306
307                match self
308                    .bert_model
309                    .get_bert(&group.text, &group.word2ph, total_expected_bert_len)
310                {
311                    Ok(bert_features) => {
312                        if bert_features.shape()[0] != total_expected_bert_len {
313                            let error_msg = format!(
314                                "BERT output length mismatch for text '{}': expected {}, got {}",
315                                group.text,
316                                total_expected_bert_len,
317                                bert_features.shape()[0]
318                            );
319                            warn!("{}", error_msg);
320
321                            continue;
322                        }
323                        result.push((group.text, group.phone_ids, bert_features));
324                    }
325                    Err(e) => {
326                        warn!(
327                            "Failed to get BERT features for text '{}': {}",
328                            group.text, e
329                        );
330                    }
331                }
332            }
333        }
334
335        debug!("RESULT (total sentences: {})", result.len());
336        if result.is_empty() {
337            return Err(GSVError::GeneratePhonemesOrBertFeaturesFailed(
338                text.to_owned(),
339            ));
340        }
341        Ok(result)
342    }
343}
344
345fn parse_punctuation(p: &str) -> Option<&'static str> {
346    match p {
347        "," | "," => Some(","),
348        "。" | "." => Some("."),
349        "!" | "!" => Some("!"),
350        "?" | "?" => Some("?"),
351        ";" | ";" => Some(";"),
352        ":" | ":" => Some(":"),
353        "‘" | "’" | "'" => Some("'"),
354        "'" => Some("'"),
355        "“" | "”" | "\"" => Some("\""),
356        """ => Some("\""),
357        "(" | "(" => Some("("),
358        ")" | ")" => Some(")"),
359        "【" | "[" => Some("["),
360        "】" | "]" => Some("]"),
361        "《" | "<" => Some("<"),
362        "》" | ">" => Some(">"),
363        "—" | "–" => Some("-"),
364        "~" | "~" => Some("~"),
365        "…" | "..." => Some("..."),
366        "·" => Some("·"),
367        "、" => Some("、"),
368        "$" => Some("$"),
369        "/" => Some("/"),
370        "\n" => Some("\n"), // Corrected escape sequence
371        " " => Some(" "),
372        _ => None,
373    }
374}
375
376#[derive(Debug)]
377enum Sentence {
378    Zh(ZhSentence),
379    En(EnSentence),
380}
381
382impl SentenceProcessor for Sentence {
383    fn get_text_for_bert(&self) -> String {
384        match self {
385            Sentence::Zh(zh) => zh.get_text_for_bert(),
386            Sentence::En(en) => en.get_text_for_bert(),
387        }
388    }
389
390    fn get_word2ph(&self) -> &[i32] {
391        match self {
392            Sentence::Zh(zh) => zh.get_word2ph(),
393            Sentence::En(en) => en.get_word2ph(),
394        }
395    }
396
397    fn get_phone_ids(&self) -> &[i64] {
398        match self {
399            Sentence::Zh(s) => s.get_phone_ids(),
400            Sentence::En(s) => s.get_phone_ids(),
401        }
402    }
403}
404
405struct PhoneBuilder {
406    sentences: Vec<Sentence>,
407    sentence_lang: Lang,
408}
409
410impl PhoneBuilder {
411    fn new(text: &str) -> Self {
412        let sentence_lang = detect_sentence_language(text);
413        Self {
414            sentences: Vec::with_capacity(16),
415            sentence_lang,
416        }
417    }
418
419    fn extend_text(&mut self, jieba: &Jieba, text: &str) {
420        let tokens: Vec<&str> = if str_is_chinese(text) {
421            jieba.cut(text, true).into_iter().collect()
422        } else {
423            TOKEN_REGEX.find_iter(text).map(|m| m.as_str()).collect()
424        };
425
426        for t in tokens {
427            if let Some(p) = parse_punctuation(t) {
428                self.push_punctuation(p);
429                continue;
430            }
431
432            if is_numeric(t) {
433                let ns = NumSentence {
434                    text: t.to_owned(),
435                    lang: self.sentence_lang,
436                };
437                let txt = match ns.to_lang_text() {
438                    Ok(txt) => txt,
439                    Err(e) => {
440                        warn!("Failed to process numeric token '{}': {}", t, e);
441                        t.to_string()
442                    }
443                };
444                match self.sentence_lang {
445                    Lang::Zh => self.push_zh_word(&txt),
446                    Lang::En => self.push_en_word(&txt),
447                }
448            } else if str_is_chinese(t) {
449                self.push_zh_word(t);
450            } else if t
451                .chars()
452                .all(|c| c.is_ascii_alphabetic() || c == '\'' || c == '-')
453            {
454                self.push_en_word(t);
455            } else {
456                // Handle mixed-language tokens by re-tokenizing the mixed token
457                for sub_token in TOKEN_REGEX.find_iter(t) {
458                    let sub_token_str = sub_token.as_str();
459                    if let Some(p) = parse_punctuation(sub_token_str) {
460                        self.push_punctuation(p);
461                    } else if is_numeric(sub_token_str) {
462                        let ns = NumSentence {
463                            text: sub_token_str.to_owned(),
464                            lang: self.sentence_lang,
465                        };
466                        let txt = match ns.to_lang_text() {
467                            Ok(txt) => txt,
468                            Err(e) => {
469                                warn!("Failed to process numeric token '{}': {}", sub_token_str, e);
470                                sub_token_str.to_string()
471                            }
472                        };
473                        match self.sentence_lang {
474                            Lang::Zh => self.push_zh_word(&txt),
475                            Lang::En => self.push_en_word(&txt),
476                        }
477                    } else if str_is_chinese(sub_token_str) {
478                        self.push_zh_word(sub_token_str);
479                    } else if sub_token_str
480                        .chars()
481                        .all(|c| c.is_ascii_alphabetic() || c == '\'' || c == '-')
482                    {
483                        self.push_en_word(sub_token_str);
484                    }
485                }
486            }
487        }
488    }
489
490    fn push_punctuation(&mut self, p: &'static str) {
491        match self.sentences.last_mut() {
492            Some(Sentence::Zh(zh)) => {
493                zh.text.push_str(p);
494                zh.phones.push(G2PWOut::RawChar(p.chars().next().unwrap()));
495            }
496            Some(Sentence::En(en)) => {
497                // Simplified condition check
498                if p == " " && matches!(en.text.last(), Some(EnWord::Word(w)) if w == "a") {
499                    return;
500                }
501                en.text.push(EnWord::Punctuation(p));
502            }
503            None => {
504                let en = EnSentence {
505                    phone_ids: Vec::with_capacity(16),
506                    phones: Vec::with_capacity(16),
507                    text: vec![EnWord::Punctuation(p)],
508                    word2ph: Vec::with_capacity(16),
509                };
510                self.sentences.push(Sentence::En(en));
511            }
512        }
513    }
514
515    fn push_en_word(&mut self, word: &str) {
516        if word.ends_with(['。', '.', '?', '?', '!', '!', ';', ';', '\n']) {
517            let en = EnSentence {
518                phone_ids: Vec::with_capacity(16),
519                phones: Vec::with_capacity(16),
520                text: vec![EnWord::Word(word.to_string())],
521                word2ph: Vec::with_capacity(16),
522            };
523            self.sentences.push(Sentence::En(en));
524        }
525        match self.sentences.last_mut() {
526            Some(Sentence::En(en)) => {
527                // Simplified condition check using matches! macro
528                if matches!(en.text.last(), Some(EnWord::Punctuation(p)) if *p == "'" || *p == "-")
529                {
530                    let p = en.text.pop().unwrap();
531                    if let Some(EnWord::Word(last_word)) = en.text.last_mut() {
532                        if let EnWord::Punctuation(p_str) = p {
533                            last_word.push_str(p_str);
534                            last_word.push_str(word);
535                            return;
536                        }
537                    }
538                    en.text.push(p); // Push back if not applicable
539                }
540                en.text.push(EnWord::Word(word.to_string()));
541            }
542            _ => {
543                let en = EnSentence {
544                    phone_ids: Vec::with_capacity(16),
545                    phones: Vec::with_capacity(16),
546                    text: vec![EnWord::Word(word.to_string())],
547                    word2ph: Vec::with_capacity(16),
548                };
549                self.sentences.push(Sentence::En(en));
550            }
551        }
552    }
553
554    fn push_zh_word(&mut self, word: &str) {
555        fn add_zh_word(zh: &mut ZhSentence, word: &str) {
556            zh.text.push_str(word);
557            match dict::zh_word_dict(word) {
558                Some(phones) => {
559                    zh.phones.extend(
560                        phones
561                            .into_iter()
562                            .map(|p: &String| G2PWOut::Pinyin(p.clone())),
563                    );
564                }
565                None => {
566                    zh.phones
567                        .extend(word.chars().map(|_| G2PWOut::Pinyin(String::new())));
568                }
569            }
570        }
571
572        if word.ends_with(['。', '.', '?', '?', '!', '!', ';', ';', '\n']) {
573            let zh = ZhSentence {
574                phone_ids: Vec::with_capacity(16),
575                phones: Vec::with_capacity(16),
576                word2ph: Vec::with_capacity(16),
577                text: String::with_capacity(32),
578            };
579            self.sentences.push(Sentence::Zh(zh));
580        }
581
582        match self.sentences.last_mut() {
583            Some(Sentence::Zh(zh)) => add_zh_word(zh, word),
584            _ => {
585                let mut zh = ZhSentence {
586                    phone_ids: Vec::with_capacity(16),
587                    phones: Vec::with_capacity(16),
588                    word2ph: Vec::with_capacity(16),
589                    text: String::with_capacity(32),
590                };
591                add_zh_word(&mut zh, word);
592                self.sentences.push(Sentence::Zh(zh));
593            }
594        }
595    }
596}
597
598/// Detects the dominant language of a sentence based on character distribution.
599fn detect_sentence_language(text: &str) -> Lang {
600    let graphemes = text.graphemes(true).collect::<Vec<&str>>();
601    let total_chars = graphemes.len();
602    if total_chars == 0 {
603        return Lang::Zh; // Default to Chinese for empty input
604    }
605
606    let zh_count = graphemes.iter().filter(|&&g| str_is_chinese(g)).count();
607    let zh_percent = zh_count as f32 / total_chars as f32;
608
609    debug!("chinese percent {}", zh_percent);
610    if zh_percent > 0.3 { Lang::Zh } else { Lang::En }
611}
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616
617    #[test]
618    fn test_split_text() {
619        assert_eq!(split_text("Dr. Smith"), ["Dr. Smith"]);
620        assert_eq!(split_text("1.0版本"), ["1.0版本"]);
621    }
622}