Skip to main content

parakeet_rs/
model_nemotron.rs

1use crate::error::{Error, Result};
2use crate::execution::ModelConfig as ExecutionConfig;
3use ndarray::{Array1, Array2, Array3, Array4};
4use ort::session::Session;
5use std::path::Path;
6
7/// Encoder cache state for Nemotron streaming inference.
8#[derive(Clone)]
9pub struct NemotronEncoderCache {
10    /// [24, 1, 70, 1024] - 24 layers, batch=1, 70 frame lookback, 1024 features
11    pub cache_last_channel: Array4<f32>,
12    /// [24, 1, 1024, 8] - 24 layers, batch=1, 1024 features, 8 conv context
13    pub cache_last_time: Array4<f32>,
14    /// [1] - current cache length
15    pub cache_last_channel_len: Array1<i64>,
16}
17
18impl Default for NemotronEncoderCache {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl NemotronEncoderCache {
25    pub fn new() -> Self {
26        Self {
27            cache_last_channel: Array4::zeros((24, 1, 70, 1024)),
28            cache_last_time: Array4::zeros((24, 1, 1024, 8)),
29            cache_last_channel_len: Array1::from_vec(vec![0i64]),
30        }
31    }
32
33    pub fn with_dims(
34        num_layers: usize,
35        left_context: usize,
36        hidden_dim: usize,
37        conv_context: usize,
38    ) -> Self {
39        Self {
40            cache_last_channel: Array4::zeros((num_layers, 1, left_context, hidden_dim)),
41            cache_last_time: Array4::zeros((num_layers, 1, hidden_dim, conv_context)),
42            cache_last_channel_len: Array1::from_vec(vec![0i64]),
43        }
44    }
45}
46
47/// Nemotron ONNX wrapper.
48/// we handle encoder and decoder_joint sessions separately.
49pub struct NemotronModel {
50    encoder: Session,
51    decoder_joint: Session,
52    pub config: NemotronModelConfig,
53}
54
55/// cfg for Nemotron model dims.
56#[derive(Debug, Clone)]
57pub struct NemotronModelConfig {
58    pub num_encoder_layers: usize,
59    pub hidden_dim: usize,
60    pub left_context: usize,
61    pub conv_context: usize,
62    pub decoder_lstm_dim: usize,
63    pub decoder_lstm_layers: usize,
64    pub vocab_size: usize,
65    pub blank_id: usize,
66}
67
68impl Default for NemotronModelConfig {
69    fn default() -> Self {
70        Self {
71            num_encoder_layers: 24,
72            hidden_dim: 1024,
73            left_context: 70,
74            conv_context: 8,
75            decoder_lstm_dim: 640,
76            decoder_lstm_layers: 2,
77            vocab_size: 1024,
78            blank_id: 1024,
79        }
80    }
81}
82
83impl NemotronModel {
84    pub fn from_pretrained<P: AsRef<Path>>(
85        model_dir: P,
86        exec_config: ExecutionConfig,
87        config: NemotronModelConfig,
88    ) -> Result<Self> {
89        let model_dir = model_dir.as_ref();
90
91        let encoder_path = model_dir.join("encoder.onnx");
92        let decoder_path = model_dir.join("decoder_joint.onnx");
93
94        if !encoder_path.exists() {
95            return Err(Error::Config(format!(
96                "Missing encoder.onnx in {}",
97                model_dir.display()
98            )));
99        }
100        if !decoder_path.exists() {
101            return Err(Error::Config(format!(
102                "Missing decoder_joint.onnx in {}",
103                model_dir.display()
104            )));
105        }
106
107        let builder = Session::builder()?;
108        let builder = exec_config.apply_to_session_builder(builder)?;
109        let encoder = builder.commit_from_file(&encoder_path)?;
110
111        let builder = Session::builder()?;
112        let builder = exec_config.apply_to_session_builder(builder)?;
113        let decoder_joint = builder.commit_from_file(&decoder_path)?;
114
115        Ok(Self {
116            encoder,
117            decoder_joint,
118            config,
119        })
120    }
121
122    /// Run encoder with cache-aware streaming.
123    /// i: mel features [1, n_mels, time], cache state
124    /// o: (encoded [1, hidden_dim, time], new_cache)
125    pub fn run_encoder(
126        &mut self,
127        features: &Array3<f32>,
128        length: i64,
129        cache: &NemotronEncoderCache,
130    ) -> Result<(Array3<f32>, i64, NemotronEncoderCache)> {
131        let length_arr = Array1::from_vec(vec![length]);
132
133        let outputs = self.encoder.run(ort::inputs![
134            "processed_signal" => ort::value::Value::from_array(features.clone())?,
135            "processed_signal_length" => ort::value::Value::from_array(length_arr)?,
136            "cache_last_channel" => ort::value::Value::from_array(cache.cache_last_channel.clone())?,
137            "cache_last_time" => ort::value::Value::from_array(cache.cache_last_time.clone())?,
138            "cache_last_channel_len" => ort::value::Value::from_array(cache.cache_last_channel_len.clone())?
139        ])?;
140
141        // [1, hidden_dim, time]
142        let (shape, data) = outputs["encoded"]
143            .try_extract_tensor::<f32>()
144            .map_err(|e| Error::Model(format!("Failed to extract encoder output: {e}")))?;
145
146        let shape_dims = shape.as_ref();
147        let b = shape_dims[0] as usize;
148        let d = shape_dims[1] as usize;
149        let t = shape_dims[2] as usize;
150
151        let encoder_out = Array3::from_shape_vec((b, d, t), data.to_vec())
152            .map_err(|e| Error::Model(format!("Failed to reshape encoder output: {e}")))?;
153
154        // on here we are extracting encoded length and new cache states.. and so on...
155        let (_, enc_len_data) = outputs["encoded_len"]
156            .try_extract_tensor::<i64>()
157            .map_err(|e| Error::Model(format!("Failed to extract encoded_len: {e}")))?;
158        let encoded_len = enc_len_data[0];
159
160        let (ch_shape, ch_data) = outputs["cache_last_channel_next"]
161            .try_extract_tensor::<f32>()
162            .map_err(|e| Error::Model(format!("Failed to extract cache_last_channel: {e}")))?;
163
164        let (tm_shape, tm_data) = outputs["cache_last_time_next"]
165            .try_extract_tensor::<f32>()
166            .map_err(|e| Error::Model(format!("Failed to extract cache_last_time: {e}")))?;
167
168        let (len_shape, len_data) = outputs["cache_last_channel_len_next"]
169            .try_extract_tensor::<i64>()
170            .map_err(|e| Error::Model(format!("Failed to extract cache_len: {e}")))?;
171
172        let new_cache = NemotronEncoderCache {
173            cache_last_channel: Array4::from_shape_vec(
174                (
175                    ch_shape[0] as usize,
176                    ch_shape[1] as usize,
177                    ch_shape[2] as usize,
178                    ch_shape[3] as usize,
179                ),
180                ch_data.to_vec(),
181            )
182            .map_err(|e| Error::Model(format!("Failed to reshape cache_last_channel: {e}")))?,
183
184            cache_last_time: Array4::from_shape_vec(
185                (
186                    tm_shape[0] as usize,
187                    tm_shape[1] as usize,
188                    tm_shape[2] as usize,
189                    tm_shape[3] as usize,
190                ),
191                tm_data.to_vec(),
192            )
193            .map_err(|e| Error::Model(format!("Failed to reshape cache_last_time: {e}")))?,
194
195            cache_last_channel_len: Array1::from_shape_vec(
196                len_shape[0] as usize,
197                len_data.to_vec(),
198            )
199            .map_err(|e| Error::Model(format!("Failed to reshape cache_len: {e}")))?,
200        };
201
202        Ok((encoder_out, encoded_len, new_cache))
203    }
204
205    /// Run decoder step.
206    /// Returns: (logits [vocab_size], new_state_1, new_state_2)
207    pub fn run_decoder(
208        &mut self,
209        encoder_frame: &Array3<f32>, // [1, hidden_dim, 1]
210        target_token: i32,
211        state_1: &Array3<f32>, // [2, 1, 640]
212        state_2: &Array3<f32>, // [2, 1, 640]
213    ) -> Result<(Array1<f32>, Array3<f32>, Array3<f32>)> {
214        let targets = Array2::from_shape_vec((1, 1), vec![target_token])
215            .map_err(|e| Error::Model(format!("Failed to create targets: {e}")))?;
216        let target_len = Array1::from_vec(vec![1i32]);
217
218        let outputs = self.decoder_joint.run(ort::inputs![
219            "encoder_outputs" => ort::value::Value::from_array(encoder_frame.clone())?,
220            "targets" => ort::value::Value::from_array(targets)?,
221            "target_length" => ort::value::Value::from_array(target_len)?,
222            "input_states_1" => ort::value::Value::from_array(state_1.clone())?,
223            "input_states_2" => ort::value::Value::from_array(state_2.clone())?
224        ])?;
225
226        // logits for others I think you can understand by looking at the error msgs right?
227        let (_l_shape, l_data) = outputs["outputs"]
228            .try_extract_tensor::<f32>()
229            .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?;
230
231        let logits = Array1::from_vec(l_data.to_vec());
232
233        let (h_shape, h_data) = outputs["output_states_1"]
234            .try_extract_tensor::<f32>()
235            .map_err(|e| Error::Model(format!("Failed to extract state_1: {e}")))?;
236
237        let (c_shape, c_data) = outputs["output_states_2"]
238            .try_extract_tensor::<f32>()
239            .map_err(|e| Error::Model(format!("Failed to extract state_2: {e}")))?;
240
241        let new_state_1 = Array3::from_shape_vec(
242            (
243                h_shape[0] as usize,
244                h_shape[1] as usize,
245                h_shape[2] as usize,
246            ),
247            h_data.to_vec(),
248        )
249        .map_err(|e| Error::Model(format!("Failed to reshape state_1: {e}")))?;
250
251        let new_state_2 = Array3::from_shape_vec(
252            (
253                c_shape[0] as usize,
254                c_shape[1] as usize,
255                c_shape[2] as usize,
256            ),
257            c_data.to_vec(),
258        )
259        .map_err(|e| Error::Model(format!("Failed to reshape state_2: {e}")))?;
260
261        Ok((logits, new_state_1, new_state_2))
262    }
263}