gpt_sovits/text/zh/
g2pw.rs

1use {
2    crate::{
3        error::GSVError,
4        onnx_builder::create_onnx_cpu_session,
5        text::{BERT_TOKENIZER, DICT_MONO_CHARS, DICT_POLY_CHARS, argmax_2d},
6    },
7    ndarray::Array,
8    ort::value::Tensor,
9    std::{
10        fmt::Debug,
11        path::Path,
12        str::FromStr,
13        sync::{Arc, LazyLock},
14    },
15    tokenizers::Tokenizer,
16};
17
18pub static LABELS: &str = include_str!("dict_poly_index_list.json");
19
20pub static POLY_LABLES: LazyLock<Vec<String>> =
21    LazyLock::new(|| serde_json::from_str(LABELS).unwrap());
22
23#[derive(Clone)]
24pub enum G2PWOut {
25    Pinyin(String),
26    Yue(String),
27    RawChar(char),
28}
29
30impl Debug for G2PWOut {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            Self::Pinyin(s) => write!(f, "\"{}\"", s),
34            Self::Yue(s) => write!(f, "\"{}\"", s),
35            Self::RawChar(s) => write!(f, "\"{}\"", s),
36        }
37    }
38}
39
40#[derive(Debug)]
41pub struct G2PW {
42    model: Option<ort::session::Session>,
43    tokenizers: Option<Arc<tokenizers::Tokenizer>>,
44}
45
46impl G2PW {
47    pub fn new<P: AsRef<Path>>(g2pw_path: Option<P>) -> Result<Self, GSVError> {
48        if let Some(g2pw_path) = g2pw_path {
49            log::info!("G2PW model is loading...");
50            let model = create_onnx_cpu_session(g2pw_path)?;
51            log::info!("G2PW model is loaded.");
52            Ok(Self {
53                model: Some(model),
54                tokenizers: Some(Arc::new(Tokenizer::from_str(BERT_TOKENIZER).unwrap())),
55            })
56        } else {
57            Ok(Self {
58                model: None,
59                tokenizers: None,
60            })
61        }
62    }
63
64    pub fn g2p<'s>(&mut self, text: &'s str) -> Vec<G2PWOut> {
65        if self.model.is_some() && self.tokenizers.is_some() {
66            self.get_pinyin_ml(text)
67                .unwrap_or(self.simple_get_pinyin(text))
68        } else {
69            self.simple_get_pinyin(text)
70        }
71    }
72
73    pub fn simple_get_pinyin(&self, text: &str) -> Vec<G2PWOut> {
74        let mut pre_data = vec![];
75        for (_, c) in text.chars().enumerate() {
76            if let Some(mono) = DICT_MONO_CHARS.get(&c) {
77                pre_data.push(G2PWOut::Pinyin(mono.phone.clone()));
78            } else if let Some(poly) = DICT_POLY_CHARS.get(&c) {
79                pre_data.push(G2PWOut::Pinyin(poly.phones[0].0.clone()));
80            } else {
81                pre_data.push(G2PWOut::RawChar(c));
82            }
83        }
84        pre_data
85    }
86
87    fn get_pinyin_ml<'s>(&mut self, text: &'s str) -> Result<Vec<G2PWOut>, GSVError> {
88        let c = self.tokenizers.as_ref().unwrap().encode(text, true)?;
89        let input_ids = c.get_ids().iter().map(|x| *x as i64).collect::<Vec<i64>>();
90        let token_type_ids = vec![0i64; input_ids.len()];
91        let attention_mask = vec![1i64; input_ids.len()];
92
93        let mut phoneme_masks = vec![];
94        let mut pre_data = vec![];
95        let mut query_id = vec![];
96        let mut chars_id = vec![];
97
98        for (i, c) in text.chars().enumerate() {
99            if let Some(mono) = DICT_MONO_CHARS.get(&c) {
100                pre_data.push(G2PWOut::Pinyin(mono.phone.clone()));
101            } else if let Some(poly) = DICT_POLY_CHARS.get(&c) {
102                pre_data.push(G2PWOut::Pinyin("".to_owned()));
103                // 这个位置是 tokens 的位置,它的前后添加了 '[CLS]' 和 '[SEP]' 两个特殊字符
104                query_id.push(i + 1);
105                chars_id.push(poly.index);
106                let mut phoneme_mask = vec![0f32; POLY_LABLES.len()];
107                for (_, i) in &poly.phones {
108                    phoneme_mask[*i] = 1.0;
109                }
110                phoneme_masks.push(phoneme_mask);
111            } else {
112                pre_data.push(G2PWOut::RawChar(c));
113            }
114        }
115        let input_ids =
116            Tensor::from_array(Array::from_shape_vec((1, input_ids.len()), input_ids).unwrap())
117                .unwrap();
118        let token_type_ids = Tensor::from_array(
119            Array::from_shape_vec((1, token_type_ids.len()), token_type_ids).unwrap(),
120        )
121        .unwrap();
122        let attention_mask = Tensor::from_array(
123            Array::from_shape_vec((1, attention_mask.len()), attention_mask).unwrap(),
124        )
125        .unwrap();
126
127        for ((position_id, phoneme_mask), char_id) in query_id
128            .iter()
129            .zip(phoneme_masks.iter())
130            .zip(chars_id.iter())
131        {
132            let phoneme_mask = Tensor::from_array(
133                Array::from_shape_vec((1, phoneme_mask.len()), phoneme_mask.to_vec()).unwrap(),
134            )
135            .unwrap();
136            let position_id_t =
137                Tensor::from_array(Array::from_vec([*position_id as i64].to_vec())).unwrap();
138            let char_id = Tensor::from_array(Array::from_vec([*char_id as i64].to_vec())).unwrap();
139
140            let model_ouput = self.model.as_mut().unwrap().run(ort::inputs![
141                "input_ids" => input_ids.clone(),
142                "token_type_ids" => token_type_ids.clone(),
143                "attention_mask" => attention_mask.clone(),
144                "phoneme_mask"=> phoneme_mask,
145                "char_ids" => char_id,
146                "position_ids"=> position_id_t,
147            ])?;
148
149            let probs = model_ouput["probs"].try_extract_array::<f32>().unwrap();
150
151            let probs_view = probs.view();
152
153            let i = argmax_2d(&probs_view);
154
155            pre_data[*position_id - 1] = G2PWOut::Pinyin(POLY_LABLES[i.1 as usize].clone());
156        }
157
158        Ok(pre_data)
159    }
160}