active_call/offline/sensevoice/
encoder.rs

1use super::tokenizer::TokenDecoder;
2use anyhow::{Context, Result, anyhow, ensure};
3use ndarray::{Array3, Axis};
4use ort::{
5    session::{Session, builder::GraphOptimizationLevel},
6    value::{DynValue, Tensor},
7};
8use std::{fs, path::Path};
9use tracing::warn;
10
11pub struct SensevoiceEncoder {
12    session: Session,
13    input_names: Vec<String>,
14    decoder: TokenDecoder,
15}
16
17impl SensevoiceEncoder {
18    pub fn new<P: AsRef<Path>>(
19        model_path: P,
20        tokens_path: P,
21        intra_threads: usize,
22    ) -> Result<Self> {
23        let session = build_session_with_ort_cache(model_path.as_ref(), intra_threads)?;
24        // Use fixed input names for SenseVoice model
25        let input_names = vec![
26            "x".to_string(),
27            "x_length".to_string(),
28            "language".to_string(),
29            "text_norm".to_string(),
30        ];
31        let decoder = TokenDecoder::new(tokens_path)?;
32
33        Ok(Self {
34            session,
35            input_names,
36            decoder,
37        })
38    }
39
40    pub fn run_and_decode(
41        &mut self,
42        feats: ndarray::ArrayView3<'_, f32>, // [B=1, T, D]
43        language_id: i32,
44        use_itn: bool,
45    ) -> Result<String> {
46        let b = feats.len_of(Axis(0));
47        ensure!(b == 1, "batch=1 only");
48        let t = feats.len_of(Axis(1));
49        let d = feats.len_of(Axis(2));
50        ensure!(d == 560, "expect feature dim 560 but got {}", d);
51        ensure!(
52            language_id >= 0 && language_id < 16,
53            "invalid language id {language_id}"
54        );
55
56        let text_norm_idx = if use_itn { 14 } else { 15 };
57
58        // Create tensors using (shape, data) tuple format
59        let feats_owned = feats.to_owned();
60        let shape = feats_owned.shape().to_vec();
61        let (data, _offset) = feats_owned.into_raw_vec_and_offset();
62        let input_tensor = Tensor::from_array((shape.as_slice(), data))
63            .map_err(|e| anyhow!("ORT tensor error: {e}"))?;
64
65        let len_tensor = Tensor::from_array(([1], vec![t as i32]))
66            .map_err(|e| anyhow!("ORT tensor error: {e}"))?;
67
68        let lang_tensor = Tensor::from_array(([1], vec![language_id]))
69            .map_err(|e| anyhow!("ORT tensor error: {e}"))?;
70
71        let tn_tensor = Tensor::from_array(([1], vec![text_norm_idx as i32]))
72            .map_err(|e| anyhow!("ORT tensor error: {e}"))?;
73
74        let mut x_val = Some(input_tensor.into_dyn());
75        let mut len_val = Some(len_tensor.into_dyn());
76        let mut lang_val = Some(lang_tensor.into_dyn());
77        let mut tn_val = Some(tn_tensor.into_dyn());
78
79        let mut inputs: Vec<(String, DynValue)> = Vec::with_capacity(self.input_names.len());
80        for name in &self.input_names {
81            let value = match name.as_str() {
82                "x" => x_val
83                    .take()
84                    .ok_or_else(|| anyhow!("duplicate tensor binding for input 'x'"))?,
85                "x_length" => len_val
86                    .take()
87                    .ok_or_else(|| anyhow!("duplicate tensor binding for input 'x_length'"))?,
88                "language" => lang_val
89                    .take()
90                    .ok_or_else(|| anyhow!("duplicate tensor binding for input 'language'"))?,
91                "text_norm" => tn_val
92                    .take()
93                    .ok_or_else(|| anyhow!("duplicate tensor binding for input 'text_norm'"))?,
94                other => anyhow::bail!("unexpected encoder input '{other}'"),
95            };
96            inputs.push((name.clone(), value));
97        }
98
99        let outputs = self
100            .session
101            .run(inputs)
102            .map_err(|e| anyhow!("ORT run error: {e}"))?;
103        let logits_value = &outputs[0];
104        let (shape, data) = logits_value
105            .try_extract_tensor::<f32>()
106            .map_err(|e| anyhow!("ORT extract tensor error: {e}"))?;
107        let dims: Vec<usize> = shape.iter().map(|d| *d as usize).collect();
108        ensure!(dims.len() == 3, "unexpected logits rank: {:?}", dims);
109        ensure!(dims[0] == 1, "expect batch=1 but got {}", dims[0]);
110        let logits = Array3::from_shape_vec((dims[0], dims[1], dims[2]), data.to_vec())?;
111        let ids = argmax_and_unique(logits.index_axis(Axis(0), 0));
112        Ok(self.decoder.decode_ids(&ids))
113    }
114}
115
116fn build_session_with_ort_cache(model_path: &Path, intra_threads: usize) -> Result<Session> {
117    let ort_path = model_path.with_extension("ort");
118
119    if ort_path.exists() {
120        let session_attempt = Session::builder()
121            .map_err(|e| anyhow!("ORT session builder error: {e}"))?
122            .with_intra_threads(intra_threads)
123            .map_err(|e| anyhow!("ORT intra threads error: {e}"))?
124            .commit_from_file(&ort_path);
125
126        match session_attempt {
127            Ok(session) => return Ok(session),
128            Err(err) => {
129                warn!(
130                    ort = %ort_path.display(),
131                    model = %model_path.display(),
132                    error = %err,
133                    "failed to load cached ORT graph, regenerating"
134                );
135                let _ = fs::remove_file(&ort_path);
136            }
137        }
138    }
139
140    let builder = Session::builder()
141        .map_err(|e| anyhow!("ORT session builder error: {e}"))?
142        .with_optimization_level(GraphOptimizationLevel::Level2)
143        .map_err(|e| anyhow!("ORT optimization level error: {e}"))?
144        .with_intra_threads(intra_threads)
145        .map_err(|e| anyhow!("ORT intra threads error: {e}"))?;
146
147    if let Ok(builder_with_cache) = builder.with_optimized_model_path(&ort_path) {
148        match builder_with_cache.commit_from_file(model_path) {
149            Ok(session) => return Ok(session),
150            Err(err) => {
151                warn!(
152                    ort = %ort_path.display(),
153                    model = %model_path.display(),
154                    error = %err,
155                    "failed to build session with ORT cache, retrying without cache"
156                );
157                let _ = fs::remove_file(&ort_path);
158            }
159        }
160    }
161
162    let fallback_builder = Session::builder()
163        .map_err(|e| anyhow!("ORT session builder error: {e}"))?
164        .with_optimization_level(GraphOptimizationLevel::Level2)
165        .map_err(|e| anyhow!("ORT optimization level error: {e}"))?
166        .with_intra_threads(intra_threads)
167        .map_err(|e| anyhow!("ORT intra threads error: {e}"))?;
168
169    let model_bytes = fs::read(model_path)
170        .with_context(|| format!("read encoder model {}", model_path.display()))?;
171    fallback_builder
172        .commit_from_memory(&model_bytes)
173        .map_err(|e| anyhow!("ORT load model error: {e}"))
174}
175
176fn argmax_and_unique(logits: ndarray::ArrayView2<'_, f32>) -> Vec<i32> {
177    let blank_id = 0i32;
178    let mut prev: Option<i32> = None;
179    let mut out = Vec::new();
180    for t in 0..logits.len_of(Axis(0)) {
181        let row = logits.index_axis(Axis(0), t);
182        let mut maxv = f32::MIN;
183        let mut arg = 0i32;
184        for (i, v) in row.iter().enumerate() {
185            if *v > maxv {
186                maxv = *v;
187                arg = i as i32;
188            }
189        }
190        if Some(arg) != prev {
191            if arg != blank_id {
192                out.push(arg);
193            }
194            prev = Some(arg);
195        }
196    }
197    out
198}