1use crate::error::{Error, Result};
2use ndarray::Array2;
3use std::path::Path;
4
5#[derive(Debug, Clone)]
8pub struct TimedToken {
9 pub text: String,
10 pub start: f32,
11 pub end: f32,
12}
13
14#[derive(Debug, Clone)]
15pub struct TranscriptionResult {
16 pub text: String,
17 pub tokens: Vec<TimedToken>,
18}
19
20pub struct ParakeetDecoder {
22 tokenizer: tokenizers::Tokenizer,
23 pad_token_id: usize,
24}
25
26impl ParakeetDecoder {
27 pub fn from_pretrained<P: AsRef<Path>>(tokenizer_path: P) -> Result<Self> {
28 let tokenizer_path = tokenizer_path.as_ref();
29
30 let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
31 .map_err(|e| Error::Tokenizer(format!("Failed to load tokenizer: {e}")))?;
32
33 let pad_token_id = 1024;
35
36 Ok(Self {
37 tokenizer,
38 pad_token_id,
39 })
40 }
41
42 pub fn decode(&self, logits: &Array2<f32>) -> Result<String> {
43 let time_steps = logits.shape()[0];
44
45 let mut token_ids = Vec::new();
46 for t in 0..time_steps {
47 let logits_t = logits.row(t);
48 let max_idx = logits_t
49 .iter()
50 .enumerate()
51 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
52 .map(|(idx, _)| idx)
53 .unwrap_or(0);
54
55 token_ids.push(max_idx as u32);
56 }
57
58 let collapsed = self.ctc_collapse(&token_ids);
59
60 let text = self
61 .tokenizer
62 .decode(&collapsed, true)
63 .map_err(|e| Error::Tokenizer(format!("Failed to decode: {e}")))?;
64
65 Ok(text)
66 }
67
68 fn ctc_collapse(&self, token_ids: &[u32]) -> Vec<u32> {
69 let mut result = Vec::new();
70 let mut prev_token: Option<u32> = None;
71
72 for &token_id in token_ids {
73 if token_id == self.pad_token_id as u32 {
74 prev_token = Some(token_id);
75 continue;
76 }
77
78 if Some(token_id) != prev_token {
79 result.push(token_id);
80 }
81
82 prev_token = Some(token_id);
83 }
84
85 result
86 }
87
88 fn ctc_collapse_with_frames(&self, token_ids: &[(u32, usize)]) -> Vec<(u32, usize, usize)> {
90 let mut result: Vec<(u32, usize, usize)> = Vec::new();
91 let mut prev_token: Option<u32> = None;
92
93 for &(token_id, frame) in token_ids.iter() {
94 if token_id == self.pad_token_id as u32 {
95 prev_token = Some(token_id);
96 continue;
97 }
98
99 if Some(token_id) != prev_token {
100 if let Some(prev) = prev_token {
101 if prev != self.pad_token_id as u32 {
102 if let Some(last) = result.last_mut() {
104 last.2 = frame;
105 }
106 }
107 }
108 result.push((token_id, frame, frame));
110 }
111
112 prev_token = Some(token_id);
113 }
114
115 if let Some(last) = result.last_mut() {
117 last.2 = token_ids.len();
118 }
119
120 result
121 }
122
123 pub fn decode_with_timestamps(
126 &self,
127 logits: &Array2<f32>,
128 hop_length: usize,
129 sample_rate: usize,
130 ) -> Result<TranscriptionResult> {
131 let time_steps = logits.shape()[0];
132
133 let mut token_ids_with_frames = Vec::new();
134 for t in 0..time_steps {
135 let logits_t = logits.row(t);
136 let max_idx = logits_t
137 .iter()
138 .enumerate()
139 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
140 .map(|(idx, _)| idx)
141 .unwrap_or(0);
142
143 token_ids_with_frames.push((max_idx as u32, t));
144 }
145
146 let collapsed_with_frames = self.ctc_collapse_with_frames(&token_ids_with_frames);
148
149 let token_ids: Vec<u32> = collapsed_with_frames.iter().map(|(id, _, _)| *id).collect();
151
152 let full_text = self
154 .tokenizer
155 .decode(&token_ids, true)
156 .map_err(|e| Error::Tokenizer(format!("Failed to decode: {e}")))?;
157
158 let mut timed_tokens = Vec::new();
161 let mut prev_decode = String::new();
162
163 for (i, (_token_id, start_frame, end_frame)) in collapsed_with_frames.iter().enumerate() {
164 let token_ids_so_far: Vec<u32> = collapsed_with_frames[0..=i]
166 .iter()
167 .map(|(id, _, _)| *id)
168 .collect();
169
170 if let Ok(curr_decode) = self.tokenizer.decode(&token_ids_so_far, true) {
171 let added_text = if curr_decode.len() > prev_decode.len() {
173 &curr_decode[prev_decode.len()..]
174 } else {
175 ""
176 };
177
178 if !added_text.is_empty() {
179 let start_time = (*start_frame * hop_length) as f32 / sample_rate as f32;
180 let end_time = (*end_frame * hop_length) as f32 / sample_rate as f32;
181
182 timed_tokens.push(TimedToken {
183 text: added_text.to_string(),
184 start: start_time,
185 end: end_time,
186 });
187 }
188
189 prev_decode = curr_decode;
190 }
191 }
192
193 Ok(TranscriptionResult {
194 text: full_text,
195 tokens: timed_tokens,
196 })
197 }
198
199 pub fn decode_with_beam_search(
201 &self,
202 logits: &Array2<f32>,
203 _beam_width: usize,
204 ) -> Result<String> {
205 self.decode(logits)
206 }
207
208 pub fn pad_token_id(&self) -> usize {
209 self.pad_token_id
210 }
211}