1use {
2 crate::{error::GSVError, onnx_builder::create_onnx_cpu_session, text::utils::BERT_TOKENIZER},
3 log::{debug, warn},
4 ndarray::{Array1, Array2, Axis, concatenate},
5 ort::{inputs, value::Tensor},
6 std::{path::Path, str::FromStr, sync::Arc},
7 tokenizers::Tokenizer,
8};
9
10#[derive(Debug)]
11pub struct BertModel {
12 model: Option<ort::session::Session>,
13 tokenizers: Option<Arc<tokenizers::Tokenizer>>,
14}
15
16impl BertModel {
17 pub fn new<P: AsRef<Path>>(path: Option<P>) -> Result<Self, GSVError> {
18 let mut model = None;
19 if let Some(path) = path {
20 model = Some(create_onnx_cpu_session(path)?);
21 }
22 Ok(Self {
23 model: model,
24 tokenizers: Some(Arc::new(Tokenizer::from_str(BERT_TOKENIZER).unwrap())),
25 })
26 }
27
28 pub fn get_bert(
29 &mut self,
30 text: &str,
31 word2ph: &[i32],
32 total_phones: usize,
33 ) -> Result<Array2<f32>, GSVError> {
34 if self.model.is_some() && self.tokenizers.is_some() {
35 let tmp = self.get_real_bert(text, word2ph)?;
36 debug!("use real bert, {}", text);
37 if tmp.shape()[0] != total_phones {
38 warn!(
39 "tmp.shape()[0]: {} != total_phones: {}, use empty",
40 tmp.shape()[0],
41 total_phones
42 );
43 return Ok(self.get_fake_bert(total_phones));
44 }
45 Ok(tmp)
46 } else {
47 debug!("use empty bert, {}", text);
48 Ok(self.get_fake_bert(total_phones))
49 }
50 }
51
52 fn get_real_bert(&mut self, text: &str, word2ph: &[i32]) -> Result<Array2<f32>, GSVError> {
53 let tokenizer = self.tokenizers.as_ref().unwrap();
54 let session = self.model.as_mut().unwrap();
55
56 let encoding = tokenizer.encode(text, true).unwrap();
57 let (input_ids, attention_mask, token_type_ids): (Vec<i64>, Vec<i64>, Vec<i64>) = (
58 encoding.get_ids().iter().map(|&id| id as i64).collect(),
59 encoding
60 .get_attention_mask()
61 .iter()
62 .map(|&m| m as i64)
63 .collect(),
64 encoding.get_type_ids().iter().map(|&t| t as i64).collect(),
65 );
66
67 let inputs = inputs![
68 "input_ids" => Tensor::from_array(Array2::from_shape_vec((1, input_ids.len()), input_ids).unwrap()).unwrap(),
69 "attention_mask" => Tensor::from_array(Array2::from_shape_vec((1, attention_mask.len()), attention_mask).unwrap()).unwrap(),
70 "token_type_ids" => Tensor::from_array(Array2::from_shape_vec((1, token_type_ids.len()), token_type_ids).unwrap()).unwrap()
71 ];
72
73 let bert_out = session.run(inputs)?;
74 let bert_feature = bert_out["bert_feature"]
75 .try_extract_array::<f32>()?
76 .to_owned();
77
78 let bert_feature_2d: Array2<f32> = bert_feature.into_dimensionality()?;
79
80 Ok(build_phone_level_feature(
81 bert_feature_2d,
82 Array1::from_vec(word2ph.to_vec()),
83 ))
84 }
85
86 fn get_fake_bert(&self, total_phones: usize) -> Array2<f32> {
87 Array2::<f32>::zeros((total_phones, 1024))
89 }
90}
91
92fn build_phone_level_feature(res: Array2<f32>, word2ph: Array1<i32>) -> Array2<f32> {
95 let phone_level_features = word2ph
96 .into_iter()
97 .enumerate()
98 .map(|(i, count)| {
99 if i < res.dim().0 {
100 let row = res.row(i);
101 Array2::from_shape_fn((count as usize, res.ncols()), |(_j, k)| row[k])
102 } else {
103 let last_row = res.row(res.dim().0 - 1);
105 Array2::from_shape_fn((count as usize, res.ncols()), |(_j, k)| last_row[k])
106 }
107 })
108 .collect::<Vec<_>>();
109
110 concatenate(
111 Axis(0),
112 &phone_level_features
113 .iter()
114 .map(|x| x.view())
115 .collect::<Vec<_>>(),
116 )
117 .unwrap()
118}