1use crate::error::{Error, Result};
2use crate::execution::ModelConfig as ExecutionConfig;
3use ndarray::{Array1, Array2, Array3, Array4};
4use ort::session::Session;
5use std::path::Path;
6
7#[derive(Clone)]
9pub struct NemotronEncoderCache {
10 pub cache_last_channel: Array4<f32>,
12 pub cache_last_time: Array4<f32>,
14 pub cache_last_channel_len: Array1<i64>,
16}
17
18impl Default for NemotronEncoderCache {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl NemotronEncoderCache {
25 pub fn new() -> Self {
26 Self {
27 cache_last_channel: Array4::zeros((24, 1, 70, 1024)),
28 cache_last_time: Array4::zeros((24, 1, 1024, 8)),
29 cache_last_channel_len: Array1::from_vec(vec![0i64]),
30 }
31 }
32
33 pub fn with_dims(
34 num_layers: usize,
35 left_context: usize,
36 hidden_dim: usize,
37 conv_context: usize,
38 ) -> Self {
39 Self {
40 cache_last_channel: Array4::zeros((num_layers, 1, left_context, hidden_dim)),
41 cache_last_time: Array4::zeros((num_layers, 1, hidden_dim, conv_context)),
42 cache_last_channel_len: Array1::from_vec(vec![0i64]),
43 }
44 }
45}
46
47pub struct NemotronModel {
50 encoder: Session,
51 decoder_joint: Session,
52 pub config: NemotronModelConfig,
53}
54
55#[derive(Debug, Clone)]
57pub struct NemotronModelConfig {
58 pub num_encoder_layers: usize,
59 pub hidden_dim: usize,
60 pub left_context: usize,
61 pub conv_context: usize,
62 pub decoder_lstm_dim: usize,
63 pub decoder_lstm_layers: usize,
64 pub vocab_size: usize,
65 pub blank_id: usize,
66}
67
68impl Default for NemotronModelConfig {
69 fn default() -> Self {
70 Self {
71 num_encoder_layers: 24,
72 hidden_dim: 1024,
73 left_context: 70,
74 conv_context: 8,
75 decoder_lstm_dim: 640,
76 decoder_lstm_layers: 2,
77 vocab_size: 1024,
78 blank_id: 1024,
79 }
80 }
81}
82
83impl NemotronModel {
84 pub fn from_pretrained<P: AsRef<Path>>(
85 model_dir: P,
86 exec_config: ExecutionConfig,
87 config: NemotronModelConfig,
88 ) -> Result<Self> {
89 let model_dir = model_dir.as_ref();
90
91 let encoder_path = model_dir.join("encoder.onnx");
92 let decoder_path = model_dir.join("decoder_joint.onnx");
93
94 if !encoder_path.exists() {
95 return Err(Error::Config(format!(
96 "Missing encoder.onnx in {}",
97 model_dir.display()
98 )));
99 }
100 if !decoder_path.exists() {
101 return Err(Error::Config(format!(
102 "Missing decoder_joint.onnx in {}",
103 model_dir.display()
104 )));
105 }
106
107 let builder = Session::builder()?;
108 let builder = exec_config.apply_to_session_builder(builder)?;
109 let encoder = builder.commit_from_file(&encoder_path)?;
110
111 let builder = Session::builder()?;
112 let builder = exec_config.apply_to_session_builder(builder)?;
113 let decoder_joint = builder.commit_from_file(&decoder_path)?;
114
115 Ok(Self {
116 encoder,
117 decoder_joint,
118 config,
119 })
120 }
121
122 pub fn run_encoder(
126 &mut self,
127 features: &Array3<f32>,
128 length: i64,
129 cache: &NemotronEncoderCache,
130 ) -> Result<(Array3<f32>, i64, NemotronEncoderCache)> {
131 let length_arr = Array1::from_vec(vec![length]);
132
133 let outputs = self.encoder.run(ort::inputs![
134 "processed_signal" => ort::value::Value::from_array(features.clone())?,
135 "processed_signal_length" => ort::value::Value::from_array(length_arr)?,
136 "cache_last_channel" => ort::value::Value::from_array(cache.cache_last_channel.clone())?,
137 "cache_last_time" => ort::value::Value::from_array(cache.cache_last_time.clone())?,
138 "cache_last_channel_len" => ort::value::Value::from_array(cache.cache_last_channel_len.clone())?
139 ])?;
140
141 let (shape, data) = outputs["encoded"]
143 .try_extract_tensor::<f32>()
144 .map_err(|e| Error::Model(format!("Failed to extract encoder output: {e}")))?;
145
146 let shape_dims = shape.as_ref();
147 let b = shape_dims[0] as usize;
148 let d = shape_dims[1] as usize;
149 let t = shape_dims[2] as usize;
150
151 let encoder_out = Array3::from_shape_vec((b, d, t), data.to_vec())
152 .map_err(|e| Error::Model(format!("Failed to reshape encoder output: {e}")))?;
153
154 let (_, enc_len_data) = outputs["encoded_len"]
156 .try_extract_tensor::<i64>()
157 .map_err(|e| Error::Model(format!("Failed to extract encoded_len: {e}")))?;
158 let encoded_len = enc_len_data[0];
159
160 let (ch_shape, ch_data) = outputs["cache_last_channel_next"]
161 .try_extract_tensor::<f32>()
162 .map_err(|e| Error::Model(format!("Failed to extract cache_last_channel: {e}")))?;
163
164 let (tm_shape, tm_data) = outputs["cache_last_time_next"]
165 .try_extract_tensor::<f32>()
166 .map_err(|e| Error::Model(format!("Failed to extract cache_last_time: {e}")))?;
167
168 let (len_shape, len_data) = outputs["cache_last_channel_len_next"]
169 .try_extract_tensor::<i64>()
170 .map_err(|e| Error::Model(format!("Failed to extract cache_len: {e}")))?;
171
172 let new_cache = NemotronEncoderCache {
173 cache_last_channel: Array4::from_shape_vec(
174 (
175 ch_shape[0] as usize,
176 ch_shape[1] as usize,
177 ch_shape[2] as usize,
178 ch_shape[3] as usize,
179 ),
180 ch_data.to_vec(),
181 )
182 .map_err(|e| Error::Model(format!("Failed to reshape cache_last_channel: {e}")))?,
183
184 cache_last_time: Array4::from_shape_vec(
185 (
186 tm_shape[0] as usize,
187 tm_shape[1] as usize,
188 tm_shape[2] as usize,
189 tm_shape[3] as usize,
190 ),
191 tm_data.to_vec(),
192 )
193 .map_err(|e| Error::Model(format!("Failed to reshape cache_last_time: {e}")))?,
194
195 cache_last_channel_len: Array1::from_shape_vec(
196 len_shape[0] as usize,
197 len_data.to_vec(),
198 )
199 .map_err(|e| Error::Model(format!("Failed to reshape cache_len: {e}")))?,
200 };
201
202 Ok((encoder_out, encoded_len, new_cache))
203 }
204
205 pub fn run_decoder(
208 &mut self,
209 encoder_frame: &Array3<f32>, target_token: i32,
211 state_1: &Array3<f32>, state_2: &Array3<f32>, ) -> Result<(Array1<f32>, Array3<f32>, Array3<f32>)> {
214 let targets = Array2::from_shape_vec((1, 1), vec![target_token])
215 .map_err(|e| Error::Model(format!("Failed to create targets: {e}")))?;
216 let target_len = Array1::from_vec(vec![1i32]);
217
218 let outputs = self.decoder_joint.run(ort::inputs![
219 "encoder_outputs" => ort::value::Value::from_array(encoder_frame.clone())?,
220 "targets" => ort::value::Value::from_array(targets)?,
221 "target_length" => ort::value::Value::from_array(target_len)?,
222 "input_states_1" => ort::value::Value::from_array(state_1.clone())?,
223 "input_states_2" => ort::value::Value::from_array(state_2.clone())?
224 ])?;
225
226 let (_l_shape, l_data) = outputs["outputs"]
228 .try_extract_tensor::<f32>()
229 .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?;
230
231 let logits = Array1::from_vec(l_data.to_vec());
232
233 let (h_shape, h_data) = outputs["output_states_1"]
234 .try_extract_tensor::<f32>()
235 .map_err(|e| Error::Model(format!("Failed to extract state_1: {e}")))?;
236
237 let (c_shape, c_data) = outputs["output_states_2"]
238 .try_extract_tensor::<f32>()
239 .map_err(|e| Error::Model(format!("Failed to extract state_2: {e}")))?;
240
241 let new_state_1 = Array3::from_shape_vec(
242 (
243 h_shape[0] as usize,
244 h_shape[1] as usize,
245 h_shape[2] as usize,
246 ),
247 h_data.to_vec(),
248 )
249 .map_err(|e| Error::Model(format!("Failed to reshape state_1: {e}")))?;
250
251 let new_state_2 = Array3::from_shape_vec(
252 (
253 c_shape[0] as usize,
254 c_shape[1] as usize,
255 c_shape[2] as usize,
256 ),
257 c_data.to_vec(),
258 )
259 .map_err(|e| Error::Model(format!("Failed to reshape state_2: {e}")))?;
260
261 Ok((logits, new_state_1, new_state_2))
262 }
263}