gpt_sovits/text/zh/
g2pw.rs1use {
2 crate::{
3 error::GSVError,
4 onnx_builder::create_onnx_cpu_session,
5 text::{BERT_TOKENIZER, DICT_MONO_CHARS, DICT_POLY_CHARS, argmax_2d},
6 },
7 ndarray::Array,
8 ort::value::Tensor,
9 std::{
10 fmt::Debug,
11 path::Path,
12 str::FromStr,
13 sync::{Arc, LazyLock},
14 },
15 tokenizers::Tokenizer,
16};
17
18pub static LABELS: &str = include_str!("dict_poly_index_list.json");
19
20pub static POLY_LABLES: LazyLock<Vec<String>> =
21 LazyLock::new(|| serde_json::from_str(LABELS).unwrap());
22
23#[derive(Clone)]
24pub enum G2PWOut {
25 Pinyin(String),
26 Yue(String),
27 RawChar(char),
28}
29
30impl Debug for G2PWOut {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 Self::Pinyin(s) => write!(f, "\"{}\"", s),
34 Self::Yue(s) => write!(f, "\"{}\"", s),
35 Self::RawChar(s) => write!(f, "\"{}\"", s),
36 }
37 }
38}
39
40#[derive(Debug)]
41pub struct G2PW {
42 model: Option<ort::session::Session>,
43 tokenizers: Option<Arc<tokenizers::Tokenizer>>,
44}
45
46impl G2PW {
47 pub fn new<P: AsRef<Path>>(g2pw_path: Option<P>) -> Result<Self, GSVError> {
48 if let Some(g2pw_path) = g2pw_path {
49 log::info!("G2PW model is loading...");
50 let model = create_onnx_cpu_session(g2pw_path)?;
51 log::info!("G2PW model is loaded.");
52 Ok(Self {
53 model: Some(model),
54 tokenizers: Some(Arc::new(Tokenizer::from_str(BERT_TOKENIZER).unwrap())),
55 })
56 } else {
57 Ok(Self {
58 model: None,
59 tokenizers: None,
60 })
61 }
62 }
63
64 pub fn g2p<'s>(&mut self, text: &'s str) -> Vec<G2PWOut> {
65 if self.model.is_some() && self.tokenizers.is_some() {
66 self.get_pinyin_ml(text)
67 .unwrap_or(self.simple_get_pinyin(text))
68 } else {
69 self.simple_get_pinyin(text)
70 }
71 }
72
73 pub fn simple_get_pinyin(&self, text: &str) -> Vec<G2PWOut> {
74 let mut pre_data = vec![];
75 for (_, c) in text.chars().enumerate() {
76 if let Some(mono) = DICT_MONO_CHARS.get(&c) {
77 pre_data.push(G2PWOut::Pinyin(mono.phone.clone()));
78 } else if let Some(poly) = DICT_POLY_CHARS.get(&c) {
79 pre_data.push(G2PWOut::Pinyin(poly.phones[0].0.clone()));
80 } else {
81 pre_data.push(G2PWOut::RawChar(c));
82 }
83 }
84 pre_data
85 }
86
87 fn get_pinyin_ml<'s>(&mut self, text: &'s str) -> Result<Vec<G2PWOut>, GSVError> {
88 let c = self.tokenizers.as_ref().unwrap().encode(text, true)?;
89 let input_ids = c.get_ids().iter().map(|x| *x as i64).collect::<Vec<i64>>();
90 let token_type_ids = vec![0i64; input_ids.len()];
91 let attention_mask = vec![1i64; input_ids.len()];
92
93 let mut phoneme_masks = vec![];
94 let mut pre_data = vec![];
95 let mut query_id = vec![];
96 let mut chars_id = vec![];
97
98 for (i, c) in text.chars().enumerate() {
99 if let Some(mono) = DICT_MONO_CHARS.get(&c) {
100 pre_data.push(G2PWOut::Pinyin(mono.phone.clone()));
101 } else if let Some(poly) = DICT_POLY_CHARS.get(&c) {
102 pre_data.push(G2PWOut::Pinyin("".to_owned()));
103 query_id.push(i + 1);
105 chars_id.push(poly.index);
106 let mut phoneme_mask = vec![0f32; POLY_LABLES.len()];
107 for (_, i) in &poly.phones {
108 phoneme_mask[*i] = 1.0;
109 }
110 phoneme_masks.push(phoneme_mask);
111 } else {
112 pre_data.push(G2PWOut::RawChar(c));
113 }
114 }
115 let input_ids =
116 Tensor::from_array(Array::from_shape_vec((1, input_ids.len()), input_ids).unwrap())
117 .unwrap();
118 let token_type_ids = Tensor::from_array(
119 Array::from_shape_vec((1, token_type_ids.len()), token_type_ids).unwrap(),
120 )
121 .unwrap();
122 let attention_mask = Tensor::from_array(
123 Array::from_shape_vec((1, attention_mask.len()), attention_mask).unwrap(),
124 )
125 .unwrap();
126
127 for ((position_id, phoneme_mask), char_id) in query_id
128 .iter()
129 .zip(phoneme_masks.iter())
130 .zip(chars_id.iter())
131 {
132 let phoneme_mask = Tensor::from_array(
133 Array::from_shape_vec((1, phoneme_mask.len()), phoneme_mask.to_vec()).unwrap(),
134 )
135 .unwrap();
136 let position_id_t =
137 Tensor::from_array(Array::from_vec([*position_id as i64].to_vec())).unwrap();
138 let char_id = Tensor::from_array(Array::from_vec([*char_id as i64].to_vec())).unwrap();
139
140 let model_ouput = self.model.as_mut().unwrap().run(ort::inputs![
141 "input_ids" => input_ids.clone(),
142 "token_type_ids" => token_type_ids.clone(),
143 "attention_mask" => attention_mask.clone(),
144 "phoneme_mask"=> phoneme_mask,
145 "char_ids" => char_id,
146 "position_ids"=> position_id_t,
147 ])?;
148
149 let probs = model_ouput["probs"].try_extract_array::<f32>().unwrap();
150
151 let probs_view = probs.view();
152
153 let i = argmax_2d(&probs_view);
154
155 pre_data[*position_id - 1] = G2PWOut::Pinyin(POLY_LABLES[i.1 as usize].clone());
156 }
157
158 Ok(pre_data)
159 }
160}