Skip to main content

parakeet_rs/
model.rs

1use crate::config::ModelConfig;
2use crate::error::{Error, Result};
3use crate::execution::ModelConfig as ExecutionConfig;
4use ndarray::Array2;
5use ort::session::Session;
6use std::path::Path;
7
8pub struct ParakeetModel {
9    session: Session,
10    config: ModelConfig,
11}
12
13impl ParakeetModel {
14    pub fn from_pretrained<P: AsRef<Path>>(model_path: P) -> Result<Self> {
15        Self::from_pretrained_with_config(model_path, ExecutionConfig::default())
16    }
17
18    pub fn from_pretrained_with_config<P: AsRef<Path>>(
19        model_path: P,
20        exec_config: ExecutionConfig,
21    ) -> Result<Self> {
22        let model_path = model_path.as_ref();
23
24        // Use default config (hardcoded constants for Parakeet-CTC-0.6b: please see: json files https://huggingface.co/onnx-community/parakeet-ctc-0.6b-ONNX/tree/main)
25        let config = ModelConfig::default();
26
27        let builder = Session::builder()?;
28        let builder = exec_config.apply_to_session_builder(builder)?;
29        let session = builder.commit_from_file(model_path)?;
30
31        Ok(Self { session, config })
32    }
33    pub fn forward(&mut self, features: Array2<f32>) -> Result<Array2<f32>> {
34        let batch_size = 1;
35        let time_steps = features.shape()[0];
36        let feature_size = features.shape()[1];
37
38        let input = features
39            .to_shape((batch_size, time_steps, feature_size))
40            .map_err(|e| Error::Model(format!("Failed to reshape input: {e}")))?
41            .to_owned();
42
43        use ndarray::Array2;
44        let attention_mask = Array2::<i64>::ones((batch_size, time_steps));
45
46        let input_value = ort::value::Value::from_array(input)?;
47        let attention_mask_value = ort::value::Value::from_array(attention_mask)?;
48
49        let outputs = self.session.run(ort::inputs!(
50            "input_features" => input_value,
51            "attention_mask" => attention_mask_value
52        ))?;
53
54        let logits_value = &outputs["logits"];
55        let (shape, data) = logits_value
56            .try_extract_tensor::<f32>()
57            .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?;
58
59        let shape_dims = shape.as_ref();
60        if shape_dims.len() != 3 {
61            return Err(Error::Model(format!(
62                "Expected 3D logits, got shape: {shape_dims:?}"
63            )));
64        }
65
66        let batch_size = shape_dims[0] as usize;
67        let time_steps_out = shape_dims[1] as usize;
68        let vocab_size = shape_dims[2] as usize;
69
70        if batch_size != 1 {
71            return Err(Error::Model(format!(
72                "Expected batch size 1, got {batch_size}"
73            )));
74        }
75
76        let logits_2d = Array2::from_shape_vec((time_steps_out, vocab_size), data.to_vec())
77            .map_err(|e| Error::Model(format!("Failed to create array: {e}")))?;
78
79        Ok(logits_2d)
80    }
81
82    pub fn config(&self) -> &ModelConfig {
83        &self.config
84    }
85
86    pub fn vocab_size(&self) -> usize {
87        self.config.vocab_size
88    }
89
90    pub fn pad_token_id(&self) -> usize {
91        self.config.pad_token_id
92    }
93}