active_call/offline/sensevoice/
encoder.rs1use super::tokenizer::TokenDecoder;
2use anyhow::{Context, Result, anyhow, ensure};
3use ndarray::{Array3, Axis};
4use ort::{
5 session::{Session, builder::GraphOptimizationLevel},
6 value::{DynValue, Tensor},
7};
8use std::{fs, path::Path};
9use tracing::warn;
10
11pub struct SensevoiceEncoder {
12 session: Session,
13 input_names: Vec<String>,
14 decoder: TokenDecoder,
15}
16
17impl SensevoiceEncoder {
18 pub fn new<P: AsRef<Path>>(
19 model_path: P,
20 tokens_path: P,
21 intra_threads: usize,
22 ) -> Result<Self> {
23 let session = build_session_with_ort_cache(model_path.as_ref(), intra_threads)?;
24 let input_names = vec![
26 "x".to_string(),
27 "x_length".to_string(),
28 "language".to_string(),
29 "text_norm".to_string(),
30 ];
31 let decoder = TokenDecoder::new(tokens_path)?;
32
33 Ok(Self {
34 session,
35 input_names,
36 decoder,
37 })
38 }
39
40 pub fn run_and_decode(
41 &mut self,
42 feats: ndarray::ArrayView3<'_, f32>, language_id: i32,
44 use_itn: bool,
45 ) -> Result<String> {
46 let b = feats.len_of(Axis(0));
47 ensure!(b == 1, "batch=1 only");
48 let t = feats.len_of(Axis(1));
49 let d = feats.len_of(Axis(2));
50 ensure!(d == 560, "expect feature dim 560 but got {}", d);
51 ensure!(
52 language_id >= 0 && language_id < 16,
53 "invalid language id {language_id}"
54 );
55
56 let text_norm_idx = if use_itn { 14 } else { 15 };
57
58 let feats_owned = feats.to_owned();
60 let shape = feats_owned.shape().to_vec();
61 let (data, _offset) = feats_owned.into_raw_vec_and_offset();
62 let input_tensor = Tensor::from_array((shape.as_slice(), data))
63 .map_err(|e| anyhow!("ORT tensor error: {e}"))?;
64
65 let len_tensor = Tensor::from_array(([1], vec![t as i32]))
66 .map_err(|e| anyhow!("ORT tensor error: {e}"))?;
67
68 let lang_tensor = Tensor::from_array(([1], vec![language_id]))
69 .map_err(|e| anyhow!("ORT tensor error: {e}"))?;
70
71 let tn_tensor = Tensor::from_array(([1], vec![text_norm_idx as i32]))
72 .map_err(|e| anyhow!("ORT tensor error: {e}"))?;
73
74 let mut x_val = Some(input_tensor.into_dyn());
75 let mut len_val = Some(len_tensor.into_dyn());
76 let mut lang_val = Some(lang_tensor.into_dyn());
77 let mut tn_val = Some(tn_tensor.into_dyn());
78
79 let mut inputs: Vec<(String, DynValue)> = Vec::with_capacity(self.input_names.len());
80 for name in &self.input_names {
81 let value = match name.as_str() {
82 "x" => x_val
83 .take()
84 .ok_or_else(|| anyhow!("duplicate tensor binding for input 'x'"))?,
85 "x_length" => len_val
86 .take()
87 .ok_or_else(|| anyhow!("duplicate tensor binding for input 'x_length'"))?,
88 "language" => lang_val
89 .take()
90 .ok_or_else(|| anyhow!("duplicate tensor binding for input 'language'"))?,
91 "text_norm" => tn_val
92 .take()
93 .ok_or_else(|| anyhow!("duplicate tensor binding for input 'text_norm'"))?,
94 other => anyhow::bail!("unexpected encoder input '{other}'"),
95 };
96 inputs.push((name.clone(), value));
97 }
98
99 let outputs = self
100 .session
101 .run(inputs)
102 .map_err(|e| anyhow!("ORT run error: {e}"))?;
103 let logits_value = &outputs[0];
104 let (shape, data) = logits_value
105 .try_extract_tensor::<f32>()
106 .map_err(|e| anyhow!("ORT extract tensor error: {e}"))?;
107 let dims: Vec<usize> = shape.iter().map(|d| *d as usize).collect();
108 ensure!(dims.len() == 3, "unexpected logits rank: {:?}", dims);
109 ensure!(dims[0] == 1, "expect batch=1 but got {}", dims[0]);
110 let logits = Array3::from_shape_vec((dims[0], dims[1], dims[2]), data.to_vec())?;
111 let ids = argmax_and_unique(logits.index_axis(Axis(0), 0));
112 Ok(self.decoder.decode_ids(&ids))
113 }
114}
115
116fn build_session_with_ort_cache(model_path: &Path, intra_threads: usize) -> Result<Session> {
117 let ort_path = model_path.with_extension("ort");
118
119 if ort_path.exists() {
120 let session_attempt = Session::builder()
121 .map_err(|e| anyhow!("ORT session builder error: {e}"))?
122 .with_intra_threads(intra_threads)
123 .map_err(|e| anyhow!("ORT intra threads error: {e}"))?
124 .commit_from_file(&ort_path);
125
126 match session_attempt {
127 Ok(session) => return Ok(session),
128 Err(err) => {
129 warn!(
130 ort = %ort_path.display(),
131 model = %model_path.display(),
132 error = %err,
133 "failed to load cached ORT graph, regenerating"
134 );
135 let _ = fs::remove_file(&ort_path);
136 }
137 }
138 }
139
140 let builder = Session::builder()
141 .map_err(|e| anyhow!("ORT session builder error: {e}"))?
142 .with_optimization_level(GraphOptimizationLevel::Level2)
143 .map_err(|e| anyhow!("ORT optimization level error: {e}"))?
144 .with_intra_threads(intra_threads)
145 .map_err(|e| anyhow!("ORT intra threads error: {e}"))?;
146
147 if let Ok(builder_with_cache) = builder.with_optimized_model_path(&ort_path) {
148 match builder_with_cache.commit_from_file(model_path) {
149 Ok(session) => return Ok(session),
150 Err(err) => {
151 warn!(
152 ort = %ort_path.display(),
153 model = %model_path.display(),
154 error = %err,
155 "failed to build session with ORT cache, retrying without cache"
156 );
157 let _ = fs::remove_file(&ort_path);
158 }
159 }
160 }
161
162 let fallback_builder = Session::builder()
163 .map_err(|e| anyhow!("ORT session builder error: {e}"))?
164 .with_optimization_level(GraphOptimizationLevel::Level2)
165 .map_err(|e| anyhow!("ORT optimization level error: {e}"))?
166 .with_intra_threads(intra_threads)
167 .map_err(|e| anyhow!("ORT intra threads error: {e}"))?;
168
169 let model_bytes = fs::read(model_path)
170 .with_context(|| format!("read encoder model {}", model_path.display()))?;
171 fallback_builder
172 .commit_from_memory(&model_bytes)
173 .map_err(|e| anyhow!("ORT load model error: {e}"))
174}
175
176fn argmax_and_unique(logits: ndarray::ArrayView2<'_, f32>) -> Vec<i32> {
177 let blank_id = 0i32;
178 let mut prev: Option<i32> = None;
179 let mut out = Vec::new();
180 for t in 0..logits.len_of(Axis(0)) {
181 let row = logits.index_axis(Axis(0), t);
182 let mut maxv = f32::MIN;
183 let mut arg = 0i32;
184 for (i, v) in row.iter().enumerate() {
185 if *v > maxv {
186 maxv = *v;
187 arg = i as i32;
188 }
189 }
190 if Some(arg) != prev {
191 if arg != blank_id {
192 out.push(arg);
193 }
194 prev = Some(arg);
195 }
196 }
197 out
198}