gpt_sovits/text/
bert.rs

1use {
2    crate::{error::GSVError, onnx_builder::create_onnx_cpu_session, text::utils::BERT_TOKENIZER},
3    log::{debug, warn},
4    ndarray::{Array1, Array2, Axis, concatenate},
5    ort::{inputs, value::Tensor},
6    std::{path::Path, str::FromStr, sync::Arc},
7    tokenizers::Tokenizer,
8};
9
10#[derive(Debug)]
11pub struct BertModel {
12    model: Option<ort::session::Session>,
13    tokenizers: Option<Arc<tokenizers::Tokenizer>>,
14}
15
16impl BertModel {
17    pub fn new<P: AsRef<Path>>(path: Option<P>) -> Result<Self, GSVError> {
18        let mut model = None;
19        if let Some(path) = path {
20            model = Some(create_onnx_cpu_session(path)?);
21        }
22        Ok(Self {
23            model: model,
24            tokenizers: Some(Arc::new(Tokenizer::from_str(BERT_TOKENIZER).unwrap())),
25        })
26    }
27
28    pub fn get_bert(
29        &mut self,
30        text: &str,
31        word2ph: &[i32],
32        total_phones: usize,
33    ) -> Result<Array2<f32>, GSVError> {
34        if self.model.is_some() && self.tokenizers.is_some() {
35            let tmp = self.get_real_bert(text, word2ph)?;
36            debug!("use real bert, {}", text);
37            if tmp.shape()[0] != total_phones {
38                warn!(
39                    "tmp.shape()[0]: {} != total_phones: {}, use empty",
40                    tmp.shape()[0],
41                    total_phones
42                );
43                return Ok(self.get_fake_bert(total_phones));
44            }
45            Ok(tmp)
46        } else {
47            debug!("use empty bert, {}", text);
48            Ok(self.get_fake_bert(total_phones))
49        }
50    }
51
52    fn get_real_bert(&mut self, text: &str, word2ph: &[i32]) -> Result<Array2<f32>, GSVError> {
53        let tokenizer = self.tokenizers.as_ref().unwrap();
54        let session = self.model.as_mut().unwrap();
55
56        let encoding = tokenizer.encode(text, true).unwrap();
57        let (input_ids, attention_mask, token_type_ids): (Vec<i64>, Vec<i64>, Vec<i64>) = (
58            encoding.get_ids().iter().map(|&id| id as i64).collect(),
59            encoding
60                .get_attention_mask()
61                .iter()
62                .map(|&m| m as i64)
63                .collect(),
64            encoding.get_type_ids().iter().map(|&t| t as i64).collect(),
65        );
66
67        let inputs = inputs![
68            "input_ids" => Tensor::from_array(Array2::from_shape_vec((1, input_ids.len()), input_ids).unwrap()).unwrap(),
69            "attention_mask" => Tensor::from_array(Array2::from_shape_vec((1, attention_mask.len()), attention_mask).unwrap()).unwrap(),
70            "token_type_ids" => Tensor::from_array(Array2::from_shape_vec((1, token_type_ids.len()), token_type_ids).unwrap()).unwrap()
71        ];
72
73        let bert_out = session.run(inputs)?;
74        let bert_feature = bert_out["bert_feature"]
75            .try_extract_array::<f32>()?
76            .to_owned();
77
78        let bert_feature_2d: Array2<f32> = bert_feature.into_dimensionality()?;
79
80        Ok(build_phone_level_feature(
81            bert_feature_2d,
82            Array1::from_vec(word2ph.to_vec()),
83        ))
84    }
85
86    fn get_fake_bert(&self, total_phones: usize) -> Array2<f32> {
87        // The BERT model outputs features of size 1024
88        Array2::<f32>::zeros((total_phones, 1024))
89    }
90}
91
92// Helper function to expand word-level features to phone-level features.
93// This function is required by get_real_bert.
94fn build_phone_level_feature(res: Array2<f32>, word2ph: Array1<i32>) -> Array2<f32> {
95    let phone_level_features = word2ph
96        .into_iter()
97        .enumerate()
98        .map(|(i, count)| {
99            if i < res.dim().0 {
100                let row = res.row(i);
101                Array2::from_shape_fn((count as usize, res.ncols()), |(_j, k)| row[k])
102            } else {
103                // If word2ph has more elements than res rows, duplicate the last feature.
104                let last_row = res.row(res.dim().0 - 1);
105                Array2::from_shape_fn((count as usize, res.ncols()), |(_j, k)| last_row[k])
106            }
107        })
108        .collect::<Vec<_>>();
109
110    concatenate(
111        Axis(0),
112        &phone_level_features
113            .iter()
114            .map(|x| x.view())
115            .collect::<Vec<_>>(),
116    )
117    .unwrap()
118}