1use crate::error::{Error, Result};
2use crate::execution::ModelConfig as ExecutionConfig;
3use ndarray::{Array1, Array2, Array3, Array4};
4use ort::session::Session;
5use std::path::Path;
6
7#[derive(Default)]
10pub struct EncoderCache {
11 pub cache_last_channel: Array4<f32>,
13 pub cache_last_time: Array4<f32>,
15 pub cache_last_channel_len: Array1<i64>,
17}
18
19impl EncoderCache {
20 pub fn new() -> Self {
22 Self {
23 cache_last_channel: Array4::zeros((17, 1, 70, 512)),
24 cache_last_time: Array4::zeros((17, 1, 512, 8)),
25 cache_last_channel_len: Array1::from_vec(vec![0i64]),
26 }
27 }
28}
29
30pub struct ParakeetEOUModel {
31 encoder: Session,
32 decoder_joint: Session,
33}
34
35impl ParakeetEOUModel {
36 pub fn from_pretrained<P: AsRef<Path>>(
37 model_dir: P,
38 exec_config: ExecutionConfig,
39 ) -> Result<Self> {
40 let model_dir = model_dir.as_ref();
41
42 let encoder_path = model_dir.join("encoder.onnx");
43 let decoder_path = model_dir.join("decoder_joint.onnx");
44
45 if !encoder_path.exists() || !decoder_path.exists() {
46 return Err(Error::Config(format!(
47 "Missing ONNX files in {}. Expected encoder.onnx and decoder_joint.onnx",
48 model_dir.display()
49 )));
50 }
51
52 let builder = Session::builder()?;
54 let builder = exec_config.apply_to_session_builder(builder)?;
55 let encoder = builder.commit_from_file(&encoder_path)?;
56
57 let builder = Session::builder()?;
59 let builder = exec_config.apply_to_session_builder(builder)?;
60 let decoder_joint = builder.commit_from_file(&decoder_path)?;
61
62 Ok(Self {
63 encoder,
64 decoder_joint,
65 })
66 }
67
68 pub fn run_encoder(
72 &mut self,
73 features: &Array3<f32>,
74 length: i64,
75 cache: &EncoderCache,
76 ) -> Result<(Array3<f32>, EncoderCache)> {
77 let length_arr = Array1::from_vec(vec![length]);
78
79 let outputs = self.encoder.run(ort::inputs![
80 "audio_signal" => ort::value::Value::from_array(features.clone())?,
81 "length" => ort::value::Value::from_array(length_arr)?,
82 "cache_last_channel" => ort::value::Value::from_array(cache.cache_last_channel.clone())?,
83 "cache_last_time" => ort::value::Value::from_array(cache.cache_last_time.clone())?,
84 "cache_last_channel_len" => ort::value::Value::from_array(cache.cache_last_channel_len.clone())?
85 ])?;
86
87 let (shape, data) = outputs["outputs"]
89 .try_extract_tensor::<f32>()
90 .map_err(|e| Error::Model(format!("Failed to extract encoder output: {e}")))?;
91
92 let shape_dims = shape.as_ref();
93 let b = shape_dims[0] as usize;
94 let d = shape_dims[1] as usize;
95 let t = shape_dims[2] as usize;
96
97 let encoder_out = Array3::from_shape_vec((b, d, t), data.to_vec())
98 .map_err(|e| Error::Model(format!("Failed to reshape encoder output: {e}")))?;
99
100 let (ch_shape, ch_data) = outputs["new_cache_last_channel"]
102 .try_extract_tensor::<f32>()
103 .map_err(|e| Error::Model(format!("Failed to extract cache_last_channel: {e}")))?;
104
105 let (tm_shape, tm_data) = outputs["new_cache_last_time"]
106 .try_extract_tensor::<f32>()
107 .map_err(|e| Error::Model(format!("Failed to extract cache_last_time: {e}")))?;
108
109 let (len_shape, len_data) = outputs["new_cache_last_channel_len"]
110 .try_extract_tensor::<i64>()
111 .map_err(|e| Error::Model(format!("Failed to extract cache_len: {e}")))?;
112
113 let new_cache = EncoderCache {
115 cache_last_channel: Array4::from_shape_vec(
116 (
117 ch_shape[0] as usize,
118 ch_shape[1] as usize,
119 ch_shape[2] as usize,
120 ch_shape[3] as usize,
121 ),
122 ch_data.to_vec(),
123 )
124 .map_err(|e| Error::Model(format!("Failed to reshape cache_last_channel: {e}")))?,
125
126 cache_last_time: Array4::from_shape_vec(
127 (
128 tm_shape[0] as usize,
129 tm_shape[1] as usize,
130 tm_shape[2] as usize,
131 tm_shape[3] as usize,
132 ),
133 tm_data.to_vec(),
134 )
135 .map_err(|e| Error::Model(format!("Failed to reshape cache_last_time: {e}")))?,
136
137 cache_last_channel_len: Array1::from_shape_vec(
138 len_shape[0] as usize,
139 len_data.to_vec(),
140 )
141 .map_err(|e| Error::Model(format!("Failed to reshape cache_len: {e}")))?,
142 };
143
144 Ok((encoder_out, new_cache))
145 }
146
147 pub fn run_decoder(
150 &mut self,
151 encoder_frame: &Array3<f32>, last_token: &Array2<i32>, state_h: &Array3<f32>, state_c: &Array3<f32>, ) -> Result<(Array3<f32>, Array3<f32>, Array3<f32>)> {
156 let target_len = Array1::from_vec(vec![1i32]);
158
159 let outputs = self.decoder_joint.run(ort::inputs![
160 "encoder_outputs" => ort::value::Value::from_array(encoder_frame.clone())?,
161 "targets" => ort::value::Value::from_array(last_token.clone())?,
162 "target_length" => ort::value::Value::from_array(target_len)?,
163 "input_states_1" => ort::value::Value::from_array(state_h.clone())?,
164 "input_states_2" => ort::value::Value::from_array(state_c.clone())?
165 ])?;
166
167 let (l_shape, l_data) = outputs["outputs"]
169 .try_extract_tensor::<f32>()
170 .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?;
171
172 let (_h_shape, h_data) = outputs["output_states_1"]
174 .try_extract_tensor::<f32>()
175 .map_err(|e| Error::Model(format!("Failed to extract state h: {e}")))?;
176
177 let (_c_shape, c_data) = outputs["output_states_2"]
178 .try_extract_tensor::<f32>()
179 .map_err(|e| Error::Model(format!("Failed to extract state c: {e}")))?;
180
181 let vocab_size = l_shape[3] as usize;
184 let logits = Array3::from_shape_vec((1, 1, vocab_size), l_data.to_vec())
185 .map_err(|e| Error::Model(format!("Reshape logits failed: {e}")))?;
186
187 let new_h = Array3::from_shape_vec((1, 1, 640), h_data.to_vec())
189 .map_err(|e| Error::Model(format!("Reshape state h failed: {e}")))?;
190
191 let new_c = Array3::from_shape_vec((1, 1, 640), c_data.to_vec())
192 .map_err(|e| Error::Model(format!("Reshape state c failed: {e}")))?;
193
194 Ok((logits, new_h, new_c))
195 }
196}