Skip to main content

parakeet_rs/
decoder.rs

1use crate::error::{Error, Result};
2use ndarray::Array2;
3use std::path::Path;
4
5// Token with its timestamp information
6// start and end are in seconds
7#[derive(Debug, Clone)]
8pub struct TimedToken {
9    pub text: String,
10    pub start: f32,
11    pub end: f32,
12}
13
14#[derive(Debug, Clone)]
15pub struct TranscriptionResult {
16    pub text: String,
17    pub tokens: Vec<TimedToken>,
18}
19
20// CTC decoder for parakeet-ctc-0.6b model with token-level timestamps
21pub struct ParakeetDecoder {
22    tokenizer: tokenizers::Tokenizer,
23    pad_token_id: usize,
24}
25
26impl ParakeetDecoder {
27    pub fn from_pretrained<P: AsRef<Path>>(tokenizer_path: P) -> Result<Self> {
28        let tokenizer_path = tokenizer_path.as_ref();
29
30        let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
31            .map_err(|e| Error::Tokenizer(format!("Failed to load tokenizer: {e}")))?;
32
33        // Hardcoded pad_token_id for Parakeet-CTC-0.6b (constant across all models: please see def configs jsons: https://huggingface.co/onnx-community/parakeet-ctc-0.6b-ONNX/tree/main)
34        let pad_token_id = 1024;
35
36        Ok(Self {
37            tokenizer,
38            pad_token_id,
39        })
40    }
41
42    pub fn decode(&self, logits: &Array2<f32>) -> Result<String> {
43        let time_steps = logits.shape()[0];
44
45        let mut token_ids = Vec::new();
46        for t in 0..time_steps {
47            let logits_t = logits.row(t);
48            let max_idx = logits_t
49                .iter()
50                .enumerate()
51                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
52                .map(|(idx, _)| idx)
53                .unwrap_or(0);
54
55            token_ids.push(max_idx as u32);
56        }
57
58        let collapsed = self.ctc_collapse(&token_ids);
59
60        let text = self
61            .tokenizer
62            .decode(&collapsed, true)
63            .map_err(|e| Error::Tokenizer(format!("Failed to decode: {e}")))?;
64
65        Ok(text)
66    }
67
68    fn ctc_collapse(&self, token_ids: &[u32]) -> Vec<u32> {
69        let mut result = Vec::new();
70        let mut prev_token: Option<u32> = None;
71
72        for &token_id in token_ids {
73            if token_id == self.pad_token_id as u32 {
74                prev_token = Some(token_id);
75                continue;
76            }
77
78            if Some(token_id) != prev_token {
79                result.push(token_id);
80            }
81
82            prev_token = Some(token_id);
83        }
84
85        result
86    }
87
88    // CTC collapse with frame tracking for timestamps
89    fn ctc_collapse_with_frames(&self, token_ids: &[(u32, usize)]) -> Vec<(u32, usize, usize)> {
90        let mut result: Vec<(u32, usize, usize)> = Vec::new();
91        let mut prev_token: Option<u32> = None;
92
93        for &(token_id, frame) in token_ids.iter() {
94            if token_id == self.pad_token_id as u32 {
95                prev_token = Some(token_id);
96                continue;
97            }
98
99            if Some(token_id) != prev_token {
100                if let Some(prev) = prev_token {
101                    if prev != self.pad_token_id as u32 {
102                        // End previous token
103                        if let Some(last) = result.last_mut() {
104                            last.2 = frame;
105                        }
106                    }
107                }
108                // Start new token
109                result.push((token_id, frame, frame));
110            }
111
112            prev_token = Some(token_id);
113        }
114
115        // Close last token
116        if let Some(last) = result.last_mut() {
117            last.2 = token_ids.len();
118        }
119
120        result
121    }
122
123    // Decode with token-level timestamps
124    // hop_length and sample_rate are needed to convert frames to seconds
125    pub fn decode_with_timestamps(
126        &self,
127        logits: &Array2<f32>,
128        hop_length: usize,
129        sample_rate: usize,
130    ) -> Result<TranscriptionResult> {
131        let time_steps = logits.shape()[0];
132
133        let mut token_ids_with_frames = Vec::new();
134        for t in 0..time_steps {
135            let logits_t = logits.row(t);
136            let max_idx = logits_t
137                .iter()
138                .enumerate()
139                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
140                .map(|(idx, _)| idx)
141                .unwrap_or(0);
142
143            token_ids_with_frames.push((max_idx as u32, t));
144        }
145
146        // CTC collapse with frame tracking
147        let collapsed_with_frames = self.ctc_collapse_with_frames(&token_ids_with_frames);
148
149        // Extract just token IDs for decoding
150        let token_ids: Vec<u32> = collapsed_with_frames.iter().map(|(id, _, _)| *id).collect();
151
152        // Decode full text
153        let full_text = self
154            .tokenizer
155            .decode(&token_ids, true)
156            .map_err(|e| Error::Tokenizer(format!("Failed to decode: {e}")))?;
157
158        // Progressive decode to detect word boundaries
159        // BPE tokenizers only add spaces when decoding sequences, not individual tokens
160        let mut timed_tokens = Vec::new();
161        let mut prev_decode = String::new();
162
163        for (i, (_token_id, start_frame, end_frame)) in collapsed_with_frames.iter().enumerate() {
164            // Decode from start up to and including current token
165            let token_ids_so_far: Vec<u32> = collapsed_with_frames[0..=i]
166                .iter()
167                .map(|(id, _, _)| *id)
168                .collect();
169
170            if let Ok(curr_decode) = self.tokenizer.decode(&token_ids_so_far, true) {
171                // Find what this token added
172                let added_text = if curr_decode.len() > prev_decode.len() {
173                    &curr_decode[prev_decode.len()..]
174                } else {
175                    ""
176                };
177
178                if !added_text.is_empty() {
179                    let start_time = (*start_frame * hop_length) as f32 / sample_rate as f32;
180                    let end_time = (*end_frame * hop_length) as f32 / sample_rate as f32;
181
182                    timed_tokens.push(TimedToken {
183                        text: added_text.to_string(),
184                        start: start_time,
185                        end: end_time,
186                    });
187                }
188
189                prev_decode = curr_decode;
190            }
191        }
192
193        Ok(TranscriptionResult {
194            text: full_text,
195            tokens: timed_tokens,
196        })
197    }
198
199    // Stub - falls back to greedy decoding. Full beam search with language model is TODO.
200    pub fn decode_with_beam_search(
201        &self,
202        logits: &Array2<f32>,
203        _beam_width: usize,
204    ) -> Result<String> {
205        self.decode(logits)
206    }
207
208    pub fn pad_token_id(&self) -> usize {
209        self.pad_token_id
210    }
211}