gpt_sovits/text/en/
g2p_en.rs1use {
2 crate::{error::GSVError, onnx_builder::create_onnx_cpu_session, text::dict},
3 arpabet::Arpabet,
4 log::debug,
5 ndarray::{Array, s},
6 ort::{inputs, session::Session, value::Tensor},
7 std::{path::Path, str::FromStr},
8 tokenizers::Tokenizer,
9};
10
11static MINI_BART_G2P_TOKENIZER: &str = include_str!("tokenizer.mini-bart-g2p.json");
12
13static DECODER_START_TOKEN_ID: u32 = 2;
14
15#[allow(unused)]
16static BOS_TOKEN: &str = "<s>";
17#[allow(unused)]
18static EOS_TOKEN: &str = "</s>";
19
20#[allow(unused)]
21static BOS_TOKEN_ID: u32 = 0;
22static EOS_TOKEN_ID: u32 = 2;
23
24pub struct G2PEnModel {
25 encoder_model: Session,
26 decoder_model: Session,
27 tokenizer: Tokenizer,
28}
29
30impl G2PEnModel {
31 pub fn new<P: AsRef<Path>>(encoder_path: P, decoder_path: P) -> Result<Self, GSVError> {
32 let encoder_model = create_onnx_cpu_session(encoder_path)?;
33 let decoder_model = create_onnx_cpu_session(decoder_path)?;
34 let tokenizer = Tokenizer::from_str(MINI_BART_G2P_TOKENIZER)?;
35
36 Ok(Self {
37 encoder_model,
38 decoder_model,
39 tokenizer,
40 })
41 }
42
43 pub fn get_phoneme(&mut self, text: &str) -> Result<Vec<String>, GSVError> {
44 debug!("processing {:?}", text);
45 let encoding = self.tokenizer.encode(text, true)?;
46 let input_ids = encoding
47 .get_ids()
48 .iter()
49 .map(|x| *x as i64)
50 .collect::<Vec<i64>>();
51 let mut decoder_input_ids = vec![DECODER_START_TOKEN_ID as i64];
52
53 let input_id_len = input_ids.len();
54 let input_ids_tensor =
55 Tensor::from_array(Array::from_shape_vec((1, input_id_len), input_ids.clone())?)?;
56 let attention_mask_tensor =
57 Tensor::from_array(Array::from_elem((1, input_id_len), 1 as i64))?;
58 let encoder_outputs = self.encoder_model.run(inputs![
59 "input_ids" => input_ids_tensor.clone(),
60 "attention_mask" => attention_mask_tensor.clone()
61 ])?;
62
63 for _ in 0..50 {
64 let encoder_output = encoder_outputs["last_hidden_state"].view();
68
69 let decoder_input_ids_tensor = Tensor::from_array(Array::from_shape_vec(
70 (1, decoder_input_ids.len()),
71 decoder_input_ids.clone(),
72 )?)?;
73
74 let outputs = self.decoder_model.run(inputs![
75 "input_ids" => decoder_input_ids_tensor,
76 "encoder_attention_mask" => attention_mask_tensor.clone(),
77 "encoder_hidden_states" => encoder_output,
78 ])?;
79
80 let output_array = outputs["logits"].try_extract_array::<f32>()?;
81
82 let last_token_logits = &output_array.slice(s![0, output_array.shape()[1] - 1, ..]);
84
85 let next_token_id = last_token_logits
87 .iter()
88 .enumerate()
89 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
90 .map(|(i, _)| i as i64)
91 .ok_or(GSVError::DecodeTokenFailed)?;
92
93 decoder_input_ids.push(next_token_id);
94 if next_token_id == EOS_TOKEN_ID as i64 {
95 break;
96 }
97 }
98
99 let decoder_input_ids = decoder_input_ids
100 .iter()
101 .map(|x| *x as u32)
102 .collect::<Vec<u32>>();
103 Ok(self
104 .tokenizer
105 .decode(&decoder_input_ids, true)?
106 .split(" ")
107 .map(|v| v.to_owned())
108 .collect::<Vec<String>>())
109 }
110}
111
112pub struct G2pEn {
113 model: Option<G2PEnModel>,
114 arpabet: Arpabet,
115}
116
117impl G2pEn {
118 pub fn new<P: AsRef<Path>>(path: Option<P>) -> Result<Self, GSVError> {
119 let arpabet = arpabet::load_cmudict().clone();
120 if let Some(path) = path {
121 let path = path.as_ref();
122 Ok(G2pEn {
123 model: Some(G2PEnModel::new(
124 path.join("encoder_model.onnx"),
125 path.join("decoder_model.onnx"),
126 )?),
127 arpabet: arpabet,
128 })
129 } else {
130 Ok(G2pEn {
131 model: None,
132 arpabet: arpabet,
133 })
134 }
135 }
136
137 pub fn g2p(&mut self, text: &str) -> Result<Vec<String>, GSVError> {
138 if let Some(v) = dict::en_word_dict(text) {
139 return Ok(v.to_owned());
140 }
141 match &mut self.model {
142 Some(model) => {
143 let words = text.split_whitespace();
144 let mut phonemes = Vec::new();
145 for word in words {
146 let phones = model.get_phoneme(word)?;
147 phonemes.extend(phones.into_iter());
148 }
149 Ok(phonemes)
150 }
151 None => {
152 let words = text.split_whitespace();
154 let mut phonemes = Vec::new();
155 for word in words {
156 if let Some(phones) = self.arpabet.get_polyphone_str(word) {
157 phonemes.extend(phones.iter().map(|&p| p.to_string()));
158 } else {
159 for c in word.chars() {
161 let c_str = c.to_string();
162 if let Some(phones) = self.arpabet.get_polyphone_str(&c_str) {
163 phonemes.extend(phones.iter().map(|&p| p.to_string()));
164 } else {
165 phonemes.push(c_str);
166 }
167 }
168 }
169 }
170 Ok(phonemes)
171 }
172 }
173 }
174}