Skip to main content

parakeet_rs/
parakeet_unified.rs

1use crate::audio::{self, load_audio};
2use crate::config::PreprocessorConfig;
3use crate::decoder::{TimedToken, TranscriptionResult};
4use crate::error::{Error, Result};
5use crate::execution::ModelConfig as ExecutionConfig;
6use crate::model_unified::{ParakeetUnifiedModel, UnifiedModelConfig};
7use crate::nemotron::SentencePieceVocab;
8use crate::timestamps::{process_timestamps, TimestampMode};
9use crate::transcriber::Transcriber;
10use ndarray::Array3;
11use std::path::Path;
12use std::sync::{Arc, Mutex};
13
14const SAMPLE_RATE: usize = 16000;
15const FEATURE_SIZE: usize = 128;
16const HOP_LENGTH: usize = 160;
17const N_FFT: usize = 512;
18const WIN_LENGTH: usize = 400;
19const PREEMPHASIS: f32 = 0.97;
20const DECODER_LSTM_DIM: usize = 640;
21const DECODER_LSTM_LAYERS: usize = 2;
22const SUBSAMPLING_FACTOR: usize = 8;
23const MAX_SYMBOLS_PER_STEP: usize = 10;
24
25#[derive(Debug, Clone, Copy)]
26pub struct UnifiedStreamingConfig {
27    pub left_context_secs: f32,
28    pub chunk_secs: f32,
29    pub right_context_secs: f32,
30}
31
32impl Default for UnifiedStreamingConfig {
33    fn default() -> Self {
34        Self {
35            left_context_secs: 5.6,
36            chunk_secs: 0.56,
37            right_context_secs: 0.56,
38        }
39    }
40}
41
42impl UnifiedStreamingConfig {
43    fn frames_from_secs(secs: f32) -> usize {
44        ((secs * SAMPLE_RATE as f32) / HOP_LENGTH as f32).round() as usize
45    }
46
47    pub fn validate(self) -> Result<Self> {
48        let left_frames = self.left_context_frames();
49        let chunk_frames = self.chunk_frames();
50        let right_frames = self.right_context_frames();
51
52        if chunk_frames == 0 {
53            return Err(Error::Config(
54                "Unified streaming chunk size must be greater than zero".to_string(),
55            ));
56        }
57
58        for (name, frames) in [
59            ("left_context_secs", left_frames),
60            ("chunk_secs", chunk_frames),
61            ("right_context_secs", right_frames),
62        ] {
63            if frames % SUBSAMPLING_FACTOR != 0 {
64                return Err(Error::Config(format!(
65                    "{name} must map to a mel-frame count divisible by {SUBSAMPLING_FACTOR}"
66                )));
67            }
68        }
69
70        Ok(self)
71    }
72
73    pub fn left_context_frames(self) -> usize {
74        Self::frames_from_secs(self.left_context_secs)
75    }
76
77    pub fn chunk_frames(self) -> usize {
78        Self::frames_from_secs(self.chunk_secs)
79    }
80
81    pub fn right_context_frames(self) -> usize {
82        Self::frames_from_secs(self.right_context_secs)
83    }
84
85    pub fn total_window_frames(self) -> usize {
86        self.left_context_frames() + self.chunk_frames() + self.right_context_frames()
87    }
88
89    pub fn left_context_samples(self) -> usize {
90        self.left_context_frames() * HOP_LENGTH
91    }
92
93    pub fn chunk_samples(self) -> usize {
94        self.chunk_frames() * HOP_LENGTH
95    }
96
97    pub fn right_context_samples(self) -> usize {
98        self.right_context_frames() * HOP_LENGTH
99    }
100
101    pub fn total_window_samples(self) -> usize {
102        self.total_window_frames() * HOP_LENGTH
103    }
104
105    pub fn chunk_encoder_frames(self) -> usize {
106        self.chunk_frames() / SUBSAMPLING_FACTOR
107    }
108
109    pub fn left_context_encoder_frames(self) -> usize {
110        self.left_context_frames() / SUBSAMPLING_FACTOR
111    }
112}
113
114/// Shared handle to a loaded ParakeetUnified model.
115/// The ONNX session is loaded once and reference-counted.
116///
117/// Use [`ParakeetUnifiedHandle::load`] to load from disk, then
118/// [`ParakeetUnified::from_shared`] to spawn each stream with its own state.
119#[derive(Clone)]
120pub struct ParakeetUnifiedHandle {
121    model: Arc<Mutex<ParakeetUnifiedModel>>,
122    vocab: Arc<SentencePieceVocab>,
123    preprocessor_config: Arc<PreprocessorConfig>,
124    blank_id: usize,
125}
126
127pub struct ParakeetUnified {
128    model: Arc<Mutex<ParakeetUnifiedModel>>,
129    vocab: Arc<SentencePieceVocab>,
130    preprocessor_config: Arc<PreprocessorConfig>,
131    state_1: Array3<f32>,
132    state_2: Array3<f32>,
133    last_token: i32,
134    blank_id: usize,
135    streaming_config: UnifiedStreamingConfig,
136    audio_buffer: Vec<f32>,
137    buffer_start_sample: usize,
138    next_chunk_start_sample: usize,
139    accumulated_tokens: Vec<usize>,
140    accumulated_timed_tokens: Vec<TimedToken>,
141}
142
143impl ParakeetUnifiedHandle {
144    /// Load the ParakeetUnified model, vocabulary, and preprocessor config
145    /// from a directory.
146    pub fn load<P: AsRef<Path>>(
147        path: P,
148        exec_config: Option<ExecutionConfig>,
149    ) -> Result<Self> {
150        let path = path.as_ref();
151        let vocab = SentencePieceVocab::from_file(path.join("tokenizer.model"))?;
152        let blank_id = vocab.size();
153
154        let model_config = UnifiedModelConfig {
155            vocab_size: vocab.size() + 1,
156            blank_id,
157            decoder_lstm_dim: DECODER_LSTM_DIM,
158            decoder_lstm_layers: DECODER_LSTM_LAYERS,
159            subsampling_factor: SUBSAMPLING_FACTOR,
160        };
161
162        let model = ParakeetUnifiedModel::from_pretrained(
163            path,
164            exec_config.unwrap_or_default(),
165            model_config,
166        )?;
167
168        let preprocessor_config = PreprocessorConfig {
169            feature_extractor_type: "ParakeetFeatureExtractor".to_string(),
170            feature_size: FEATURE_SIZE,
171            hop_length: HOP_LENGTH,
172            n_fft: N_FFT,
173            padding_side: "right".to_string(),
174            padding_value: 0.0,
175            preemphasis: PREEMPHASIS,
176            processor_class: "ParakeetProcessor".to_string(),
177            return_attention_mask: true,
178            sampling_rate: SAMPLE_RATE,
179            win_length: WIN_LENGTH,
180        };
181
182        Ok(Self {
183            model: Arc::new(Mutex::new(model)),
184            vocab: Arc::new(vocab),
185            preprocessor_config: Arc::new(preprocessor_config),
186            blank_id,
187        })
188    }
189}
190
191impl ParakeetUnified {
192    pub fn from_pretrained<P: AsRef<Path>>(
193        path: P,
194        exec_config: Option<ExecutionConfig>,
195    ) -> Result<Self> {
196        Self::from_pretrained_with_streaming_config(
197            path,
198            exec_config,
199            UnifiedStreamingConfig::default(),
200        )
201    }
202
203    pub fn from_pretrained_with_streaming_config<P: AsRef<Path>>(
204        path: P,
205        exec_config: Option<ExecutionConfig>,
206        streaming_config: UnifiedStreamingConfig,
207    ) -> Result<Self> {
208        let handle = ParakeetUnifiedHandle::load(path, exec_config)?;
209        Self::from_shared_with_streaming_config(&handle, streaming_config)
210    }
211
212    /// Spawn a new ParakeetUnified instance bound to a shared model, using the
213    /// default streaming profile.
214    pub fn from_shared(handle: &ParakeetUnifiedHandle) -> Self {
215        // default config is pre-validated, so unwrap is safe
216        Self::from_shared_with_streaming_config(handle, UnifiedStreamingConfig::default())
217            .expect("default UnifiedStreamingConfig is always valid")
218    }
219
220    /// Spawn a new ParakeetUnified instance bound to a shared model with a
221    /// custom streaming profile. Each instance owns independent decoder and
222    /// audio-buffer state; the ONNX session is shared through the handle.
223    pub fn from_shared_with_streaming_config(
224        handle: &ParakeetUnifiedHandle,
225        streaming_config: UnifiedStreamingConfig,
226    ) -> Result<Self> {
227        let streaming_config = streaming_config.validate()?;
228        let blank_id = handle.blank_id;
229
230        Ok(Self {
231            model: Arc::clone(&handle.model),
232            vocab: Arc::clone(&handle.vocab),
233            preprocessor_config: Arc::clone(&handle.preprocessor_config),
234            state_1: Array3::zeros((DECODER_LSTM_LAYERS, 1, DECODER_LSTM_DIM)),
235            state_2: Array3::zeros((DECODER_LSTM_LAYERS, 1, DECODER_LSTM_DIM)),
236            last_token: blank_id as i32,
237            blank_id,
238            streaming_config,
239            audio_buffer: Vec::new(),
240            buffer_start_sample: 0,
241            next_chunk_start_sample: 0,
242            accumulated_tokens: Vec::new(),
243            accumulated_timed_tokens: Vec::new(),
244        })
245    }
246
247    pub fn streaming_config(&self) -> UnifiedStreamingConfig {
248        self.streaming_config
249    }
250
251    pub fn preprocessor_config(&self) -> &PreprocessorConfig {
252        &self.preprocessor_config
253    }
254
255    pub fn reset(&mut self) {
256        self.state_1.fill(0.0);
257        self.state_2.fill(0.0);
258        self.last_token = self.blank_id as i32;
259        self.audio_buffer.clear();
260        self.buffer_start_sample = 0;
261        self.next_chunk_start_sample = 0;
262        self.accumulated_tokens.clear();
263        self.accumulated_timed_tokens.clear();
264    }
265
266    pub fn get_timed_transcript(&self, mode: TimestampMode) -> TranscriptionResult {
267        let text = self.get_transcript();
268        let tokens = process_timestamps(&self.accumulated_timed_tokens, mode);
269        TranscriptionResult { text, tokens }
270    }
271
272    pub fn get_transcript(&self) -> String {
273        let valid: Vec<usize> = self
274            .accumulated_tokens
275            .iter()
276            .copied()
277            .filter(|&token| token < self.blank_id)
278            .collect();
279        self.vocab.decode(&valid)
280    }
281
282    pub fn transcribe_audio(
283        &mut self,
284        audio: Vec<f32>,
285        sample_rate: u32,
286        channels: u16,
287    ) -> Result<String> {
288        self.transcribe_offline(audio, sample_rate, channels, None)
289            .map(|result| result.text)
290    }
291
292    pub fn transcribe_file<P: AsRef<Path>>(&mut self, audio_path: P) -> Result<String> {
293        let (audio, spec) = load_audio(audio_path)?;
294        self.transcribe_audio(audio, spec.sample_rate, spec.channels)
295    }
296
297    pub fn transcribe_chunk(&mut self, audio_chunk: &[f32]) -> Result<String> {
298        self.audio_buffer.extend_from_slice(audio_chunk);
299        self.process_ready_chunks(false)
300    }
301
302    pub fn flush(&mut self) -> Result<String> {
303        self.process_ready_chunks(true)
304    }
305
306    fn process_ready_chunks(&mut self, flush: bool) -> Result<String> {
307        let mut emitted = String::new();
308        let chunk_samples = self.streaming_config.chunk_samples();
309        let right_context_samples = self.streaming_config.right_context_samples();
310
311        loop {
312            let total_received = self.buffer_start_sample + self.audio_buffer.len();
313            let ready = if flush {
314                total_received > self.next_chunk_start_sample
315            } else {
316                total_received
317                    >= self.next_chunk_start_sample + chunk_samples + right_context_samples
318            };
319
320            if !ready {
321                break;
322            }
323
324            let (window_audio, left_encoder_frames, chunk_encoder_frames) =
325                self.build_window_audio(self.next_chunk_start_sample, total_received, flush);
326            if chunk_encoder_frames == 0 {
327                break;
328            }
329
330            let features = audio::extract_features_raw(
331                window_audio,
332                SAMPLE_RATE as u32,
333                1,
334                &self.preprocessor_config,
335            )?;
336            let (encoded, encoded_len) = {
337                let mut model = self.model.lock().map_err(|e| {
338                    Error::Model(format!("Failed to acquire model lock: {e}"))
339                })?;
340                model.run_encoder(&features)?
341            };
342
343            let available_frames = (encoded_len as usize).min(encoded.shape()[2]);
344            let start_frame = left_encoder_frames.min(available_frames);
345            let end_frame = (start_frame + chunk_encoder_frames).min(available_frames);
346
347            let absolute_frame_offset =
348                self.next_chunk_start_sample / (HOP_LENGTH * SUBSAMPLING_FACTOR);
349            let tokens =
350                self.decode_encoder_frames(&encoded, start_frame, end_frame, absolute_frame_offset)?;
351            self.accumulated_tokens
352                .extend(tokens.iter().map(|(id, _)| *id));
353            self.accumulated_timed_tokens
354                .extend(self.tokens_to_timed(&tokens));
355            emitted.push_str(&self.decode_incremental_tokens(&tokens));
356
357            self.next_chunk_start_sample += chunk_samples;
358            self.trim_audio_buffer();
359
360            if flush && total_received <= self.next_chunk_start_sample {
361                break;
362            }
363        }
364
365        Ok(emitted)
366    }
367
368    fn build_window_audio(
369        &self,
370        chunk_start_sample: usize,
371        total_received: usize,
372        flush: bool,
373    ) -> (Vec<f32>, usize, usize) {
374        let left_context_samples = self.streaming_config.left_context_samples();
375        let chunk_samples = self.streaming_config.chunk_samples();
376        let right_context_samples = self.streaming_config.right_context_samples();
377
378        let available_left = chunk_start_sample.saturating_sub(self.buffer_start_sample);
379        let available_left = available_left.min(left_context_samples);
380        let available_main = total_received.saturating_sub(chunk_start_sample).min(chunk_samples);
381        let available_right = if flush {
382            total_received
383                .saturating_sub(chunk_start_sample + available_main)
384                .min(right_context_samples)
385        } else {
386            right_context_samples
387        };
388
389        let window_start = chunk_start_sample.saturating_sub(available_left);
390        let window_end = chunk_start_sample + available_main + available_right;
391        let total_window_samples = window_end.saturating_sub(window_start);
392
393        let left_encoder_frames = (available_left / HOP_LENGTH) / SUBSAMPLING_FACTOR;
394        let chunk_encoder_frames = (available_main / HOP_LENGTH) / SUBSAMPLING_FACTOR;
395
396        let mut window = vec![0.0f32; total_window_samples];
397        let buffer_end = self.buffer_start_sample + self.audio_buffer.len();
398        let copy_start = window_start.max(self.buffer_start_sample);
399        let copy_end = window_end.min(buffer_end);
400
401        if copy_end > copy_start {
402            let src_start = copy_start - self.buffer_start_sample;
403            let dst_start = copy_start - window_start;
404            let len = copy_end - copy_start;
405            window[dst_start..dst_start + len]
406                .copy_from_slice(&self.audio_buffer[src_start..src_start + len]);
407        }
408
409        (window, left_encoder_frames, chunk_encoder_frames)
410    }
411
412    fn trim_audio_buffer(&mut self) {
413        let keep_from = self
414            .next_chunk_start_sample
415            .saturating_sub(self.streaming_config.left_context_samples());
416        if keep_from <= self.buffer_start_sample {
417            return;
418        }
419
420        let drop = keep_from - self.buffer_start_sample;
421        if drop == 0 {
422            return;
423        }
424
425        if drop >= self.audio_buffer.len() {
426            self.audio_buffer.clear();
427            self.buffer_start_sample = keep_from;
428            return;
429        }
430
431        self.audio_buffer.drain(0..drop);
432        self.buffer_start_sample = keep_from;
433    }
434
435    fn decode_encoder_frames(
436        &mut self,
437        encoder_out: &Array3<f32>,
438        start_frame: usize,
439        end_frame: usize,
440        absolute_frame_offset: usize,
441    ) -> Result<Vec<(usize, usize)>> {
442        let mut tokens = Vec::new();
443        let hidden_dim = encoder_out.shape()[1];
444        let end_frame = end_frame.min(encoder_out.shape()[2]);
445
446        // Hold the lock once across the decoder loop to avoid per-step acquire/release.
447        let mut model = self
448            .model
449            .lock()
450            .map_err(|e| Error::Model(format!("Failed to acquire model lock: {e}")))?;
451
452        for frame_idx in start_frame..end_frame {
453            let frame = encoder_out
454                .slice(ndarray::s![0, .., frame_idx])
455                .to_owned()
456                .to_shape((1, hidden_dim, 1))
457                .map_err(|e| Error::Model(format!("Failed to reshape encoder frame: {e}")))?
458                .to_owned();
459
460            let absolute_frame = absolute_frame_offset + (frame_idx - start_frame);
461
462            for _ in 0..MAX_SYMBOLS_PER_STEP {
463                let (logits, new_state_1, new_state_2) = model.run_decoder(
464                    &frame,
465                    self.last_token,
466                    &self.state_1,
467                    &self.state_2,
468                )?;
469
470                let token_id = logits
471                    .iter()
472                    .enumerate()
473                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
474                    .map(|(idx, _)| idx)
475                    .unwrap_or(self.blank_id);
476
477                if token_id == self.blank_id {
478                    break;
479                }
480
481                tokens.push((token_id, absolute_frame));
482                self.last_token = token_id as i32;
483                self.state_1 = new_state_1;
484                self.state_2 = new_state_2;
485            }
486        }
487
488        Ok(tokens)
489    }
490
491    fn encoder_frame_to_seconds(frame: usize) -> f32 {
492        (frame * SUBSAMPLING_FACTOR * HOP_LENGTH) as f32 / SAMPLE_RATE as f32
493    }
494
495    fn tokens_to_timed(&self, tokens: &[(usize, usize)]) -> Vec<TimedToken> {
496        tokens
497            .iter()
498            .filter(|(id, _)| *id < self.blank_id)
499            .map(|&(id, frame)| TimedToken {
500                text: self.vocab.decode_single(id),
501                start: Self::encoder_frame_to_seconds(frame),
502                end: Self::encoder_frame_to_seconds(frame + 1),
503            })
504            .collect()
505    }
506
507    fn decode_incremental_tokens(&self, tokens: &[(usize, usize)]) -> String {
508        let mut text = String::new();
509        for &(token, _) in tokens {
510            if token < self.blank_id {
511                text.push_str(&self.vocab.decode_single(token));
512            }
513        }
514        text
515    }
516
517    fn transcribe_offline(
518        &mut self,
519        audio: Vec<f32>,
520        sample_rate: u32,
521        channels: u16,
522        mode: Option<TimestampMode>,
523    ) -> Result<TranscriptionResult> {
524        self.reset();
525
526        let features = audio::extract_features_raw(audio, sample_rate, channels, &self.preprocessor_config)?;
527        let (encoded, encoded_len) = {
528            let mut model = self
529                .model
530                .lock()
531                .map_err(|e| Error::Model(format!("Failed to acquire model lock: {e}")))?;
532            model.run_encoder(&features)?
533        };
534        let frame_count = (encoded_len as usize).min(encoded.shape()[2]);
535        let tokens = self.decode_encoder_frames(&encoded, 0, frame_count, 0)?;
536        self.accumulated_tokens = tokens.iter().map(|(id, _)| *id).collect();
537        self.accumulated_timed_tokens = self.tokens_to_timed(&tokens);
538
539        let text = self.get_transcript();
540        let timed = match mode {
541            Some(m) => process_timestamps(&self.accumulated_timed_tokens, m),
542            None => self.accumulated_timed_tokens.clone(),
543        };
544
545        Ok(TranscriptionResult {
546            text,
547            tokens: timed,
548        })
549    }
550}
551
552impl Transcriber for ParakeetUnified {
553    fn transcribe_samples(
554        &mut self,
555        audio: Vec<f32>,
556        sample_rate: u32,
557        channels: u16,
558        mode: Option<TimestampMode>,
559    ) -> Result<TranscriptionResult> {
560        self.transcribe_offline(audio, sample_rate, channels, mode)
561    }
562}
563
564#[cfg(test)]
565mod tests {
566    use super::UnifiedStreamingConfig;
567
568    #[test]
569    fn default_streaming_profile_aligns_to_subsampling() {
570        let config = UnifiedStreamingConfig::default().validate().unwrap();
571        assert_eq!(config.left_context_frames(), 560);
572        assert_eq!(config.chunk_frames(), 56);
573        assert_eq!(config.right_context_frames(), 56);
574        assert_eq!(config.left_context_encoder_frames(), 70);
575        assert_eq!(config.chunk_encoder_frames(), 7);
576    }
577}