1use crate::audio::{self, load_audio};
2use crate::config::PreprocessorConfig;
3use crate::decoder::{TimedToken, TranscriptionResult};
4use crate::error::{Error, Result};
5use crate::execution::ModelConfig as ExecutionConfig;
6use crate::model_unified::{ParakeetUnifiedModel, UnifiedModelConfig};
7use crate::nemotron::SentencePieceVocab;
8use crate::timestamps::{process_timestamps, TimestampMode};
9use crate::transcriber::Transcriber;
10use ndarray::Array3;
11use std::path::Path;
12use std::sync::{Arc, Mutex};
13
14const SAMPLE_RATE: usize = 16000;
15const FEATURE_SIZE: usize = 128;
16const HOP_LENGTH: usize = 160;
17const N_FFT: usize = 512;
18const WIN_LENGTH: usize = 400;
19const PREEMPHASIS: f32 = 0.97;
20const DECODER_LSTM_DIM: usize = 640;
21const DECODER_LSTM_LAYERS: usize = 2;
22const SUBSAMPLING_FACTOR: usize = 8;
23const MAX_SYMBOLS_PER_STEP: usize = 10;
24
25#[derive(Debug, Clone, Copy)]
26pub struct UnifiedStreamingConfig {
27 pub left_context_secs: f32,
28 pub chunk_secs: f32,
29 pub right_context_secs: f32,
30}
31
32impl Default for UnifiedStreamingConfig {
33 fn default() -> Self {
34 Self {
35 left_context_secs: 5.6,
36 chunk_secs: 0.56,
37 right_context_secs: 0.56,
38 }
39 }
40}
41
42impl UnifiedStreamingConfig {
43 fn frames_from_secs(secs: f32) -> usize {
44 ((secs * SAMPLE_RATE as f32) / HOP_LENGTH as f32).round() as usize
45 }
46
47 pub fn validate(self) -> Result<Self> {
48 let left_frames = self.left_context_frames();
49 let chunk_frames = self.chunk_frames();
50 let right_frames = self.right_context_frames();
51
52 if chunk_frames == 0 {
53 return Err(Error::Config(
54 "Unified streaming chunk size must be greater than zero".to_string(),
55 ));
56 }
57
58 for (name, frames) in [
59 ("left_context_secs", left_frames),
60 ("chunk_secs", chunk_frames),
61 ("right_context_secs", right_frames),
62 ] {
63 if frames % SUBSAMPLING_FACTOR != 0 {
64 return Err(Error::Config(format!(
65 "{name} must map to a mel-frame count divisible by {SUBSAMPLING_FACTOR}"
66 )));
67 }
68 }
69
70 Ok(self)
71 }
72
73 pub fn left_context_frames(self) -> usize {
74 Self::frames_from_secs(self.left_context_secs)
75 }
76
77 pub fn chunk_frames(self) -> usize {
78 Self::frames_from_secs(self.chunk_secs)
79 }
80
81 pub fn right_context_frames(self) -> usize {
82 Self::frames_from_secs(self.right_context_secs)
83 }
84
85 pub fn total_window_frames(self) -> usize {
86 self.left_context_frames() + self.chunk_frames() + self.right_context_frames()
87 }
88
89 pub fn left_context_samples(self) -> usize {
90 self.left_context_frames() * HOP_LENGTH
91 }
92
93 pub fn chunk_samples(self) -> usize {
94 self.chunk_frames() * HOP_LENGTH
95 }
96
97 pub fn right_context_samples(self) -> usize {
98 self.right_context_frames() * HOP_LENGTH
99 }
100
101 pub fn total_window_samples(self) -> usize {
102 self.total_window_frames() * HOP_LENGTH
103 }
104
105 pub fn chunk_encoder_frames(self) -> usize {
106 self.chunk_frames() / SUBSAMPLING_FACTOR
107 }
108
109 pub fn left_context_encoder_frames(self) -> usize {
110 self.left_context_frames() / SUBSAMPLING_FACTOR
111 }
112}
113
114#[derive(Clone)]
120pub struct ParakeetUnifiedHandle {
121 model: Arc<Mutex<ParakeetUnifiedModel>>,
122 vocab: Arc<SentencePieceVocab>,
123 preprocessor_config: Arc<PreprocessorConfig>,
124 blank_id: usize,
125}
126
127pub struct ParakeetUnified {
128 model: Arc<Mutex<ParakeetUnifiedModel>>,
129 vocab: Arc<SentencePieceVocab>,
130 preprocessor_config: Arc<PreprocessorConfig>,
131 state_1: Array3<f32>,
132 state_2: Array3<f32>,
133 last_token: i32,
134 blank_id: usize,
135 streaming_config: UnifiedStreamingConfig,
136 audio_buffer: Vec<f32>,
137 buffer_start_sample: usize,
138 next_chunk_start_sample: usize,
139 accumulated_tokens: Vec<usize>,
140 accumulated_timed_tokens: Vec<TimedToken>,
141}
142
143impl ParakeetUnifiedHandle {
144 pub fn load<P: AsRef<Path>>(
147 path: P,
148 exec_config: Option<ExecutionConfig>,
149 ) -> Result<Self> {
150 let path = path.as_ref();
151 let vocab = SentencePieceVocab::from_file(path.join("tokenizer.model"))?;
152 let blank_id = vocab.size();
153
154 let model_config = UnifiedModelConfig {
155 vocab_size: vocab.size() + 1,
156 blank_id,
157 decoder_lstm_dim: DECODER_LSTM_DIM,
158 decoder_lstm_layers: DECODER_LSTM_LAYERS,
159 subsampling_factor: SUBSAMPLING_FACTOR,
160 };
161
162 let model = ParakeetUnifiedModel::from_pretrained(
163 path,
164 exec_config.unwrap_or_default(),
165 model_config,
166 )?;
167
168 let preprocessor_config = PreprocessorConfig {
169 feature_extractor_type: "ParakeetFeatureExtractor".to_string(),
170 feature_size: FEATURE_SIZE,
171 hop_length: HOP_LENGTH,
172 n_fft: N_FFT,
173 padding_side: "right".to_string(),
174 padding_value: 0.0,
175 preemphasis: PREEMPHASIS,
176 processor_class: "ParakeetProcessor".to_string(),
177 return_attention_mask: true,
178 sampling_rate: SAMPLE_RATE,
179 win_length: WIN_LENGTH,
180 };
181
182 Ok(Self {
183 model: Arc::new(Mutex::new(model)),
184 vocab: Arc::new(vocab),
185 preprocessor_config: Arc::new(preprocessor_config),
186 blank_id,
187 })
188 }
189}
190
191impl ParakeetUnified {
192 pub fn from_pretrained<P: AsRef<Path>>(
193 path: P,
194 exec_config: Option<ExecutionConfig>,
195 ) -> Result<Self> {
196 Self::from_pretrained_with_streaming_config(
197 path,
198 exec_config,
199 UnifiedStreamingConfig::default(),
200 )
201 }
202
203 pub fn from_pretrained_with_streaming_config<P: AsRef<Path>>(
204 path: P,
205 exec_config: Option<ExecutionConfig>,
206 streaming_config: UnifiedStreamingConfig,
207 ) -> Result<Self> {
208 let handle = ParakeetUnifiedHandle::load(path, exec_config)?;
209 Self::from_shared_with_streaming_config(&handle, streaming_config)
210 }
211
212 pub fn from_shared(handle: &ParakeetUnifiedHandle) -> Self {
215 Self::from_shared_with_streaming_config(handle, UnifiedStreamingConfig::default())
217 .expect("default UnifiedStreamingConfig is always valid")
218 }
219
220 pub fn from_shared_with_streaming_config(
224 handle: &ParakeetUnifiedHandle,
225 streaming_config: UnifiedStreamingConfig,
226 ) -> Result<Self> {
227 let streaming_config = streaming_config.validate()?;
228 let blank_id = handle.blank_id;
229
230 Ok(Self {
231 model: Arc::clone(&handle.model),
232 vocab: Arc::clone(&handle.vocab),
233 preprocessor_config: Arc::clone(&handle.preprocessor_config),
234 state_1: Array3::zeros((DECODER_LSTM_LAYERS, 1, DECODER_LSTM_DIM)),
235 state_2: Array3::zeros((DECODER_LSTM_LAYERS, 1, DECODER_LSTM_DIM)),
236 last_token: blank_id as i32,
237 blank_id,
238 streaming_config,
239 audio_buffer: Vec::new(),
240 buffer_start_sample: 0,
241 next_chunk_start_sample: 0,
242 accumulated_tokens: Vec::new(),
243 accumulated_timed_tokens: Vec::new(),
244 })
245 }
246
247 pub fn streaming_config(&self) -> UnifiedStreamingConfig {
248 self.streaming_config
249 }
250
251 pub fn preprocessor_config(&self) -> &PreprocessorConfig {
252 &self.preprocessor_config
253 }
254
255 pub fn reset(&mut self) {
256 self.state_1.fill(0.0);
257 self.state_2.fill(0.0);
258 self.last_token = self.blank_id as i32;
259 self.audio_buffer.clear();
260 self.buffer_start_sample = 0;
261 self.next_chunk_start_sample = 0;
262 self.accumulated_tokens.clear();
263 self.accumulated_timed_tokens.clear();
264 }
265
266 pub fn get_timed_transcript(&self, mode: TimestampMode) -> TranscriptionResult {
267 let text = self.get_transcript();
268 let tokens = process_timestamps(&self.accumulated_timed_tokens, mode);
269 TranscriptionResult { text, tokens }
270 }
271
272 pub fn get_transcript(&self) -> String {
273 let valid: Vec<usize> = self
274 .accumulated_tokens
275 .iter()
276 .copied()
277 .filter(|&token| token < self.blank_id)
278 .collect();
279 self.vocab.decode(&valid)
280 }
281
282 pub fn transcribe_audio(
283 &mut self,
284 audio: Vec<f32>,
285 sample_rate: u32,
286 channels: u16,
287 ) -> Result<String> {
288 self.transcribe_offline(audio, sample_rate, channels, None)
289 .map(|result| result.text)
290 }
291
292 pub fn transcribe_file<P: AsRef<Path>>(&mut self, audio_path: P) -> Result<String> {
293 let (audio, spec) = load_audio(audio_path)?;
294 self.transcribe_audio(audio, spec.sample_rate, spec.channels)
295 }
296
297 pub fn transcribe_chunk(&mut self, audio_chunk: &[f32]) -> Result<String> {
298 self.audio_buffer.extend_from_slice(audio_chunk);
299 self.process_ready_chunks(false)
300 }
301
302 pub fn flush(&mut self) -> Result<String> {
303 self.process_ready_chunks(true)
304 }
305
306 fn process_ready_chunks(&mut self, flush: bool) -> Result<String> {
307 let mut emitted = String::new();
308 let chunk_samples = self.streaming_config.chunk_samples();
309 let right_context_samples = self.streaming_config.right_context_samples();
310
311 loop {
312 let total_received = self.buffer_start_sample + self.audio_buffer.len();
313 let ready = if flush {
314 total_received > self.next_chunk_start_sample
315 } else {
316 total_received
317 >= self.next_chunk_start_sample + chunk_samples + right_context_samples
318 };
319
320 if !ready {
321 break;
322 }
323
324 let (window_audio, left_encoder_frames, chunk_encoder_frames) =
325 self.build_window_audio(self.next_chunk_start_sample, total_received, flush);
326 if chunk_encoder_frames == 0 {
327 break;
328 }
329
330 let features = audio::extract_features_raw(
331 window_audio,
332 SAMPLE_RATE as u32,
333 1,
334 &self.preprocessor_config,
335 )?;
336 let (encoded, encoded_len) = {
337 let mut model = self.model.lock().map_err(|e| {
338 Error::Model(format!("Failed to acquire model lock: {e}"))
339 })?;
340 model.run_encoder(&features)?
341 };
342
343 let available_frames = (encoded_len as usize).min(encoded.shape()[2]);
344 let start_frame = left_encoder_frames.min(available_frames);
345 let end_frame = (start_frame + chunk_encoder_frames).min(available_frames);
346
347 let absolute_frame_offset =
348 self.next_chunk_start_sample / (HOP_LENGTH * SUBSAMPLING_FACTOR);
349 let tokens =
350 self.decode_encoder_frames(&encoded, start_frame, end_frame, absolute_frame_offset)?;
351 self.accumulated_tokens
352 .extend(tokens.iter().map(|(id, _)| *id));
353 self.accumulated_timed_tokens
354 .extend(self.tokens_to_timed(&tokens));
355 emitted.push_str(&self.decode_incremental_tokens(&tokens));
356
357 self.next_chunk_start_sample += chunk_samples;
358 self.trim_audio_buffer();
359
360 if flush && total_received <= self.next_chunk_start_sample {
361 break;
362 }
363 }
364
365 Ok(emitted)
366 }
367
368 fn build_window_audio(
369 &self,
370 chunk_start_sample: usize,
371 total_received: usize,
372 flush: bool,
373 ) -> (Vec<f32>, usize, usize) {
374 let left_context_samples = self.streaming_config.left_context_samples();
375 let chunk_samples = self.streaming_config.chunk_samples();
376 let right_context_samples = self.streaming_config.right_context_samples();
377
378 let available_left = chunk_start_sample.saturating_sub(self.buffer_start_sample);
379 let available_left = available_left.min(left_context_samples);
380 let available_main = total_received.saturating_sub(chunk_start_sample).min(chunk_samples);
381 let available_right = if flush {
382 total_received
383 .saturating_sub(chunk_start_sample + available_main)
384 .min(right_context_samples)
385 } else {
386 right_context_samples
387 };
388
389 let window_start = chunk_start_sample.saturating_sub(available_left);
390 let window_end = chunk_start_sample + available_main + available_right;
391 let total_window_samples = window_end.saturating_sub(window_start);
392
393 let left_encoder_frames = (available_left / HOP_LENGTH) / SUBSAMPLING_FACTOR;
394 let chunk_encoder_frames = (available_main / HOP_LENGTH) / SUBSAMPLING_FACTOR;
395
396 let mut window = vec![0.0f32; total_window_samples];
397 let buffer_end = self.buffer_start_sample + self.audio_buffer.len();
398 let copy_start = window_start.max(self.buffer_start_sample);
399 let copy_end = window_end.min(buffer_end);
400
401 if copy_end > copy_start {
402 let src_start = copy_start - self.buffer_start_sample;
403 let dst_start = copy_start - window_start;
404 let len = copy_end - copy_start;
405 window[dst_start..dst_start + len]
406 .copy_from_slice(&self.audio_buffer[src_start..src_start + len]);
407 }
408
409 (window, left_encoder_frames, chunk_encoder_frames)
410 }
411
412 fn trim_audio_buffer(&mut self) {
413 let keep_from = self
414 .next_chunk_start_sample
415 .saturating_sub(self.streaming_config.left_context_samples());
416 if keep_from <= self.buffer_start_sample {
417 return;
418 }
419
420 let drop = keep_from - self.buffer_start_sample;
421 if drop == 0 {
422 return;
423 }
424
425 if drop >= self.audio_buffer.len() {
426 self.audio_buffer.clear();
427 self.buffer_start_sample = keep_from;
428 return;
429 }
430
431 self.audio_buffer.drain(0..drop);
432 self.buffer_start_sample = keep_from;
433 }
434
435 fn decode_encoder_frames(
436 &mut self,
437 encoder_out: &Array3<f32>,
438 start_frame: usize,
439 end_frame: usize,
440 absolute_frame_offset: usize,
441 ) -> Result<Vec<(usize, usize)>> {
442 let mut tokens = Vec::new();
443 let hidden_dim = encoder_out.shape()[1];
444 let end_frame = end_frame.min(encoder_out.shape()[2]);
445
446 let mut model = self
448 .model
449 .lock()
450 .map_err(|e| Error::Model(format!("Failed to acquire model lock: {e}")))?;
451
452 for frame_idx in start_frame..end_frame {
453 let frame = encoder_out
454 .slice(ndarray::s![0, .., frame_idx])
455 .to_owned()
456 .to_shape((1, hidden_dim, 1))
457 .map_err(|e| Error::Model(format!("Failed to reshape encoder frame: {e}")))?
458 .to_owned();
459
460 let absolute_frame = absolute_frame_offset + (frame_idx - start_frame);
461
462 for _ in 0..MAX_SYMBOLS_PER_STEP {
463 let (logits, new_state_1, new_state_2) = model.run_decoder(
464 &frame,
465 self.last_token,
466 &self.state_1,
467 &self.state_2,
468 )?;
469
470 let token_id = logits
471 .iter()
472 .enumerate()
473 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
474 .map(|(idx, _)| idx)
475 .unwrap_or(self.blank_id);
476
477 if token_id == self.blank_id {
478 break;
479 }
480
481 tokens.push((token_id, absolute_frame));
482 self.last_token = token_id as i32;
483 self.state_1 = new_state_1;
484 self.state_2 = new_state_2;
485 }
486 }
487
488 Ok(tokens)
489 }
490
491 fn encoder_frame_to_seconds(frame: usize) -> f32 {
492 (frame * SUBSAMPLING_FACTOR * HOP_LENGTH) as f32 / SAMPLE_RATE as f32
493 }
494
495 fn tokens_to_timed(&self, tokens: &[(usize, usize)]) -> Vec<TimedToken> {
496 tokens
497 .iter()
498 .filter(|(id, _)| *id < self.blank_id)
499 .map(|&(id, frame)| TimedToken {
500 text: self.vocab.decode_single(id),
501 start: Self::encoder_frame_to_seconds(frame),
502 end: Self::encoder_frame_to_seconds(frame + 1),
503 })
504 .collect()
505 }
506
507 fn decode_incremental_tokens(&self, tokens: &[(usize, usize)]) -> String {
508 let mut text = String::new();
509 for &(token, _) in tokens {
510 if token < self.blank_id {
511 text.push_str(&self.vocab.decode_single(token));
512 }
513 }
514 text
515 }
516
517 fn transcribe_offline(
518 &mut self,
519 audio: Vec<f32>,
520 sample_rate: u32,
521 channels: u16,
522 mode: Option<TimestampMode>,
523 ) -> Result<TranscriptionResult> {
524 self.reset();
525
526 let features = audio::extract_features_raw(audio, sample_rate, channels, &self.preprocessor_config)?;
527 let (encoded, encoded_len) = {
528 let mut model = self
529 .model
530 .lock()
531 .map_err(|e| Error::Model(format!("Failed to acquire model lock: {e}")))?;
532 model.run_encoder(&features)?
533 };
534 let frame_count = (encoded_len as usize).min(encoded.shape()[2]);
535 let tokens = self.decode_encoder_frames(&encoded, 0, frame_count, 0)?;
536 self.accumulated_tokens = tokens.iter().map(|(id, _)| *id).collect();
537 self.accumulated_timed_tokens = self.tokens_to_timed(&tokens);
538
539 let text = self.get_transcript();
540 let timed = match mode {
541 Some(m) => process_timestamps(&self.accumulated_timed_tokens, m),
542 None => self.accumulated_timed_tokens.clone(),
543 };
544
545 Ok(TranscriptionResult {
546 text,
547 tokens: timed,
548 })
549 }
550}
551
552impl Transcriber for ParakeetUnified {
553 fn transcribe_samples(
554 &mut self,
555 audio: Vec<f32>,
556 sample_rate: u32,
557 channels: u16,
558 mode: Option<TimestampMode>,
559 ) -> Result<TranscriptionResult> {
560 self.transcribe_offline(audio, sample_rate, channels, mode)
561 }
562}
563
564#[cfg(test)]
565mod tests {
566 use super::UnifiedStreamingConfig;
567
568 #[test]
569 fn default_streaming_profile_aligns_to_subsampling() {
570 let config = UnifiedStreamingConfig::default().validate().unwrap();
571 assert_eq!(config.left_context_frames(), 560);
572 assert_eq!(config.chunk_frames(), 56);
573 assert_eq!(config.right_context_frames(), 56);
574 assert_eq!(config.left_context_encoder_frames(), 70);
575 assert_eq!(config.chunk_encoder_frames(), 7);
576 }
577}