Skip to main content

svod_model/gigaam/
transcribe.rs

1//! High-level `Transcriber` over a unified [`GigaAm`].
2//!
3//! Replaces the per-example pipeline (mel extract → Silero VAD → chunker →
4//! sized JIT prepare → batched pack → encoder execute → per-head decode) with
5//! one stateful wrapper. Construction builds the front-ends but does NOT
6//! prepare the encoder JIT — bounds depend on the audio. The first
7//! [`Transcriber::transcribe`] call sizes and prepares; subsequent calls reuse
8//! the cached `(b, t_mel)` plan if the new audio's chunks fit underneath.
9//!
10//! Hides the CTC vs RN-T asymmetry behind one return type:
11//! [`TranscribeResult`] with optional per-word timestamps for both heads.
12//! SentencePiece `▁ → space` post-processing, the encoder-output transpose
13//! that RN-T needs (`[d_model, T_sub] → [T_sub, d_model]`), and the CTC
14//! `frames_to_words` grouping all live inside [`HeadDecoder`].
15
16use bon::bon;
17use snafu::{ResultExt, Snafu};
18use svod_arch::ctc::CtcDecoder;
19use svod_arch::rnnt::{RnntDecoder, RnntOpts};
20use svod_tensor::PrepareConfig;
21
22pub use svod_arch::rnnt::Word;
23
24use crate::audio::{AudioChunk, EncoderBounds, MelConfig, MelSpectrogram, Splitter};
25use crate::gigaam::SubsamplingMode;
26use crate::gigaam::ctc::CtcHeadJit;
27use crate::gigaam::jit::GigaAmEncoderJit;
28use crate::gigaam::model::{GigaAm, Head};
29use crate::gigaam::rnnt::RnntStepBackend;
30use crate::jit::InputSpec;
31
32/// User-facing knobs for [`Transcriber::transcribe`].
33///
34/// Construct with [`TranscribeOpts::builder`] (per-field overrides) or
35/// [`TranscribeOpts::from_env`] (read `SVOD_*` env vars with sensible
36/// fallbacks). The two agree — `from_env()` is just `builder().build()` —
37/// so `builder().word_timestamps(true).build()` still consults env for the
38/// rest of the fields.
39///
40/// Field defaults consult these env vars:
41///
42/// | Field             | Env var                | Fallback |
43/// |-------------------|------------------------|----------|
44/// | `word_timestamps` | `SVOD_TIMESTAMPS=1`   | `false`  |
45/// | `beam_decode`     | `SVOD_BEAM_DECODE=1`  | `false`  |
46/// | `max_scores_mib`  | `SVOD_MAX_SCORES_MIB` | `256`    |
47///
48/// VAD-specific knobs (`threshold`, `min_duration`, …) live on
49/// [`SileroVadSplitter`](super::SileroVadSplitter), not here.
50#[derive(Clone, Debug)]
51pub struct TranscribeOpts {
52    /// Emit per-word `Word { text, start, end }` entries on
53    /// [`ChunkResult::words`]. Both heads support this.
54    pub word_timestamps: bool,
55    /// Promote the model's config-default CTC decoder to a beam decoder
56    /// (no-op for RN-T).
57    pub beam_decode: bool,
58    /// Per-allocation budget for the SDPA scores buffer. Caps `max_batch`
59    /// so two simultaneously live `[B, H, T_sub², dtype]` scores tensors
60    /// stay under `2 × max_scores_mib` MiB.
61    pub max_scores_mib: usize,
62}
63
64impl Default for TranscribeOpts {
65    fn default() -> Self {
66        Self::builder().build()
67    }
68}
69
70#[bon]
71impl TranscribeOpts {
72    /// Build via the [`bon`] builder. Each field default consults its
73    /// `SVOD_*` env var (see the struct docs for the full table) before
74    /// falling back to a literal — so `builder().build()` produces the same
75    /// values as [`from_env`](Self::from_env), and partial overrides
76    /// (`.word_timestamps(true).build()`) still env-read the rest.
77    #[builder]
78    pub fn builder(
79        #[builder(default = std::env::var("SVOD_TIMESTAMPS").as_deref() == Ok("1"))] word_timestamps: bool,
80        #[builder(default = std::env::var("SVOD_BEAM_DECODE").as_deref() == Ok("1"))] beam_decode: bool,
81        #[builder(default = std::env::var("SVOD_MAX_SCORES_MIB").ok().and_then(|s| s.parse().ok()).unwrap_or(256))]
82        max_scores_mib: usize,
83    ) -> Self {
84        Self { word_timestamps, beam_decode, max_scores_mib }
85    }
86
87    /// Build from `SVOD_*` env vars with the same fallbacks as the
88    /// builder. Equivalent to `Self::builder().build()`.
89    pub fn from_env() -> Self {
90        Self::builder().build()
91    }
92}
93
94/// Aggregated transcription output. `text` is the chunk texts joined by a
95/// single space (empty chunks dropped); [`words`](Self::words) flattens word
96/// timestamps across chunks (shifted by each chunk's `start_sec`).
97#[derive(Clone, Debug)]
98pub struct TranscribeResult {
99    pub text: String,
100    pub chunks: Vec<ChunkResult>,
101}
102
103impl TranscribeResult {
104    /// Iterate per-word timestamps across all chunks, shifted into the
105    /// original audio's timeline. Empty if `opts.word_timestamps` was false.
106    pub fn words(&self) -> impl Iterator<Item = Word> + '_ {
107        self.chunks.iter().flat_map(|c| {
108            let offset = c.start_sec;
109            c.words.iter().flatten().map(move |w| Word {
110                text: w.text.clone(),
111                start: w.start + offset,
112                end: w.end + offset,
113            })
114        })
115    }
116}
117
118/// One VAD-bound speech region's transcript. `start_sec`/`end_sec` reference
119/// the original audio. `words` is `Some` iff [`TranscribeOpts::word_timestamps`]
120/// was set; each entry's `start`/`end` is **chunk-relative** (add `start_sec`
121/// to get audio-absolute, or use [`TranscribeResult::words`]).
122#[derive(Clone, Debug)]
123pub struct ChunkResult {
124    pub start_sec: f32,
125    pub end_sec: f32,
126    pub text: String,
127    pub words: Option<Vec<Word>>,
128}
129
130/// Per-head decoder + JIT state. CTC needs a bounds-tied head JIT (Conv1d
131/// projection); RN-T's predictor/joint JITs ride with [`RnntStepBackend`].
132/// One instance per `Transcriber`, so the variant-size disparity is
133/// irrelevant — boxing would just add an allocation.
134#[allow(clippy::large_enum_variant)]
135pub(crate) enum HeadDecoder {
136    Ctc { jit: CtcHeadJit, decoder: CtcDecoder },
137    Rnnt { backend: RnntStepBackend, decoder: RnntDecoder, sentencepiece: bool },
138}
139
140/// CTC equivalent of [`RnntDecoder::frames_to_words`].
141///
142/// Walks the decoded `text` in lockstep with `frames` (CTC's
143/// `decode_with_timestamps` returns one frame index per emitted *token*; for
144/// GigaAM's char-level vocab one token == one Unicode scalar, so `text.chars()`
145/// is the right zip target). Splits on ASCII space — no SentencePiece on the
146/// CTC side. Returns chunk-relative `[start, end)` in seconds.
147pub(crate) fn ctc_frames_to_words(text: &str, frames: &[usize], frame_shift: f32) -> Vec<Word> {
148    let mut words: Vec<Word> = Vec::new();
149    let mut current = String::new();
150    let mut first_frame = 0usize;
151    let mut last_frame = 0usize;
152
153    let commit = |words: &mut Vec<Word>, current: &mut String, first: usize, last: usize| {
154        if !current.is_empty() {
155            words.push(Word {
156                text: std::mem::take(current),
157                start: first as f32 * frame_shift,
158                end: (last + 1) as f32 * frame_shift,
159            });
160        }
161    };
162
163    for (ch, &frame) in text.chars().zip(frames.iter()) {
164        if ch == ' ' {
165            commit(&mut words, &mut current, first_frame, last_frame);
166            continue;
167        }
168        if current.is_empty() {
169            first_frame = frame;
170        }
171        current.push(ch);
172        last_frame = frame;
173    }
174    commit(&mut words, &mut current, first_frame, last_frame);
175    words
176}
177
178/// Transpose `[d_model, t_exec_sub]` row-major → `[actual_sub, d_model]`.
179/// `actual_sub <= t_exec_sub` (the JIT pads frames beyond `actual_sub`); only
180/// the first `actual_sub` frames are read.
181fn transpose_dt_to_td(src: &[f32], d_model: usize, t_exec_sub: usize, actual_sub: usize) -> Vec<f32> {
182    let mut out = vec![0.0_f32; actual_sub * d_model];
183    for t in 0..actual_sub {
184        for d in 0..d_model {
185            out[t * d_model + d] = src[d * t_exec_sub + t];
186        }
187    }
188    out
189}
190
191fn rnnt_decode_err<E: std::error::Error + 'static>(
192    e: svod_arch::rnnt::RnntDecodeError<crate::jit::JitError>,
193) -> TranscribeError<E> {
194    TranscribeError::RnntDecode { source: Box::new(e) }
195}
196
197// ─── Errors ───────────────────────────────────────────────────────────────
198
199/// Generic over the splitter error type so per-impl errors stay
200/// pattern-matchable rather than being type-erased into `Box<dyn Error>`.
201/// Mirrors the `svod_arch::rnnt::RnntDecodeError<JitError>` shape.
202#[derive(Debug, Snafu)]
203#[snafu(visibility(pub(crate)))]
204pub enum TranscribeError<E: std::error::Error + 'static> {
205    #[snafu(display("splitter: {source}"))]
206    Splitter { source: E },
207    #[snafu(display("{source}"))]
208    Jit {
209        #[snafu(source(from(crate::jit::JitError, Box::new)))]
210        source: Box<crate::jit::JitError>,
211    },
212    #[snafu(display("{source}"))]
213    CtcDecode { source: svod_arch::ctc::DecodeError },
214    #[snafu(display("{source}"))]
215    RnntDecode { source: Box<svod_arch::rnnt::RnntDecodeError<crate::jit::JitError>> },
216    #[snafu(display("{source}"))]
217    Model {
218        #[snafu(source(from(crate::gigaam::error::Error, Box::new)))]
219        source: Box<crate::gigaam::error::Error>,
220    },
221    #[snafu(display("{source}"))]
222    Tensor {
223        #[snafu(source(from(svod_tensor::error::Error, Box::new)))]
224        source: Box<svod_tensor::error::Error>,
225    },
226    #[snafu(display("{source}"))]
227    Device {
228        #[snafu(source(from(svod_device::error::Error, Box::new)))]
229        source: Box<svod_device::error::Error>,
230    },
231    #[snafu(display("WAV is {wav_sr} Hz, model expects {model_sr} Hz (resample first)"))]
232    SampleRateMismatch { wav_sr: u32, model_sr: u32 },
233    #[snafu(display("chunk {idx} length {samples} samples exceeds encoder capacity {max_samples} samples"))]
234    ChunkExceedsCapacity { idx: usize, samples: usize, max_samples: usize },
235    #[snafu(display("chunk {idx} end {end_sample} exceeds waveform length {waveform_len}"))]
236    ChunkOutOfRange { idx: usize, end_sample: usize, waveform_len: usize },
237}
238
239// ─── Transcriber ──────────────────────────────────────────────────────────
240
241/// High-level transcription wrapper, generic over the chunking strategy.
242/// JITs are prepared eagerly at construction; the splitter advertises its
243/// max chunk length so JIT buffers can be sized tighter than the encoder's
244/// hard ceiling. Use [`transcribe_chunks`](Self::transcribe_chunks) to
245/// bypass the splitter for pre-segmented audio.
246pub struct Transcriber<S: Splitter> {
247    model: GigaAm,
248    opts: TranscribeOpts,
249    splitter: S,
250    mel: MelSpectrogram,
251    head_decoder: HeadDecoder,
252    encoder_jit: GigaAmEncoderJit,
253    max_batch: usize,
254    max_t_mel: usize,
255}
256
257impl<S: Splitter> Transcriber<S> {
258    /// Build the transcriber and prepare every JIT eagerly — subsequent
259    /// `transcribe` calls just execute. `model` is cloned into each JIT
260    /// (cheap: weights are shared via `Tensor` handle Arcs).
261    pub fn new(model: GigaAm, splitter: S, opts: TranscribeOpts) -> Result<Self, TranscribeError<S::Error>> {
262        let mel = MelSpectrogram::new(&MelConfig {
263            sample_rate: model.config.sample_rate,
264            n_fft: model.config.n_fft,
265            hop_length: model.config.hop_length,
266            win_length: model.config.win_length,
267            n_mels: model.config.n_mels,
268            center: model.config.mel_center,
269        });
270
271        let subsampling_factor = model.config.subsampling_factor;
272        let hop_length = model.config.hop_length;
273        let model_bounds = EncoderBounds {
274            sample_rate: model.config.sample_rate as u32,
275            hop_length,
276            subsampling_factor,
277            max_mel_frames: model.config.max_mel_frames,
278        };
279        // Splitter advertises its emission ceiling; clamp to encoder
280        // capacity, then round up to the next power of two so the JIT
281        // codegen sees a clean factorisation.
282        let chunk_samples_cap = splitter.max_chunk_samples(&model_bounds).min(model_bounds.max_samples());
283        let chunk_mel = (chunk_samples_cap / hop_length).saturating_add(2 * subsampling_factor);
284        let max_t_mel = chunk_mel.max(1).next_power_of_two().min(model.config.max_mel_frames).max(subsampling_factor);
285
286        // SDPA scores `[B, H, T_sub², dtype]` are live twice during attention;
287        // budget `max_batch` so they stay under `2 * max_scores_mib`.
288        let t_sub_max = (max_t_mel / subsampling_factor).max(1);
289        let scores_dtype_bytes = model.encoder.input_dtype().bytes();
290        let bytes_per_batch = model.config.n_heads * t_sub_max * t_sub_max * scores_dtype_bytes;
291        let target_scores_bytes = opts.max_scores_mib * 1024 * 1024;
292        let max_batch_by_memory = (target_scores_bytes / bytes_per_batch.max(1)).max(1);
293        let max_batch = max_batch_by_memory.min(model.config.max_batch_size);
294
295        let prepare_config = PrepareConfig::from_env();
296        let mut encoder_jit = GigaAmEncoderJit::new(model.clone()).with_b_bound(max_batch).with_t_bound(max_t_mel);
297        encoder_jit
298            .prepare_with_config(
299                InputSpec::f32(&[max_batch, model.config.n_mels, max_t_mel]),
300                InputSpec::i32(&[max_batch]),
301                &prepare_config,
302            )
303            .context(JitSnafu)?;
304
305        let head_decoder = match &model.head {
306            Head::Ctc(_) => {
307                let decoder = if opts.beam_decode {
308                    match &model.config.decoder {
309                        CtcDecoder::Greedy(g) => CtcDecoder::Beam(Box::new(svod_arch::ctc::BeamDecoder::new(
310                            g.vocabulary().to_vec(),
311                            svod_arch::ctc::BeamOpts::default(),
312                        ))),
313                        other => other.clone(),
314                    }
315                } else {
316                    model.config.decoder.clone()
317                };
318                let subs_kernel_size = match model.config.subsampling_mode {
319                    SubsamplingMode::Conv1d => model.config.subs_kernel_size,
320                    SubsamplingMode::Conv2d => 3,
321                };
322                let max_t_sub = subs_output_length(subs_kernel_size, max_t_mel);
323                let mut jit = CtcHeadJit::new(model.clone()).with_b_bound(max_batch).with_t_sub_bound(max_t_sub);
324                jit.prepare_with_config(InputSpec::f32(&[max_batch, model.config.d_model, max_t_sub]), &prepare_config)
325                    .context(JitSnafu)?;
326                HeadDecoder::Ctc { jit, decoder }
327            }
328            Head::Rnnt { runtime, .. } => {
329                let backend = RnntStepBackend::from_model(model.clone()).context(JitSnafu)?;
330                let decoder = RnntDecoder::new(
331                    runtime.vocabulary.clone(),
332                    RnntOpts { max_symbols_per_step: runtime.max_symbols_per_step },
333                );
334                HeadDecoder::Rnnt { backend, decoder, sentencepiece: runtime.sentencepiece }
335            }
336        };
337
338        Ok(Self { model, opts, splitter, mel, head_decoder, encoder_jit, max_batch, max_t_mel })
339    }
340
341    /// Encoder bounds at the model's full capacity. Passed to splitters
342    /// at split time so they can clamp chunks to the encoder's ceiling.
343    pub fn encoder_bounds(&self, sample_rate: u32) -> Result<EncoderBounds, TranscribeError<S::Error>> {
344        self.bounds_with(sample_rate, self.model.config.max_mel_frames)
345    }
346
347    /// Encoder bounds tightened to this transcriber's prepared JIT capacity.
348    fn prepared_bounds(&self, sample_rate: u32) -> Result<EncoderBounds, TranscribeError<S::Error>> {
349        self.bounds_with(sample_rate, self.max_t_mel)
350    }
351
352    fn bounds_with(&self, sample_rate: u32, max_mel_frames: usize) -> Result<EncoderBounds, TranscribeError<S::Error>> {
353        if sample_rate as usize != self.model.config.sample_rate {
354            return Err(TranscribeError::SampleRateMismatch {
355                wav_sr: sample_rate,
356                model_sr: self.model.config.sample_rate as u32,
357            });
358        }
359        Ok(EncoderBounds {
360            sample_rate,
361            hop_length: self.model.config.hop_length,
362            subsampling_factor: self.model.config.subsampling_factor,
363            max_mel_frames,
364        })
365    }
366
367    /// Transcribe a waveform end-to-end: bounds → splitter → mel → batched
368    /// encoder → per-chunk head decode. `waveform` is fp32 PCM in `[-1, 1]`;
369    /// the model expects `model.config.sample_rate` (returns
370    /// [`TranscribeError::SampleRateMismatch`] otherwise).
371    pub fn transcribe(
372        &mut self,
373        waveform: &[f32],
374        sample_rate: u32,
375    ) -> Result<TranscribeResult, TranscribeError<S::Error>> {
376        let bounds = self.encoder_bounds(sample_rate)?;
377        let chunks = self.splitter.split(waveform, &bounds).context(SplitterSnafu)?;
378        self.transcribe_chunks(waveform, sample_rate, &chunks)
379    }
380
381    /// Escape hatch: caller-supplied chunks. Validates each chunk against
382    /// encoder capacity (`ChunkExceedsCapacity`) and the waveform's bounds
383    /// (`ChunkOutOfRange`) rather than silently truncating. Misaligned
384    /// boundaries are accepted — the mel/JIT pipeline pads the trailing
385    /// fractional frame.
386    pub fn transcribe_chunks(
387        &mut self,
388        waveform: &[f32],
389        sample_rate: u32,
390        chunks: &[AudioChunk],
391    ) -> Result<TranscribeResult, TranscribeError<S::Error>> {
392        // Validate against the prepared JIT capacity, not the model's
393        // worst case — oversized chunks must error here, not inside the JIT.
394        let max_samples = self.prepared_bounds(sample_rate)?.max_samples();
395        for (idx, chunk) in chunks.iter().enumerate() {
396            if chunk.end_sample > waveform.len() {
397                return Err(TranscribeError::ChunkOutOfRange {
398                    idx,
399                    end_sample: chunk.end_sample,
400                    waveform_len: waveform.len(),
401                });
402            }
403            let samples = chunk.end_sample.saturating_sub(chunk.start_sample);
404            if samples > max_samples {
405                return Err(TranscribeError::ChunkExceedsCapacity { idx, samples, max_samples });
406            }
407        }
408
409        let n_mels = self.mel.n_mels();
410        if chunks.is_empty() {
411            return Ok(TranscribeResult { text: String::new(), chunks: Vec::new() });
412        }
413
414        let sample_rate_hz = self.model.config.sample_rate;
415        let d_model = self.model.config.d_model;
416        let subs_kernel_size = match self.model.config.subsampling_mode {
417            SubsamplingMode::Conv1d => self.model.config.subs_kernel_size,
418            SubsamplingMode::Conv2d => 3,
419        };
420        let max_t_mel = self.max_t_mel;
421        let max_t_sub = subs_output_length(subs_kernel_size, max_t_mel);
422        let max_batch = self.max_batch;
423        let want_words = self.opts.word_timestamps;
424
425        // (start_sample, end_sample, mel_len, start_sec, end_sec) per chunk.
426        let chunks_meta: Vec<(usize, usize, usize, f32, f32)> = chunks
427            .iter()
428            .filter_map(|c| {
429                let mel_len = self.mel.num_frames(c.end_sample.saturating_sub(c.start_sample));
430                if mel_len == 0 {
431                    return None;
432                }
433                let start_sec = c.start_sample as f32 / sample_rate_hz as f32;
434                let end_sec = c.end_sample as f32 / sample_rate_hz as f32;
435                Some((c.start_sample, c.end_sample, mel_len, start_sec, end_sec))
436            })
437            .collect();
438        if chunks_meta.is_empty() {
439            return Ok(TranscribeResult { text: String::new(), chunks: Vec::new() });
440        }
441
442        let num_chunks = chunks_meta.len();
443        let mut chunk_results: Vec<ChunkResult> = Vec::with_capacity(num_chunks);
444        for chunk_batch_start in (0..num_chunks).step_by(max_batch) {
445            let b = (num_chunks - chunk_batch_start).min(max_batch);
446            let mut chunk_lengths = vec![0usize; b];
447
448            let batch_mels: Vec<Vec<f32>> = (0..b)
449                .map(|bi| {
450                    let &(start_sample, end_sample, valid, _, _) = &chunks_meta[chunk_batch_start + bi];
451                    let mut chunk_mel = ndarray::Array3::<f32>::zeros((1, n_mels, valid));
452                    {
453                        let mut view = chunk_mel.view_mut().into_dyn();
454                        self.mel.forward_into(&waveform[start_sample..end_sample], &mut view);
455                    }
456                    chunk_mel.as_slice().expect("contiguous chunk mel").to_vec()
457                })
458                .collect();
459
460            // Pack mel into encoder JIT input buffer.
461            {
462                let buf = self.encoder_jit.mel_mut().context(JitSnafu)?;
463                let mut view = buf.as_array_mut::<f32>().context(DeviceSnafu)?;
464                let slice = view.as_slice_mut().expect("contiguous mel buffer");
465                slice.fill(0.0);
466                for (bi, chunk_len) in chunk_lengths.iter_mut().enumerate() {
467                    let &(_, _, valid, _, _) = &chunks_meta[chunk_batch_start + bi];
468                    *chunk_len = valid;
469                    let chunk_mel = &batch_mels[bi];
470                    for mel_bin in 0..n_mels {
471                        let src = mel_bin * valid;
472                        let dst = ((bi * n_mels) + mel_bin) * max_t_mel;
473                        slice[dst..dst + valid].copy_from_slice(&chunk_mel[src..src + valid]);
474                    }
475                }
476            }
477            // Pack lengths into encoder JIT.
478            {
479                let buf = self.encoder_jit.lengths_mut().context(JitSnafu)?;
480                let mut view = buf.as_array_mut::<i32>().context(DeviceSnafu)?;
481                let slice = view.as_slice_mut().expect("contiguous lengths buffer");
482                slice.fill(0);
483                for (i, len) in chunk_lengths.iter().enumerate() {
484                    slice[i] = *len as i32;
485                }
486            }
487
488            let t_exec = chunk_lengths.iter().copied().max().unwrap_or(1).max(1);
489            let t_exec_sub = subs_output_length(subs_kernel_size, t_exec);
490            self.encoder_jit.execute_with_vars(&[("b", b as i64), ("t", t_exec as i64)]).context(JitSnafu)?;
491
492            // CTC chains the encoder output into the head JIT once per batch
493            // then decodes per item; RN-T decodes the encoder output
494            // directly per item (its JITs ride with the backend).
495            match &mut self.head_decoder {
496                HeadDecoder::Ctc { jit, decoder } => {
497                    // Chain encoder output [b, d_model, t_exec_sub] into the
498                    // head input slab [max_batch, d_model, max_t_sub].
499                    {
500                        let n = b * d_model * t_exec_sub;
501                        let src_flat =
502                            self.encoder_jit.output().context(JitSnafu)?.as_array::<f32>().context(DeviceSnafu)?;
503                        let src_3d = src_flat
504                            .slice(ndarray::s![0..n])
505                            .into_shape_with_order((b, d_model, t_exec_sub))
506                            .expect("encoder output reshape");
507                        let dst_flat =
508                            jit.encoded_mut().context(JitSnafu)?.as_array_mut::<f32>().context(DeviceSnafu)?;
509                        let mut dst_3d = dst_flat
510                            .into_shape_with_order((max_batch, d_model, max_t_sub))
511                            .expect("head input reshape");
512                        dst_3d.slice_mut(ndarray::s![0..b, 0..d_model, 0..t_exec_sub]).assign(&src_3d);
513                    }
514                    jit.execute_with_vars(&[("b", b as i64), ("t_sub", t_exec_sub as i64)]).context(JitSnafu)?;
515
516                    let total_vocab = decoder.total_vocab();
517                    let item_stride = t_exec_sub * total_vocab;
518                    let logits_buf = jit.output().context(JitSnafu)?;
519                    let logits = logits_buf.as_array::<f32>().context(DeviceSnafu)?;
520                    let flat = logits.as_slice().expect("contiguous head logits");
521                    for (bi, mel_len) in chunk_lengths.iter().enumerate() {
522                        let actual_sub = subs_output_length(subs_kernel_size, *mel_len);
523                        let &(start_sample, end_sample, _, start_sec, end_sec) = &chunks_meta[chunk_batch_start + bi];
524                        let chunk_duration_sec = (end_sample - start_sample) as f32 / sample_rate_hz as f32;
525                        let frame_shift = chunk_duration_sec / (actual_sub.max(1) as f32);
526
527                        let item_slice = &flat[bi * item_stride..bi * item_stride + item_stride];
528
529                        let (text, frames) = if want_words {
530                            let (text, frames) = decoder
531                                .decode_with_timestamps(item_slice, t_exec_sub, actual_sub)
532                                .context(CtcDecodeSnafu)?;
533                            (text, Some(frames))
534                        } else {
535                            let text = decoder.decode(item_slice, t_exec_sub, actual_sub).context(CtcDecodeSnafu)?;
536                            (text, None)
537                        };
538                        let words = want_words.then(|| {
539                            let frames = frames.as_deref().unwrap_or(&[]);
540                            ctc_frames_to_words(&text, frames, frame_shift)
541                        });
542                        chunk_results.push(ChunkResult { start_sec, end_sec, text, words });
543                    }
544                }
545                HeadDecoder::Rnnt { backend, decoder, sentencepiece } => {
546                    let item_stride = d_model * t_exec_sub;
547                    let enc_buf = self.encoder_jit.output().context(JitSnafu)?;
548                    let enc = enc_buf.as_array::<f32>().context(DeviceSnafu)?;
549                    let flat = enc.as_slice().expect("contiguous encoder output");
550                    for (bi, mel_len) in chunk_lengths.iter().enumerate() {
551                        let actual_sub = subs_output_length(subs_kernel_size, *mel_len);
552                        let &(start_sample, end_sample, _, start_sec, end_sec) = &chunks_meta[chunk_batch_start + bi];
553                        let chunk_duration_sec = (end_sample - start_sample) as f32 / sample_rate_hz as f32;
554                        let frame_shift = chunk_duration_sec / (actual_sub.max(1) as f32);
555
556                        let item_slice = &flat[bi * item_stride..bi * item_stride + item_stride];
557                        // Encoder output is [d_model, t_exec_sub] row-major;
558                        // the arch decoder wants frame-major [actual_sub, d_model].
559                        let frames = transpose_dt_to_td(item_slice, d_model, t_exec_sub, actual_sub);
560
561                        let backend: &mut RnntStepBackend = backend;
562                        let (raw, emissions) = if want_words {
563                            let (s, e) = decoder
564                                .decode_with_timestamps(&frames, actual_sub, actual_sub, d_model, backend)
565                                .map_err(rnnt_decode_err)?;
566                            (s, e)
567                        } else {
568                            let s = decoder
569                                .decode(&frames, actual_sub, actual_sub, d_model, backend)
570                                .map_err(rnnt_decode_err)?;
571                            (s, Vec::new())
572                        };
573                        let words = want_words.then(|| decoder.frames_to_words(&emissions, frame_shift));
574                        // SP pieces carry `▁` (U+2581) as word-initial markers;
575                        // after concatenation we restore them as spaces.
576                        let text = if *sentencepiece { raw.replace('\u{2581}', " ").trim().to_string() } else { raw };
577                        chunk_results.push(ChunkResult { start_sec, end_sec, text, words });
578                    }
579                }
580            }
581        }
582
583        let text =
584            chunk_results.iter().map(|c| c.text.as_str()).filter(|s| !s.is_empty()).collect::<Vec<_>>().join(" ");
585        Ok(TranscribeResult { text, chunks: chunk_results })
586    }
587}
588
589/// Compute the encoder's sub-sampled output frame count from the input
590/// mel-frame count. Mirrors the two-stage 2× stride conv stack used by
591/// GigaAM's subsampling (kernel `subs_kernel_size`, stride 2, applied twice).
592fn subs_output_length(kernel_size: usize, mel_frames: usize) -> usize {
593    let pad = (kernel_size - 1) / 2;
594    let mut len = mel_frames;
595    for _ in 0..2 {
596        len = (len + 2 * pad - kernel_size) / 2 + 1;
597    }
598    len
599}