Skip to main content

parakeet_rs/
model_eou.rs

1use crate::error::{Error, Result};
2use crate::execution::ModelConfig as ExecutionConfig;
3use ndarray::{Array1, Array2, Array3, Array4};
4use ort::session::Session;
5use std::path::Path;
6
7/// Encoder cache state for streaming inference
8/// The cache maintains temporal context across chunks
9#[derive(Default)]
10pub struct EncoderCache {
11    /// channel cache: [1, 1, 70, 512] - batch=1, 70 frame lookback
12    pub cache_last_channel: Array4<f32>,
13    /// time cache: [1, 1, 512, 8] - batch=1, fixed 8 time steps
14    pub cache_last_time: Array4<f32>,
15    /// cache length: [1] with value 0 initially
16    pub cache_last_channel_len: Array1<i64>,
17}
18
19impl EncoderCache {
20    /// 17 layers, batch=1, 70 frame lookback, 512 features
21    pub fn new() -> Self {
22        Self {
23            cache_last_channel: Array4::zeros((17, 1, 70, 512)),
24            cache_last_time: Array4::zeros((17, 1, 512, 8)),
25            cache_last_channel_len: Array1::from_vec(vec![0i64]),
26        }
27    }
28}
29
30pub struct ParakeetEOUModel {
31    encoder: Session,
32    decoder_joint: Session,
33}
34
35impl ParakeetEOUModel {
36    pub fn from_pretrained<P: AsRef<Path>>(
37        model_dir: P,
38        exec_config: ExecutionConfig,
39    ) -> Result<Self> {
40        let model_dir = model_dir.as_ref();
41
42        let encoder_path = model_dir.join("encoder.onnx");
43        let decoder_path = model_dir.join("decoder_joint.onnx");
44
45        if !encoder_path.exists() || !decoder_path.exists() {
46            return Err(Error::Config(format!(
47                "Missing ONNX files in {}. Expected encoder.onnx and decoder_joint.onnx",
48                model_dir.display()
49            )));
50        }
51
52        // Load encoder
53        let builder = Session::builder()?;
54        let builder = exec_config.apply_to_session_builder(builder)?;
55        let encoder = builder.commit_from_file(&encoder_path)?;
56
57        // Load decoder
58        let builder = Session::builder()?;
59        let builder = exec_config.apply_to_session_builder(builder)?;
60        let decoder_joint = builder.commit_from_file(&decoder_path)?;
61
62        Ok(Self {
63            encoder,
64            decoder_joint,
65        })
66    }
67
68    /// Run the stateful encoder with cache
69    /// Input: features [1, 128, T], cache state
70    /// Output: (encoded [1, 512, T], new_cache)
71    pub fn run_encoder(
72        &mut self,
73        features: &Array3<f32>,
74        length: i64,
75        cache: &EncoderCache,
76    ) -> Result<(Array3<f32>, EncoderCache)> {
77        let length_arr = Array1::from_vec(vec![length]);
78
79        let outputs = self.encoder.run(ort::inputs![
80            "audio_signal" => ort::value::Value::from_array(features.clone())?,
81            "length" => ort::value::Value::from_array(length_arr)?,
82            "cache_last_channel" => ort::value::Value::from_array(cache.cache_last_channel.clone())?,
83            "cache_last_time" => ort::value::Value::from_array(cache.cache_last_time.clone())?,
84            "cache_last_channel_len" => ort::value::Value::from_array(cache.cache_last_channel_len.clone())?
85        ])?;
86
87        // Extract encoder output [1, 512, T]
88        let (shape, data) = outputs["outputs"]
89            .try_extract_tensor::<f32>()
90            .map_err(|e| Error::Model(format!("Failed to extract encoder output: {e}")))?;
91
92        let shape_dims = shape.as_ref();
93        let b = shape_dims[0] as usize;
94        let d = shape_dims[1] as usize;
95        let t = shape_dims[2] as usize;
96
97        let encoder_out = Array3::from_shape_vec((b, d, t), data.to_vec())
98            .map_err(|e| Error::Model(format!("Failed to reshape encoder output: {e}")))?;
99
100        // Extract new cache states
101        let (ch_shape, ch_data) = outputs["new_cache_last_channel"]
102            .try_extract_tensor::<f32>()
103            .map_err(|e| Error::Model(format!("Failed to extract cache_last_channel: {e}")))?;
104
105        let (tm_shape, tm_data) = outputs["new_cache_last_time"]
106            .try_extract_tensor::<f32>()
107            .map_err(|e| Error::Model(format!("Failed to extract cache_last_time: {e}")))?;
108
109        let (len_shape, len_data) = outputs["new_cache_last_channel_len"]
110            .try_extract_tensor::<i64>()
111            .map_err(|e| Error::Model(format!("Failed to extract cache_len: {e}")))?;
112
113        // Build new cache with extracted shapes
114        let new_cache = EncoderCache {
115            cache_last_channel: Array4::from_shape_vec(
116                (
117                    ch_shape[0] as usize,
118                    ch_shape[1] as usize,
119                    ch_shape[2] as usize,
120                    ch_shape[3] as usize,
121                ),
122                ch_data.to_vec(),
123            )
124            .map_err(|e| Error::Model(format!("Failed to reshape cache_last_channel: {e}")))?,
125
126            cache_last_time: Array4::from_shape_vec(
127                (
128                    tm_shape[0] as usize,
129                    tm_shape[1] as usize,
130                    tm_shape[2] as usize,
131                    tm_shape[3] as usize,
132                ),
133                tm_data.to_vec(),
134            )
135            .map_err(|e| Error::Model(format!("Failed to reshape cache_last_time: {e}")))?,
136
137            cache_last_channel_len: Array1::from_shape_vec(
138                len_shape[0] as usize,
139                len_data.to_vec(),
140            )
141            .map_err(|e| Error::Model(format!("Failed to reshape cache_len: {e}")))?,
142        };
143
144        Ok((encoder_out, new_cache))
145    }
146
147    /// Run the stateful decoder
148    /// Returns: (logits [1, 1, 1, vocab], new_state_h, new_state_c)
149    pub fn run_decoder(
150        &mut self,
151        encoder_frame: &Array3<f32>, // [1, 512, 1]
152        last_token: &Array2<i32>,    // [1, 1]
153        state_h: &Array3<f32>,       // [1, 1, 640]
154        state_c: &Array3<f32>,       // [1, 1, 640]
155    ) -> Result<(Array3<f32>, Array3<f32>, Array3<f32>)> {
156        // Target length is always 1 for single step
157        let target_len = Array1::from_vec(vec![1i32]);
158
159        let outputs = self.decoder_joint.run(ort::inputs![
160            "encoder_outputs" => ort::value::Value::from_array(encoder_frame.clone())?,
161            "targets" => ort::value::Value::from_array(last_token.clone())?,
162            "target_length" => ort::value::Value::from_array(target_len)?,
163            "input_states_1" => ort::value::Value::from_array(state_h.clone())?,
164            "input_states_2" => ort::value::Value::from_array(state_c.clone())?
165        ])?;
166
167        // 1. Extract Logits
168        let (l_shape, l_data) = outputs["outputs"]
169            .try_extract_tensor::<f32>()
170            .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?;
171
172        // 2. Extract States (output_states_1, output_states_2)
173        let (_h_shape, h_data) = outputs["output_states_1"]
174            .try_extract_tensor::<f32>()
175            .map_err(|e| Error::Model(format!("Failed to extract state h: {e}")))?;
176
177        let (_c_shape, c_data) = outputs["output_states_2"]
178            .try_extract_tensor::<f32>()
179            .map_err(|e| Error::Model(format!("Failed to extract state c: {e}")))?;
180
181        // Reconstruct Arrays
182        // Logits: I simplify to [1, 1, vocab]
183        let vocab_size = l_shape[3] as usize;
184        let logits = Array3::from_shape_vec((1, 1, vocab_size), l_data.to_vec())
185            .map_err(|e| Error::Model(format!("Reshape logits failed: {e}")))?;
186
187        // States: [1, 1, 640]
188        let new_h = Array3::from_shape_vec((1, 1, 640), h_data.to_vec())
189            .map_err(|e| Error::Model(format!("Reshape state h failed: {e}")))?;
190
191        let new_c = Array3::from_shape_vec((1, 1, 640), c_data.to_vec())
192            .map_err(|e| Error::Model(format!("Reshape state c failed: {e}")))?;
193
194        Ok((logits, new_h, new_c))
195    }
196}