Skip to main content

parakeet_rs/
model_unified.rs

1use crate::error::{Error, Result};
2use crate::execution::ModelConfig as ExecutionConfig;
3use ndarray::{Array1, Array2, Array3};
4use ort::session::Session;
5use std::path::{Path, PathBuf};
6
7#[derive(Debug, Clone, Copy)]
8pub struct UnifiedModelConfig {
9    pub vocab_size: usize,
10    pub blank_id: usize,
11    pub decoder_lstm_dim: usize,
12    pub decoder_lstm_layers: usize,
13    pub subsampling_factor: usize,
14}
15
16impl Default for UnifiedModelConfig {
17    fn default() -> Self {
18        Self {
19            vocab_size: 1025,
20            blank_id: 1024,
21            decoder_lstm_dim: 640,
22            decoder_lstm_layers: 2,
23            subsampling_factor: 8,
24        }
25    }
26}
27
28pub struct ParakeetUnifiedModel {
29    encoder: Session,
30    decoder_joint: Session,
31    pub config: UnifiedModelConfig,
32}
33
34impl ParakeetUnifiedModel {
35    pub fn from_pretrained<P: AsRef<Path>>(
36        model_dir: P,
37        exec_config: ExecutionConfig,
38        config: UnifiedModelConfig,
39    ) -> Result<Self> {
40        let model_dir = model_dir.as_ref();
41        let encoder_path = Self::find_encoder(model_dir)?;
42        let decoder_joint_path = Self::find_decoder_joint(model_dir)?;
43
44        let builder = Session::builder()?;
45        let mut builder = exec_config.apply_to_session_builder(builder)?;
46        let encoder = builder.commit_from_file(&encoder_path)?;
47
48        let builder = Session::builder()?;
49        let mut builder = exec_config.apply_to_session_builder(builder)?;
50        let decoder_joint = builder.commit_from_file(&decoder_joint_path)?;
51
52        Ok(Self {
53            encoder,
54            decoder_joint,
55            config,
56        })
57    }
58
59    fn find_encoder(dir: &Path) -> Result<PathBuf> {
60        let candidates = ["encoder.onnx", "encoder.int8.onnx", "encoder-model.onnx"];
61        for candidate in &candidates {
62            let path = dir.join(candidate);
63            if path.exists() {
64                return Ok(path);
65            }
66        }
67
68        Err(Error::Config(format!(
69            "No unified encoder model found in {}",
70            dir.display()
71        )))
72    }
73
74    fn find_decoder_joint(dir: &Path) -> Result<PathBuf> {
75        let candidates = [
76            "decoder_joint.onnx",
77            "decoder_joint.int8.onnx",
78            "decoder_joint-model.onnx",
79        ];
80        for candidate in &candidates {
81            let path = dir.join(candidate);
82            if path.exists() {
83                return Ok(path);
84            }
85        }
86
87        Err(Error::Config(format!(
88            "No unified decoder_joint model found in {}",
89            dir.display()
90        )))
91    }
92
93    pub fn run_encoder(&mut self, features: &Array2<f32>) -> Result<(Array3<f32>, i64)> {
94        let time_steps = features.shape()[0];
95        let feature_size = features.shape()[1];
96
97        let input = features
98            .t()
99            .to_shape((1, feature_size, time_steps))
100            .map_err(|e| Error::Model(format!("Failed to build encoder input: {e}")))?
101            .to_owned();
102
103        let input_length = Array1::from_vec(vec![time_steps as i64]);
104
105        let outputs = self.encoder.run(ort::inputs!(
106            "audio_signal" => ort::value::Value::from_array(input)?,
107            "length" => ort::value::Value::from_array(input_length)?
108        ))?;
109
110        let (shape, data) = outputs["outputs"]
111            .try_extract_tensor::<f32>()
112            .map_err(|e| Error::Model(format!("Failed to extract encoder output: {e}")))?;
113
114        let (_, lens_data) = outputs["encoded_lengths"]
115            .try_extract_tensor::<i64>()
116            .map_err(|e| Error::Model(format!("Failed to extract encoder lengths: {e}")))?;
117
118        let dims = shape.as_ref();
119        if dims.len() != 3 {
120            return Err(Error::Model(format!(
121                "Expected 3D encoder output, got shape: {dims:?}"
122            )));
123        }
124
125        let encoder_out = Array3::from_shape_vec(
126            (dims[0] as usize, dims[1] as usize, dims[2] as usize),
127            data.to_vec(),
128        )
129        .map_err(|e| Error::Model(format!("Failed to create encoder array: {e}")))?;
130
131        Ok((encoder_out, lens_data[0]))
132    }
133
134    pub fn run_decoder(
135        &mut self,
136        encoder_frame: &Array3<f32>,
137        target_token: i32,
138        state_1: &Array3<f32>,
139        state_2: &Array3<f32>,
140    ) -> Result<(Array1<f32>, Array3<f32>, Array3<f32>)> {
141        let targets = Array2::from_elem((1, 1), target_token);
142        let target_length = Array1::from_elem(1, 1i32);
143
144        let outputs = self.decoder_joint.run(ort::inputs![
145            "encoder_outputs" => ort::value::Value::from_array(encoder_frame.clone())?,
146            "targets" => ort::value::Value::from_array(targets)?,
147            "target_length" => ort::value::Value::from_array(target_length)?,
148            "input_states_1" => ort::value::Value::from_array(state_1.clone())?,
149            "input_states_2" => ort::value::Value::from_array(state_2.clone())?
150        ])?;
151
152        let (_, logits_data) = outputs["outputs"]
153            .try_extract_tensor::<f32>()
154            .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?;
155
156        let logits = Array1::from_vec(logits_data.to_vec());
157
158        let (h_shape, h_data) = outputs["output_states_1"]
159            .try_extract_tensor::<f32>()
160            .map_err(|e| Error::Model(format!("Failed to extract state_1: {e}")))?;
161        let (c_shape, c_data) = outputs["output_states_2"]
162            .try_extract_tensor::<f32>()
163            .map_err(|e| Error::Model(format!("Failed to extract state_2: {e}")))?;
164
165        let new_state_1 = Array3::from_shape_vec(
166            (
167                h_shape[0] as usize,
168                h_shape[1] as usize,
169                h_shape[2] as usize,
170            ),
171            h_data.to_vec(),
172        )
173        .map_err(|e| Error::Model(format!("Failed to reshape state_1: {e}")))?;
174
175        let new_state_2 = Array3::from_shape_vec(
176            (
177                c_shape[0] as usize,
178                c_shape[1] as usize,
179                c_shape[2] as usize,
180            ),
181            c_data.to_vec(),
182        )
183        .map_err(|e| Error::Model(format!("Failed to reshape state_2: {e}")))?;
184
185        Ok((logits, new_state_1, new_state_2))
186    }
187}