gpt_sovits/text/en/
g2p_en.rs

1use {
2    crate::{error::GSVError, onnx_builder::create_onnx_cpu_session, text::dict},
3    arpabet::Arpabet,
4    log::debug,
5    ndarray::{Array, s},
6    ort::{inputs, session::Session, value::Tensor},
7    std::{path::Path, str::FromStr},
8    tokenizers::Tokenizer,
9};
10
11static MINI_BART_G2P_TOKENIZER: &str = include_str!("tokenizer.mini-bart-g2p.json");
12
13static DECODER_START_TOKEN_ID: u32 = 2;
14
15#[allow(unused)]
16static BOS_TOKEN: &str = "<s>";
17#[allow(unused)]
18static EOS_TOKEN: &str = "</s>";
19
20#[allow(unused)]
21static BOS_TOKEN_ID: u32 = 0;
22static EOS_TOKEN_ID: u32 = 2;
23
24pub struct G2PEnModel {
25    encoder_model: Session,
26    decoder_model: Session,
27    tokenizer: Tokenizer,
28}
29
30impl G2PEnModel {
31    pub fn new<P: AsRef<Path>>(encoder_path: P, decoder_path: P) -> Result<Self, GSVError> {
32        let encoder_model = create_onnx_cpu_session(encoder_path)?;
33        let decoder_model = create_onnx_cpu_session(decoder_path)?;
34        let tokenizer = Tokenizer::from_str(MINI_BART_G2P_TOKENIZER)?;
35
36        Ok(Self {
37            encoder_model,
38            decoder_model,
39            tokenizer,
40        })
41    }
42
43    pub fn get_phoneme(&mut self, text: &str) -> Result<Vec<String>, GSVError> {
44        debug!("processing {:?}", text);
45        let encoding = self.tokenizer.encode(text, true)?;
46        let input_ids = encoding
47            .get_ids()
48            .iter()
49            .map(|x| *x as i64)
50            .collect::<Vec<i64>>();
51        let mut decoder_input_ids = vec![DECODER_START_TOKEN_ID as i64];
52
53        let input_id_len = input_ids.len();
54        let input_ids_tensor =
55            Tensor::from_array(Array::from_shape_vec((1, input_id_len), input_ids.clone())?)?;
56        let attention_mask_tensor =
57            Tensor::from_array(Array::from_elem((1, input_id_len), 1 as i64))?;
58        let encoder_outputs = self.encoder_model.run(inputs![
59            "input_ids" => input_ids_tensor.clone(),
60            "attention_mask" => attention_mask_tensor.clone()
61        ])?;
62
63        for _ in 0..50 {
64            // Prepare input tensors
65            // Run inference
66
67            let encoder_output = encoder_outputs["last_hidden_state"].view();
68
69            let decoder_input_ids_tensor = Tensor::from_array(Array::from_shape_vec(
70                (1, decoder_input_ids.len()),
71                decoder_input_ids.clone(),
72            )?)?;
73
74            let outputs = self.decoder_model.run(inputs![
75                "input_ids" => decoder_input_ids_tensor,
76                "encoder_attention_mask" => attention_mask_tensor.clone(),
77                "encoder_hidden_states" => encoder_output,
78            ])?;
79
80            let output_array = outputs["logits"].try_extract_array::<f32>()?;
81
82            // Get the last token's logits
83            let last_token_logits = &output_array.slice(s![0, output_array.shape()[1] - 1, ..]);
84
85            // Find the argmax
86            let next_token_id = last_token_logits
87                .iter()
88                .enumerate()
89                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
90                .map(|(i, _)| i as i64)
91                .ok_or(GSVError::DecodeTokenFailed)?;
92
93            decoder_input_ids.push(next_token_id);
94            if next_token_id == EOS_TOKEN_ID as i64 {
95                break;
96            }
97        }
98
99        let decoder_input_ids = decoder_input_ids
100            .iter()
101            .map(|x| *x as u32)
102            .collect::<Vec<u32>>();
103        Ok(self
104            .tokenizer
105            .decode(&decoder_input_ids, true)?
106            .split(" ")
107            .map(|v| v.to_owned())
108            .collect::<Vec<String>>())
109    }
110}
111
112pub struct G2pEn {
113    model: Option<G2PEnModel>,
114    arpabet: Arpabet,
115}
116
117impl G2pEn {
118    pub fn new<P: AsRef<Path>>(path: Option<P>) -> Result<Self, GSVError> {
119        let arpabet = arpabet::load_cmudict().clone();
120        if let Some(path) = path {
121            let path = path.as_ref();
122            Ok(G2pEn {
123                model: Some(G2PEnModel::new(
124                    path.join("encoder_model.onnx"),
125                    path.join("decoder_model.onnx"),
126                )?),
127                arpabet: arpabet,
128            })
129        } else {
130            Ok(G2pEn {
131                model: None,
132                arpabet: arpabet,
133            })
134        }
135    }
136
137    pub fn g2p(&mut self, text: &str) -> Result<Vec<String>, GSVError> {
138        if let Some(v) = dict::en_word_dict(text) {
139            return Ok(v.to_owned());
140        }
141        match &mut self.model {
142            Some(model) => {
143                let words = text.split_whitespace();
144                let mut phonemes = Vec::new();
145                for word in words {
146                    let phones = model.get_phoneme(word)?;
147                    phonemes.extend(phones.into_iter());
148                }
149                Ok(phonemes)
150            }
151            None => {
152                // Split text into words and process each with Arpabet
153                let words = text.split_whitespace();
154                let mut phonemes = Vec::new();
155                for word in words {
156                    if let Some(phones) = self.arpabet.get_polyphone_str(word) {
157                        phonemes.extend(phones.iter().map(|&p| p.to_string()));
158                    } else {
159                        // Fallback to character-level processing
160                        for c in word.chars() {
161                            let c_str = c.to_string();
162                            if let Some(phones) = self.arpabet.get_polyphone_str(&c_str) {
163                                phonemes.extend(phones.iter().map(|&p| p.to_string()));
164                            } else {
165                                phonemes.push(c_str);
166                            }
167                        }
168                    }
169                }
170                Ok(phonemes)
171            }
172        }
173    }
174}