Skip to main content

rlx_whisper/
runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use crate::audio::{MelSpectrogram, N_FRAMES, pcm_to_mel};
17use crate::backend::{
18    WhisperCompileOpts, WhisperGraphCtx, decode_bucket_ladder, decode_cache_key,
19    metal_compile_guard, whisper_decoder_device, whisper_use_gpu_kv,
20};
21use crate::batch::{batched_prompt_f32, replicate_encoder_for_beams};
22use crate::builder::WhisperGraphOpts;
23use crate::cache::{
24    WhisperCrossCache, WhisperKvCache, apply_bucketed_decode_step, cross_from_outputs,
25    kv_from_prefill_outputs,
26};
27use crate::config::WhisperConfig;
28use crate::decode::{
29    EOT_TOKEN, SuppressionMask, batched_logits_row_owned, beam_search_decode_kv,
30    beam_search_decode_kv_batched, initial_prompt_opts, last_logits_row,
31};
32use crate::fused::{FusedDecoderWeights, FusedEncoderWeights};
33use crate::mel::stack_mels;
34use crate::vad::{VadConfig, segments_by_vad};
35use crate::weights::WhisperWeightPrefix;
36use anyhow::{Context, Result, bail, ensure};
37use rlx_core::flow_util::{
38    bucket_cache_ensure_built, compile_cache_ensure_built_with_options, graph_from_built,
39};
40use rlx_core::validate_standard_device;
41use rlx_core::weight_map::WeightMap;
42use rlx_core::{
43    GpuKvBinding, cross_attn_gpu_handles_ready, install_cross_attn_gpu_handles,
44    run_bucketed_kv_decode_gpu, run_bucketed_kv_decode_keyed, sync_gpu_kv_to_host,
45};
46use rlx_ir::DType;
47use rlx_runtime::attn_mask::bucket_decode_mask;
48use rlx_runtime::compile_cache::{BucketedCompileCache, CacheRunInput, CompileCache};
49use rlx_runtime::{CompiledGraph, Device};
50use std::path::{Path, PathBuf};
51use std::sync::Arc;
52
53#[derive(Debug, Clone)]
54pub struct WhisperRunnerBuilder {
55    weights: Option<PathBuf>,
56    config_path: Option<PathBuf>,
57    tokenizer_path: Option<PathBuf>,
58    config: Option<WhisperConfig>,
59    device: Option<Device>,
60    mel_frames: usize,
61    max_decode_steps: usize,
62    beam_size: usize,
63    language: Option<String>,
64    translate: bool,
65    timestamps: bool,
66    activation_dtype: DType,
67    use_f16_compute: bool,
68    vad_config: Option<VadConfig>,
69    max_region_batch: usize,
70    encoder_attn_chunk: usize,
71}
72
73impl Default for WhisperRunnerBuilder {
74    fn default() -> Self {
75        Self {
76            weights: None,
77            config_path: None,
78            tokenizer_path: None,
79            config: None,
80            device: None,
81            mel_frames: 0,
82            max_decode_steps: 0,
83            beam_size: 0,
84            language: None,
85            translate: false,
86            timestamps: false,
87            activation_dtype: DType::F32,
88            use_f16_compute: false,
89            vad_config: None,
90            max_region_batch: 10,
91            encoder_attn_chunk: crate::builder::DEFAULT_ENCODER_ATTN_CHUNK,
92        }
93    }
94}
95
96impl WhisperRunnerBuilder {
97    pub fn weights<P: Into<PathBuf>>(mut self, path: P) -> Self {
98        self.weights = Some(path.into());
99        self
100    }
101    pub fn config_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
102        self.config_path = Some(path.into());
103        self
104    }
105    pub fn tokenizer_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
106        self.tokenizer_path = Some(path.into());
107        self
108    }
109    pub fn config(mut self, cfg: WhisperConfig) -> Self {
110        self.config = Some(cfg);
111        self
112    }
113    pub fn device(mut self, d: Device) -> Self {
114        self.device = Some(d);
115        self
116    }
117    pub fn language(mut self, lang: impl Into<String>) -> Self {
118        self.language = Some(lang.into());
119        self
120    }
121    pub fn translate(mut self, on: bool) -> Self {
122        self.translate = on;
123        self
124    }
125    pub fn timestamps(mut self, on: bool) -> Self {
126        self.timestamps = on;
127        self
128    }
129    pub fn activation_dtype(mut self, dt: DType) -> Self {
130        self.activation_dtype = dt;
131        self
132    }
133    pub fn use_f16_compute(mut self, on: bool) -> Self {
134        self.use_f16_compute = on;
135        self
136    }
137    pub fn vad_config(mut self, cfg: VadConfig) -> Self {
138        self.vad_config = Some(cfg);
139        self
140    }
141    pub fn max_region_batch(mut self, n: usize) -> Self {
142        self.max_region_batch = n.max(1);
143        self
144    }
145    pub fn encoder_attn_chunk(mut self, n: usize) -> Self {
146        self.encoder_attn_chunk = n;
147        self
148    }
149    pub fn max_decode_steps(mut self, n: usize) -> Self {
150        self.max_decode_steps = n;
151        self
152    }
153    pub fn beam_size(mut self, n: usize) -> Self {
154        self.beam_size = n;
155        self
156    }
157
158    pub fn build(self) -> Result<WhisperRunner> {
159        let weights_path = self
160            .weights
161            .ok_or_else(|| anyhow::anyhow!("weights path required"))?;
162        if !weights_path.exists() {
163            bail!("weights file not found: {weights_path:?}");
164        }
165        let weights_dir = weights_path
166            .parent()
167            .ok_or_else(|| anyhow::anyhow!("weights path has no parent"))?;
168        let cfg_path = self
169            .config_path
170            .clone()
171            .unwrap_or_else(|| weights_dir.join("config.json"));
172        let cfg = match self.config {
173            Some(c) => c,
174            None => WhisperConfig::from_file(&cfg_path)
175                .with_context(|| format!("reading config {cfg_path:?}"))?,
176        };
177        let tok_path = self
178            .tokenizer_path
179            .clone()
180            .unwrap_or_else(|| weights_dir.join("tokenizer.json"));
181        let device = self.device.unwrap_or(Device::Cpu);
182        validate_standard_device("whisper", device)?;
183        let mel_frames = if self.mel_frames == 0 {
184            N_FRAMES
185        } else {
186            self.mel_frames
187        };
188        let max_decode_steps = if self.max_decode_steps == 0 {
189            cfg.max_target_positions.saturating_sub(8)
190        } else {
191            self.max_decode_steps
192        };
193        let wt = weights_path
194            .to_str()
195            .ok_or_else(|| anyhow::anyhow!("non-utf8 weights path"))?;
196        let mut weights_cache = WeightMap::snapshot_from_path(wt)?;
197        let pfx = {
198            let wm = WeightMap::from_tensors(weights_cache.clone());
199            WhisperWeightPrefix::detect(&wm)
200        };
201        let fused = FusedDecoderWeights::from_checkpoint(&weights_cache, &cfg, &pfx)?;
202        let fused_enc = FusedEncoderWeights::from_checkpoint(&weights_cache, &cfg, &pfx)?;
203        fused.merge_into_tensors(&mut weights_cache);
204        fused_enc.merge_into_tensors(&mut weights_cache);
205        let mut graph_opts = if self.use_f16_compute || self.activation_dtype == DType::F16 {
206            WhisperGraphOpts::f16_mixed()
207        } else {
208            WhisperGraphOpts::default()
209        };
210        if self.encoder_attn_chunk != crate::builder::DEFAULT_ENCODER_ATTN_CHUNK {
211            graph_opts.encoder_attn_chunk = self.encoder_attn_chunk;
212            graph_opts.cross_attn_chunk = self.encoder_attn_chunk;
213        }
214        let suppression = SuppressionMask::from_config(&cfg);
215
216        let f16 = self.use_f16_compute || self.activation_dtype == DType::F16;
217        let mut compile_opts = WhisperCompileOpts::new(device, f16, &weights_path);
218        // Metal / MLX / Vulkan: run encoder + cross + prefill + decode on CPU.
219        let decode_device = whisper_decoder_device(device);
220        let prefill_device = decode_device;
221        if decode_device != device {
222            let cpu_opts = WhisperCompileOpts::new(decode_device, f16, &weights_path);
223            compile_opts.encoder = cpu_opts.encoder.clone();
224            compile_opts.cross = cpu_opts.cross.clone();
225            compile_opts.decode = cpu_opts.decode.clone();
226            compile_opts.prefill = cpu_opts.prefill;
227        }
228        let use_gpu_kv = whisper_use_gpu_kv(device, decode_device);
229
230        let enc_seq = cfg.encoder_seq_len(mel_frames);
231        let weights_cache = Arc::new(weights_cache);
232        let graph_ctx = WhisperGraphCtx {
233            cfg: cfg.clone(),
234            pfx: pfx.clone(),
235            weights: Arc::clone(&weights_cache),
236            enc_seq,
237            mel_frames,
238            graph_opts,
239            fused: Some(fused.clone()),
240            fused_enc: Some(fused_enc.clone()),
241        };
242
243        let mut enc_compile_cache = CompileCache::new(decode_device, 8);
244        let mut cross_compile_cache = CompileCache::new(decode_device, 8);
245        metal_compile_guard(decode_device, || -> Result<()> {
246            compile_cache_ensure_built_with_options(
247                &mut enc_compile_cache,
248                1,
249                graph_ctx.build_encoder(1)?,
250                &compile_opts.encoder,
251            )?;
252            compile_cache_ensure_built_with_options(
253                &mut cross_compile_cache,
254                1,
255                graph_ctx.build_cross(1)?,
256                &compile_opts.cross,
257            )?;
258            Ok(())
259        })?;
260
261        let max_past = cfg.max_target_positions.max(1);
262        let decode_compile_cache = decode_bucket_ladder(decode_device, max_past as u64);
263
264        #[cfg(feature = "tokenizer")]
265        let tokenizer = {
266            ensure!(tok_path.exists(), "tokenizer not found: {tok_path:?}");
267            Some(
268                tokenizers::Tokenizer::from_file(&tok_path)
269                    .map_err(|e| anyhow::anyhow!("load tokenizer {tok_path:?}: {e}"))?,
270            )
271        };
272
273        let cross_input_names: Vec<String> = (0..cfg.decoder_layers)
274            .flat_map(|i| [format!("cross_k_{i}"), format!("cross_v_{i}")])
275            .collect();
276
277        Ok(WhisperRunner {
278            graph_ctx,
279            device,
280            decode_device,
281            prefill_device,
282            activation_dtype: self.activation_dtype,
283            suppression,
284            max_decode_steps,
285            beam_size: self.beam_size,
286            max_region_batch: self.max_region_batch,
287            vad_config: self.vad_config,
288            compile_opts,
289            use_gpu_kv,
290            gpu_kv_binding: GpuKvBinding::default(),
291            cross_gpu_epoch: 0,
292            cross_gpu_bound_epoch: u64::MAX,
293            decode_batch_tag: u64::MAX,
294            enc_compile_cache,
295            cross_compile_cache,
296            prefill_compile_cache: CompileCache::new(prefill_device, 8),
297            decode_compile_cache,
298            decode_token_f32: Vec::new(),
299            decode_pos_ix: Vec::new(),
300            decode_mask: Vec::new(),
301            cross_input_names,
302            language: self.language,
303            translate: self.translate,
304            timestamps: self.timestamps,
305            #[cfg(feature = "tokenizer")]
306            tokenizer,
307        })
308    }
309}
310
311/// Stage timings from [`WhisperRunner::bench_greedy_pipeline`].
312#[derive(Debug, Clone)]
313pub struct WhisperBenchReport {
314    pub encode_ms: f64,
315    pub cross_ms: f64,
316    pub prefill_ms: f64,
317    pub decode_ms: f64,
318    pub decode_steps: usize,
319    pub greedy_ms: f64,
320    /// Logits after prompt prefill (`[1, prompt_len, vocab]` layout).
321    pub last_prefill_logits: Vec<f32>,
322}
323
324pub struct WhisperRunner {
325    graph_ctx: WhisperGraphCtx,
326    pub device: Device,
327    /// Device used for bucketed decode graphs (CPU when `device` needs host decoder).
328    decode_device: Device,
329    /// Device used for prompt prefill (same as [`Self::decode_device`]).
330    prefill_device: Device,
331    pub activation_dtype: DType,
332    suppression: SuppressionMask,
333    max_decode_steps: usize,
334    beam_size: usize,
335    max_region_batch: usize,
336    vad_config: Option<VadConfig>,
337    compile_opts: WhisperCompileOpts,
338    use_gpu_kv: bool,
339    gpu_kv_binding: GpuKvBinding,
340    /// Bumped on each new cross cache; GPU cross handles rebind when epochs differ.
341    cross_gpu_epoch: u64,
342    cross_gpu_bound_epoch: u64,
343    decode_batch_tag: u64,
344    enc_compile_cache: CompileCache,
345    cross_compile_cache: CompileCache,
346    prefill_compile_cache: CompileCache,
347    decode_compile_cache: BucketedCompileCache,
348    decode_token_f32: Vec<f32>,
349    decode_pos_ix: Vec<f32>,
350    decode_mask: Vec<f32>,
351    cross_input_names: Vec<String>,
352    language: Option<String>,
353    translate: bool,
354    timestamps: bool,
355    #[cfg(feature = "tokenizer")]
356    tokenizer: Option<tokenizers::Tokenizer>,
357}
358
359impl WhisperRunner {
360    pub fn builder() -> WhisperRunnerBuilder {
361        WhisperRunnerBuilder::default()
362    }
363
364    pub fn config(&self) -> &WhisperConfig {
365        &self.graph_ctx.cfg
366    }
367
368    /// Number of bucketed decode graphs compiled so far (bench / tuning).
369    pub fn decode_buckets_compiled(&self) -> usize {
370        self.decode_compile_cache.compiled_count()
371    }
372
373    fn prepare_decode_step_inputs(&mut self, tokens: &[u32], past_seq: usize, upper: usize) {
374        self.decode_token_f32.clear();
375        self.decode_token_f32
376            .extend(tokens.iter().map(|&t| t as f32));
377        self.decode_pos_ix.clear();
378        self.decode_pos_ix.resize(tokens.len(), past_seq as f32);
379        let mask = bucket_decode_mask(past_seq, upper);
380        if self.decode_mask.len() != mask.len() {
381            self.decode_mask = mask;
382        } else {
383            self.decode_mask.copy_from_slice(&mask);
384        }
385    }
386
387    pub fn mel_frames(&self) -> usize {
388        self.graph_ctx.mel_frames
389    }
390
391    pub fn enc_seq(&self) -> usize {
392        self.graph_ctx.enc_seq
393    }
394
395    /// Device that runs bucketed decode graphs (may differ from [`Self::device`] on Metal/MLX).
396    pub fn decode_device(&self) -> Device {
397        self.decode_device
398    }
399
400    /// Device that runs encoder, cross, prefill, and decode graphs.
401    pub fn stage_device(&self) -> Device {
402        self.decode_device
403    }
404
405    pub fn uses_gpu_kv(&self) -> bool {
406        self.use_gpu_kv
407    }
408
409    fn ensure_encoder(&mut self, batch: usize) -> Result<()> {
410        let key = batch as u64;
411        if self.enc_compile_cache.contains(key) {
412            return Ok(());
413        }
414        let built = self.graph_ctx.build_encoder(batch)?;
415        let opts = self.compile_opts.encoder.clone();
416        metal_compile_guard(self.decode_device, || -> Result<()> {
417            compile_cache_ensure_built_with_options(
418                &mut self.enc_compile_cache,
419                key,
420                built,
421                &opts,
422            )?;
423            Ok(())
424        })
425    }
426
427    fn bind_cross_gpu_if_needed(
428        compiled: &mut CompiledGraph,
429        cross: &WhisperCrossCache,
430        enc_seq: usize,
431        d_model: usize,
432        n_layers: usize,
433        epoch: u64,
434        bound_epoch: u64,
435        use_gpu: bool,
436    ) -> Result<bool> {
437        if !use_gpu {
438            return Ok(false);
439        }
440        if epoch == bound_epoch && cross_attn_gpu_handles_ready(compiled) {
441            return Ok(true);
442        }
443        install_cross_attn_gpu_handles(compiled, cross, enc_seq, d_model, n_layers)?;
444        Ok(true)
445    }
446
447    fn ensure_cross(&mut self, batch: usize) -> Result<()> {
448        let key = batch as u64;
449        if self.cross_compile_cache.contains(key) {
450            return Ok(());
451        }
452        let built = self.graph_ctx.build_cross(batch)?;
453        let opts = self.compile_opts.cross.clone();
454        metal_compile_guard(self.decode_device, || -> Result<()> {
455            compile_cache_ensure_built_with_options(
456                &mut self.cross_compile_cache,
457                key,
458                built,
459                &opts,
460            )?;
461            Ok(())
462        })
463    }
464
465    pub fn encode_mel(&mut self, mel: &MelSpectrogram) -> Result<Vec<f32>> {
466        ensure!(
467            mel.n_frames == self.graph_ctx.mel_frames,
468            "mel frame count mismatch"
469        );
470        self.ensure_encoder(1)?;
471        let key = 1u64;
472        metal_compile_guard(self.decode_device, || {
473            self.enc_compile_cache
474                .get_or_compile(key, || panic!("encoder cache missing"))
475                .run(&[("mel", &mel.data)])
476        })
477        .into_iter()
478        .next()
479        .ok_or_else(|| anyhow::anyhow!("encoder produced no output"))
480    }
481
482    pub fn encode_pcm(&mut self, samples: &[f32]) -> Result<Vec<f32>> {
483        let mel = pcm_to_mel(&self.graph_ctx.cfg, samples);
484        self.encode_mel(&mel)
485    }
486
487    pub fn encode_wav(&mut self, path: &Path) -> Result<Vec<f32>> {
488        let samples = crate::audio::load_wav_mono_f32(path)?;
489        self.encode_pcm(&samples)
490    }
491
492    fn cross_cache(&mut self, enc: &[f32]) -> Result<WhisperCrossCache> {
493        self.ensure_cross(1)?;
494        let outs = metal_compile_guard(self.decode_device, || {
495            self.cross_compile_cache
496                .get_or_compile(1, || panic!("cross cache missing"))
497                .run(&[("encoder_hidden", enc)])
498        });
499        let cross = cross_from_outputs(
500            self.graph_ctx.cfg.decoder_layers,
501            1,
502            self.graph_ctx.enc_seq,
503            self.graph_ctx.cfg.d_model,
504            &outs,
505        )
506        .map_err(|e| anyhow::anyhow!(e))?;
507        self.cross_gpu_epoch = self.cross_gpu_epoch.saturating_add(1);
508        Ok(cross)
509    }
510
511    pub fn prefill_prompt(
512        &mut self,
513        cross: &WhisperCrossCache,
514        prompt_tokens: &[u32],
515        batch: usize,
516    ) -> Result<(Vec<f32>, WhisperKvCache)> {
517        let dec_seq = prompt_tokens.len();
518        let key = decode_cache_key(batch, dec_seq);
519
520        metal_compile_guard(self.prefill_device, || {
521            compile_cache_ensure_built_with_options(
522                &mut self.prefill_compile_cache,
523                key,
524                self.graph_ctx.build_prefill(batch, dec_seq)?,
525                &self.compile_opts.prefill,
526            )
527        })?;
528        let token_f32 = if batch == 1 {
529            prompt_tokens.iter().map(|&t| t as f32).collect()
530        } else {
531            batched_prompt_f32(prompt_tokens, batch)
532        };
533        let enc_seq = self.graph_ctx.enc_seq;
534        let d_model = self.graph_ctx.cfg.d_model;
535        let n_layers = self.graph_ctx.cfg.decoder_layers;
536        let epoch = self.cross_gpu_epoch;
537        let bound_epoch = self.cross_gpu_bound_epoch;
538        let use_gpu = self.use_gpu_kv;
539        let mut cross_on_gpu = use_gpu && bound_epoch == epoch;
540        let cross_bound = {
541            let prefill = self
542                .prefill_compile_cache
543                .get_or_compile(key, || panic!("prefill cache missing"));
544            Self::bind_cross_gpu_if_needed(
545                prefill,
546                cross,
547                enc_seq,
548                d_model,
549                n_layers,
550                epoch,
551                bound_epoch,
552                use_gpu,
553            )?
554        };
555        if cross_bound {
556            self.cross_gpu_bound_epoch = epoch;
557            cross_on_gpu = true;
558        }
559        let prefill = self
560            .prefill_compile_cache
561            .get_or_compile(key, || panic!("prefill cache missing"));
562        let mut inputs: Vec<(&str, &[f32])> = vec![("token_ids", &token_f32)];
563        if !cross_on_gpu {
564            for i in 0..self.graph_ctx.cfg.decoder_layers {
565                inputs.push((
566                    self.cross_input_names[2 * i].as_str(),
567                    cross.layers_k[i].as_slice(),
568                ));
569                inputs.push((
570                    self.cross_input_names[2 * i + 1].as_str(),
571                    cross.layers_v[i].as_slice(),
572                ));
573            }
574        }
575        let outputs = metal_compile_guard(self.prefill_device, || prefill.run(&inputs));
576        ensure!(!outputs.is_empty(), "prefill returned no outputs");
577        let logits = outputs[0].clone();
578        let kv = kv_from_prefill_outputs(
579            self.graph_ctx.cfg.decoder_layers,
580            batch,
581            dec_seq,
582            self.graph_ctx.cfg.d_model,
583            &outputs[1..],
584        )
585        .map_err(|e| anyhow::anyhow!(e))?;
586        Ok((logits, kv))
587    }
588
589    fn decode_step_bucketed(
590        &mut self,
591        cross: &WhisperCrossCache,
592        token: u32,
593        cache: &mut WhisperKvCache,
594        batch: usize,
595    ) -> Result<Vec<f32>> {
596        self.decode_step_batch(cross, std::slice::from_ref(&token), cache, batch, false)
597    }
598
599    fn decode_step_batch(
600        &mut self,
601        cross: &WhisperCrossCache,
602        tokens: &[u32],
603        cache: &mut WhisperKvCache,
604        batch: usize,
605        sync_kv_to_host: bool,
606    ) -> Result<Vec<f32>> {
607        ensure!(
608            tokens.len() == batch,
609            "decode_step_batch: expected {batch} tokens, got {}",
610            tokens.len()
611        );
612        self.ensure_decode_batch(batch)?;
613        let past_seq = cache.past_len;
614        let bucket_key = past_seq as u64;
615        if self.use_gpu_kv {
616            return self.decode_step_batch_gpu(
617                cross,
618                tokens,
619                cache,
620                batch,
621                bucket_key,
622                past_seq,
623                sync_kv_to_host,
624            );
625        }
626        self.decode_step_batch_host(cross, tokens, cache, batch, bucket_key, past_seq)
627    }
628
629    fn decode_step_batch_gpu(
630        &mut self,
631        cross: &WhisperCrossCache,
632        tokens: &[u32],
633        cache: &mut WhisperKvCache,
634        batch: usize,
635        key: u64,
636        past_seq: usize,
637        sync_kv_to_host: bool,
638    ) -> Result<Vec<f32>> {
639        let graph_ctx = self.graph_ctx.clone();
640        let decode_opts = self.compile_opts.decode.clone();
641        let d_model = self.graph_ctx.cfg.d_model;
642        let n_layers = self.graph_ctx.cfg.decoder_layers;
643
644        metal_compile_guard(self.decode_device, || {
645            bucket_cache_ensure_built(
646                &mut self.decode_compile_cache,
647                key,
648                |upper| graph_ctx.build_decode_step(batch, upper as usize),
649                &decode_opts,
650            )
651        })
652        .ok_or_else(|| anyhow::anyhow!("past_seq {past_seq} outside decode buckets"))?;
653
654        let upper = self
655            .decode_upper_for_key(key)
656            .ok_or_else(|| anyhow::anyhow!("past_seq {past_seq} outside decode buckets"))?;
657        self.prepare_decode_step_inputs(tokens, past_seq, upper);
658        let token_f32 = &self.decode_token_f32;
659        let pos_ix = &self.decode_pos_ix;
660        let mask = &self.decode_mask;
661        let mut specs: Vec<CacheRunInput<'_>> = vec![
662            CacheRunInput {
663                name: "token_id",
664                data: token_f32,
665                row_inner: None,
666            },
667            CacheRunInput {
668                name: "pos_ix",
669                data: pos_ix,
670                row_inner: None,
671            },
672            CacheRunInput {
673                name: "mask",
674                data: mask,
675                row_inner: None,
676            },
677        ];
678        let epoch = self.cross_gpu_epoch;
679        let bound_epoch = self.cross_gpu_bound_epoch;
680        let use_gpu = self.use_gpu_kv;
681        let enc_seq = self.graph_ctx.enc_seq;
682        let mut cross_on_gpu = use_gpu && bound_epoch == epoch;
683        if let Some(compiled) = self.decode_compile_cache.compiled_for_key_mut(key) {
684            if Self::bind_cross_gpu_if_needed(
685                compiled,
686                cross,
687                enc_seq,
688                d_model,
689                n_layers,
690                epoch,
691                bound_epoch,
692                use_gpu,
693            )? {
694                self.cross_gpu_bound_epoch = epoch;
695                cross_on_gpu = true;
696            }
697        }
698        if !cross_on_gpu {
699            for i in 0..n_layers {
700                specs.push(CacheRunInput {
701                    name: self.cross_input_names[2 * i].as_str(),
702                    data: cross.layers_k[i].as_slice(),
703                    row_inner: None,
704                });
705                specs.push(CacheRunInput {
706                    name: self.cross_input_names[2 * i + 1].as_str(),
707                    data: cross.layers_v[i].as_slice(),
708                    row_inner: None,
709                });
710            }
711        }
712
713        let upper_u = upper as u64;
714        let prev_upper = self.gpu_kv_binding.upper;
715        let bucket_changed = prev_upper != 0 && prev_upper != upper_u;
716        let handles_live = self
717            .decode_compile_cache
718            .compiled_for_key_mut(key)
719            .map(|c| c.has_gpu_handle("past_k_0"))
720            .unwrap_or(false);
721        let refresh_kv = if self.decode_device == Device::Gpu {
722            // wgpu handle feeds drift within a bucket; re-upload prefix each step.
723            true
724        } else {
725            bucket_changed || !handles_live
726        };
727
728        let logits = metal_compile_guard(self.decode_device, || {
729            run_bucketed_kv_decode_gpu(
730                &mut self.decode_compile_cache,
731                key,
732                past_seq,
733                cache,
734                &mut self.gpu_kv_binding,
735                d_model,
736                n_layers,
737                &specs,
738                |upper| {
739                    let built = graph_ctx
740                        .build_decode_step(batch, upper as usize)
741                        .expect("whisper decode step built");
742                    graph_from_built(built).expect("whisper decode step graph")
743                },
744                &decode_opts,
745                refresh_kv,
746            )
747        })?;
748
749        let force_host_kv = self.decode_device == Device::Gpu;
750        let next_upper = self
751            .decode_upper_for_key((past_seq + 1) as u64)
752            .unwrap_or(upper);
753        let leaves_bucket = next_upper != upper;
754
755        if sync_kv_to_host || leaves_bucket || force_host_kv {
756            if let Some(compiled) = self.decode_compile_cache.compiled_for_key_mut(key) {
757                sync_gpu_kv_to_host(compiled, cache, d_model, n_layers)?;
758            }
759        }
760        Ok(logits)
761    }
762
763    fn ensure_decode_batch(&mut self, batch: usize) -> Result<()> {
764        let batch_tag = batch as u64;
765        if self.decode_batch_tag == batch_tag {
766            return Ok(());
767        }
768        self.gpu_kv_binding = GpuKvBinding::default();
769        self.decode_batch_tag = batch_tag;
770        let max_past = self.graph_ctx.cfg.max_target_positions.max(1) as u64;
771        self.decode_compile_cache = decode_bucket_ladder(self.decode_device, max_past);
772        Ok(())
773    }
774
775    fn decode_upper_for_key(&self, key: u64) -> Option<usize> {
776        self.decode_compile_cache.bucket_for(key).and_then(|idx| {
777            self.decode_compile_cache
778                .buckets()
779                .nth(idx)
780                .map(|r| (r.end - 1) as usize)
781        })
782    }
783
784    fn decode_step_batch_host(
785        &mut self,
786        cross: &WhisperCrossCache,
787        tokens: &[u32],
788        cache: &mut WhisperKvCache,
789        batch: usize,
790        key: u64,
791        past_seq: usize,
792    ) -> Result<Vec<f32>> {
793        let graph_ctx = self.graph_ctx.clone();
794        let d_model = self.graph_ctx.cfg.d_model;
795        let n_layers = self.graph_ctx.cfg.decoder_layers;
796        let upper = self
797            .decode_upper_for_key(key)
798            .ok_or_else(|| anyhow::anyhow!("past_seq {past_seq} outside decode buckets"))?;
799        self.prepare_decode_step_inputs(tokens, past_seq, upper);
800        let token_f32 = &self.decode_token_f32;
801        let pos_ix = &self.decode_pos_ix;
802        let mask = &self.decode_mask;
803        let mut specs: Vec<CacheRunInput<'_>> = vec![
804            CacheRunInput {
805                name: "token_id",
806                data: token_f32,
807                row_inner: None,
808            },
809            CacheRunInput {
810                name: "pos_ix",
811                data: pos_ix,
812                row_inner: None,
813            },
814            CacheRunInput {
815                name: "mask",
816                data: mask,
817                row_inner: None,
818            },
819        ];
820        let epoch = self.cross_gpu_epoch;
821        let bound_epoch = self.cross_gpu_bound_epoch;
822        let use_gpu = self.use_gpu_kv;
823        let enc_seq = self.graph_ctx.enc_seq;
824        let mut cross_on_gpu = use_gpu && bound_epoch == epoch;
825        if let Some(compiled) = self.decode_compile_cache.compiled_for_key_mut(key) {
826            if Self::bind_cross_gpu_if_needed(
827                compiled,
828                cross,
829                enc_seq,
830                d_model,
831                n_layers,
832                epoch,
833                bound_epoch,
834                use_gpu,
835            )? {
836                self.cross_gpu_bound_epoch = epoch;
837                cross_on_gpu = true;
838            }
839        }
840        if !cross_on_gpu {
841            for i in 0..n_layers {
842                specs.push(CacheRunInput {
843                    name: self.cross_input_names[2 * i].as_str(),
844                    data: cross.layers_k[i].as_slice(),
845                    row_inner: None,
846                });
847                specs.push(CacheRunInput {
848                    name: self.cross_input_names[2 * i + 1].as_str(),
849                    data: cross.layers_v[i].as_slice(),
850                    row_inner: None,
851                });
852            }
853        }
854
855        let (logits, new_k, new_v) = metal_compile_guard(self.decode_device, || {
856            run_bucketed_kv_decode_keyed(
857                &mut self.decode_compile_cache,
858                key,
859                past_seq,
860                cache,
861                d_model,
862                n_layers,
863                &specs,
864                |upper| {
865                    let built = graph_ctx
866                        .build_decode_step(batch, upper as usize)
867                        .expect("whisper decode step built");
868                    graph_from_built(built).expect("whisper decode step graph")
869                },
870                &self.compile_opts.decode,
871            )
872        })?;
873
874        apply_bucketed_decode_step(cache, new_k, new_v, batch, d_model)
875            .map_err(|e| anyhow::anyhow!(e))?;
876        Ok(logits)
877    }
878
879    /// Exchange bucketed decode compile caches (precision: share CPU-compiled graphs).
880    pub fn swap_decode_cache(&mut self, other: &mut Self) {
881        std::mem::swap(
882            &mut self.decode_compile_cache,
883            &mut other.decode_compile_cache,
884        );
885        std::mem::swap(&mut self.decode_batch_tag, &mut other.decode_batch_tag);
886        self.gpu_kv_binding = GpuKvBinding::default();
887        other.gpu_kv_binding = GpuKvBinding::default();
888    }
889
890    /// Single greedy decode step (for cross-backend parity checks).
891    pub fn decode_one_step(
892        &mut self,
893        cross: &WhisperCrossCache,
894        token: u32,
895        cache: &mut WhisperKvCache,
896    ) -> Result<Vec<f32>> {
897        self.decode_step_bucketed(cross, token, cache, 1)
898    }
899
900    fn decode_step(
901        &mut self,
902        cross: &WhisperCrossCache,
903        token: u32,
904        cache: &mut WhisperKvCache,
905        batch: usize,
906    ) -> Result<Vec<f32>> {
907        self.decode_step_bucketed(cross, token, cache, batch)
908    }
909
910    pub fn encode_mel_batch(&mut self, mels: &[MelSpectrogram]) -> Result<Vec<f32>> {
911        if mels.is_empty() {
912            return Ok(Vec::new());
913        }
914        let batch = mels.len();
915        let mel_input: Vec<f32> = if batch == 1 {
916            mels[0].data.clone()
917        } else {
918            stack_mels(mels)
919        };
920        self.ensure_encoder(batch)?;
921        metal_compile_guard(self.decode_device, || {
922            self.enc_compile_cache
923                .get_or_compile(batch as u64, || panic!("encoder cache missing"))
924                .run(&[("mel", &mel_input)])
925        })
926        .into_iter()
927        .next()
928        .ok_or_else(|| anyhow::anyhow!("encoder produced no output"))
929    }
930
931    /// Per-stage greedy-decode timings (after optional warmup). Includes last prefill logits for parity checks.
932    #[cfg(feature = "tokenizer")]
933    pub fn bench_greedy_pipeline(
934        &mut self,
935        pcm: &[f32],
936        decode_steps: usize,
937        warmup: usize,
938    ) -> Result<(WhisperBenchReport, String)> {
939        use std::time::Instant;
940        let mel = pcm_to_mel(&self.graph_ctx.cfg, pcm);
941        for _ in 0..warmup {
942            let enc = self.encode_mel(&mel)?;
943            self.bench_greedy_from_encoder(&enc, decode_steps.min(2))?;
944        }
945        let t_enc = Instant::now();
946        let enc = self.encode_mel(&mel)?;
947        let encode_ms = t_enc.elapsed().as_secs_f64() * 1000.0;
948        let (mut report, transcript) = self.bench_greedy_from_encoder(&enc, decode_steps)?;
949        report.encode_ms = encode_ms;
950        report.greedy_ms =
951            report.encode_ms + report.cross_ms + report.prefill_ms + report.decode_ms;
952        Ok((report, transcript))
953    }
954
955    /// Greedy decode benchmark from a fixed encoder output (cross-backend precision: share CPU `enc`).
956    #[cfg(feature = "tokenizer")]
957    pub fn bench_greedy_from_encoder(
958        &mut self,
959        enc: &[f32],
960        decode_steps: usize,
961    ) -> Result<(WhisperBenchReport, String)> {
962        use std::time::Instant;
963        let t_cross = Instant::now();
964        let cross = self.cross_cache_batch(enc, 1)?;
965        let cross_ms = t_cross.elapsed().as_secs_f64() * 1000.0;
966        let (mut report, transcript) = self.bench_greedy_from_cross(&cross, decode_steps)?;
967        report.cross_ms = cross_ms;
968        report.greedy_ms =
969            report.encode_ms + report.cross_ms + report.prefill_ms + report.decode_ms;
970        Ok((report, transcript))
971    }
972
973    /// Greedy decode from a fixed cross-attention cache (share CPU cross for precision).
974    #[cfg(feature = "tokenizer")]
975    pub fn bench_greedy_from_cross(
976        &mut self,
977        cross: &WhisperCrossCache,
978        decode_steps: usize,
979    ) -> Result<(WhisperBenchReport, String)> {
980        use std::time::Instant;
981
982        let prompt = self.build_prompt()?;
983        let t_pre = Instant::now();
984        let (prefill_logits, cache) = self.prefill_prompt(cross, &prompt, 1)?;
985        let prefill_ms = t_pre.elapsed().as_secs_f64() * 1000.0;
986        let (mut report, transcript) = self.bench_greedy_decode_from_state(
987            cross,
988            &prompt,
989            prefill_logits,
990            cache,
991            decode_steps,
992        )?;
993        report.prefill_ms = prefill_ms;
994        report.greedy_ms =
995            report.encode_ms + report.cross_ms + report.prefill_ms + report.decode_ms;
996        Ok((report, transcript))
997    }
998
999    /// Greedy decode from CPU prefill logits + KV (cross-backend decode parity).
1000    #[cfg(feature = "tokenizer")]
1001    pub fn bench_greedy_decode_from_state(
1002        &mut self,
1003        cross: &WhisperCrossCache,
1004        prompt: &[u32],
1005        prefill_logits: Vec<f32>,
1006        mut cache: WhisperKvCache,
1007        decode_steps: usize,
1008    ) -> Result<(WhisperBenchReport, String)> {
1009        use std::time::Instant;
1010
1011        let steps = decode_steps.min(self.max_decode_steps);
1012        let vocab = self.graph_ctx.cfg.vocab_size;
1013        let eot = self.eot_id()?;
1014        let last_prefill_logits = prefill_logits.clone();
1015
1016        let t_dec = Instant::now();
1017        let mut tokens = prompt.to_vec();
1018        let mut next_logits = last_logits_row(&prefill_logits, prompt.len(), vocab);
1019        let mut done_steps = 0usize;
1020        for (n_gen, _) in (0..steps).enumerate() {
1021            let mut row = next_logits;
1022            let next = self.suppression.argmax_next(&mut row, n_gen == 0);
1023            tokens.push(next);
1024            done_steps += 1;
1025            if next == eot {
1026                break;
1027            }
1028            let step_logits = self.decode_step(cross, next, &mut cache, 1)?;
1029            next_logits = if step_logits.len() == vocab {
1030                step_logits
1031            } else {
1032                // Bucketed decode graphs emit a single new-token row; not `past_len` rows.
1033                last_logits_row(&step_logits, 1, vocab)
1034            };
1035        }
1036        let decode_ms = t_dec.elapsed().as_secs_f64() * 1000.0;
1037        let transcript = self.decode_tokens(&tokens)?;
1038
1039        let report = WhisperBenchReport {
1040            encode_ms: 0.0,
1041            cross_ms: 0.0,
1042            prefill_ms: 0.0,
1043            decode_ms,
1044            decode_steps: done_steps,
1045            greedy_ms: 0.0,
1046            last_prefill_logits,
1047        };
1048        Ok((report, transcript))
1049    }
1050
1051    pub fn cross_cache_batch(&mut self, enc: &[f32], batch: usize) -> Result<WhisperCrossCache> {
1052        self.ensure_cross(batch)?;
1053        let outs = metal_compile_guard(self.decode_device, || {
1054            self.cross_compile_cache
1055                .get_or_compile(batch as u64, || panic!("cross cache missing"))
1056                .run(&[("encoder_hidden", enc)])
1057        });
1058        let cross = cross_from_outputs(
1059            self.graph_ctx.cfg.decoder_layers,
1060            batch,
1061            self.graph_ctx.enc_seq,
1062            self.graph_ctx.cfg.d_model,
1063            &outs,
1064        )
1065        .map_err(|e| anyhow::anyhow!(e))?;
1066        self.cross_gpu_epoch = self.cross_gpu_epoch.saturating_add(1);
1067        Ok(cross)
1068    }
1069
1070    #[cfg(feature = "tokenizer")]
1071    pub fn transcribe_greedy(&mut self, pcm: &[f32]) -> Result<String> {
1072        self.transcribe_cached(pcm, 1)
1073    }
1074
1075    #[cfg(feature = "tokenizer")]
1076    pub fn transcribe_beam(&mut self, pcm: &[f32]) -> Result<String> {
1077        let beam = if self.beam_size == 0 {
1078            5
1079        } else {
1080            self.beam_size
1081        };
1082        self.transcribe_cached(pcm, beam)
1083    }
1084
1085    #[cfg(feature = "tokenizer")]
1086    pub fn transcribe_with_vad(&mut self, pcm: &[f32]) -> Result<String> {
1087        let vad = self.vad_config.clone().unwrap_or_default();
1088        let regions = segments_by_vad(&vad, pcm);
1089        if regions.len() <= 1 {
1090            return self.transcribe_cached(pcm, 1);
1091        }
1092        let beam = if self.beam_size == 0 {
1093            1
1094        } else {
1095            self.beam_size
1096        };
1097        let texts = self.transcribe_regions_batched(pcm, &regions, beam)?;
1098        Ok(texts.join(" "))
1099    }
1100
1101    #[cfg(feature = "tokenizer")]
1102    pub fn transcribe_regions_batched(
1103        &mut self,
1104        pcm: &[f32],
1105        regions: &[crate::audio::SpeechSegment],
1106        beam_size: usize,
1107    ) -> Result<Vec<String>> {
1108        if regions.is_empty() {
1109            return Ok(Vec::new());
1110        }
1111        let mut out = Vec::with_capacity(regions.len());
1112        let prompt = self.build_prompt()?;
1113        for chunk in regions.chunks(self.max_region_batch) {
1114            let n = chunk.len();
1115            let mels: Vec<MelSpectrogram> = chunk
1116                .iter()
1117                .map(|seg| pcm_to_mel(&self.graph_ctx.cfg, &pcm[seg.start..seg.end]))
1118                .collect();
1119            let enc_n = self.encode_mel_batch(&mels)?;
1120            let texts = if beam_size <= 1 {
1121                self.greedy_decode_batch(&enc_n, n, &prompt)?
1122            } else {
1123                self.beam_decode_batch(&enc_n, n, beam_size, &prompt)?
1124            };
1125            out.extend(texts);
1126        }
1127        Ok(out)
1128    }
1129
1130    #[cfg(feature = "tokenizer")]
1131    fn greedy_decode_batch(
1132        &mut self,
1133        enc: &[f32],
1134        n_regions: usize,
1135        prompt: &[u32],
1136    ) -> Result<Vec<String>> {
1137        let cross = self.cross_cache_batch(enc, n_regions)?;
1138        let (prefill_logits, mut cache) = self.prefill_prompt(&cross, prompt, n_regions)?;
1139        let mut tokens: Vec<Vec<u32>> = (0..n_regions).map(|_| prompt.to_vec()).collect();
1140        let mut done = vec![false; n_regions];
1141        let vocab = self.graph_ctx.cfg.vocab_size;
1142        let eot = self.eot_id()?;
1143        let mut last_logits = prefill_logits;
1144
1145        for _ in 0..self.max_decode_steps {
1146            if done.iter().all(|&d| d) {
1147                break;
1148            }
1149            let mut step_tokens = vec![eot; n_regions];
1150            for b in 0..n_regions {
1151                if done[b] {
1152                    continue;
1153                }
1154                let mut row =
1155                    batched_logits_row_owned(&last_logits, b, n_regions, tokens[b].len(), vocab);
1156                let at_begin = tokens[b].len() == prompt.len();
1157                step_tokens[b] = self.suppression.argmax_next(&mut row, at_begin);
1158            }
1159            let new_logits =
1160                self.decode_step_batch(&cross, &step_tokens, &mut cache, n_regions, false)?;
1161            last_logits = new_logits;
1162            for b in 0..n_regions {
1163                if done[b] {
1164                    continue;
1165                }
1166                tokens[b].push(step_tokens[b]);
1167                if step_tokens[b] == eot {
1168                    done[b] = true;
1169                }
1170            }
1171        }
1172        tokens.into_iter().map(|t| self.decode_tokens(&t)).collect()
1173    }
1174
1175    #[cfg(feature = "tokenizer")]
1176    fn beam_decode_batch(
1177        &mut self,
1178        enc: &[f32],
1179        n_regions: usize,
1180        beam_size: usize,
1181        prompt: &[u32],
1182    ) -> Result<Vec<String>> {
1183        let plane = self.graph_ctx.enc_seq * self.graph_ctx.cfg.d_model;
1184        let enc_rep = replicate_encoder_for_beams(enc, n_regions, beam_size, plane);
1185        let batch = n_regions * beam_size;
1186        let cross = self.cross_cache_batch(&enc_rep, batch)?;
1187        let (prefill_logits, cache) = self.prefill_prompt(&cross, prompt, batch)?;
1188        let eot = self.eot_id()?;
1189        let cross_ref = &cross;
1190        let suffixes = beam_search_decode_kv_batched(
1191            &prefill_logits,
1192            prompt.len(),
1193            cache,
1194            n_regions,
1195            beam_size,
1196            self.max_decode_steps,
1197            self.graph_ctx.cfg.vocab_size,
1198            eot,
1199            |tokens, cache| self.decode_step_batch(cross_ref, tokens, cache, batch, true),
1200        )?;
1201        suffixes
1202            .into_iter()
1203            .map(|suffix| {
1204                let mut t = prompt.to_vec();
1205                t.extend(suffix);
1206                self.decode_tokens(&t)
1207            })
1208            .collect()
1209    }
1210
1211    #[cfg(feature = "tokenizer")]
1212    fn greedy_extend_after_prefill(
1213        &mut self,
1214        cross: &WhisperCrossCache,
1215        prompt: &[u32],
1216        mut cache: WhisperKvCache,
1217        prefill_logits: &[f32],
1218        max_steps: usize,
1219    ) -> Result<Vec<u32>> {
1220        let vocab = self.graph_ctx.cfg.vocab_size;
1221        let eot = self.eot_id()?;
1222        let prompt_len = prompt.len();
1223        let mut tokens = prompt.to_vec();
1224        let mut next_logits = last_logits_row(prefill_logits, prompt_len, vocab);
1225        for (n_gen, _) in (0..max_steps).enumerate() {
1226            let mut row = next_logits;
1227            let next = self.suppression.argmax_next(&mut row, n_gen == 0);
1228            tokens.push(next);
1229            if next == eot {
1230                break;
1231            }
1232            let step_logits = self.decode_step(cross, next, &mut cache, 1)?;
1233            next_logits = if step_logits.len() == vocab {
1234                step_logits
1235            } else {
1236                last_logits_row(&step_logits, 1, vocab)
1237            };
1238        }
1239        Ok(tokens)
1240    }
1241
1242    fn transcribe_cross(&mut self, cross: WhisperCrossCache, beam_size: usize) -> Result<String> {
1243        let prompt = self.build_prompt()?;
1244        let cross_ref = &cross;
1245        if beam_size <= 1 {
1246            let (prefill_logits, cache) = self.prefill_prompt(cross_ref, &prompt, 1)?;
1247            let tokens = self.greedy_extend_after_prefill(
1248                cross_ref,
1249                &prompt,
1250                cache,
1251                &prefill_logits,
1252                self.max_decode_steps,
1253            )?;
1254            return self.decode_tokens(&tokens);
1255        }
1256        let (prefill_logits, base_cache) = self.prefill_prompt(cross_ref, &prompt, 1)?;
1257        let extra = beam_search_decode_kv(
1258            &prefill_logits,
1259            prompt.len(),
1260            base_cache,
1261            self.eot_id()?,
1262            beam_size,
1263            self.max_decode_steps,
1264            self.graph_ctx.cfg.vocab_size,
1265            |token, cache| {
1266                let mut branch = cache.clone();
1267                let logits = self.decode_step(cross_ref, token, &mut branch, 1)?;
1268                let mut row = last_logits_row(&logits, 1, self.graph_ctx.cfg.vocab_size);
1269                self.suppression.apply(&mut row);
1270                Ok((row, branch))
1271            },
1272        )?;
1273        let mut tokens = prompt;
1274        tokens.extend(extra);
1275        self.decode_tokens(&tokens)
1276    }
1277
1278    #[cfg(feature = "tokenizer")]
1279    pub fn build_prompt(&self) -> Result<Vec<u32>> {
1280        let tok = self
1281            .tokenizer
1282            .as_ref()
1283            .ok_or_else(|| anyhow::anyhow!("tokenizer not loaded"))?;
1284        initial_prompt_opts(
1285            tok,
1286            self.language.as_deref(),
1287            self.translate,
1288            self.timestamps,
1289        )
1290    }
1291
1292    #[cfg(feature = "tokenizer")]
1293    fn eot_id(&self) -> Result<u32> {
1294        self.tokenizer
1295            .as_ref()
1296            .and_then(|t| t.token_to_id(EOT_TOKEN))
1297            .ok_or_else(|| anyhow::anyhow!("tokenizer missing {EOT_TOKEN}"))
1298    }
1299
1300    #[cfg(feature = "tokenizer")]
1301    fn decode_tokens(&self, tokens: &[u32]) -> Result<String> {
1302        let tok = self
1303            .tokenizer
1304            .as_ref()
1305            .ok_or_else(|| anyhow::anyhow!("tokenizer not loaded"))?;
1306        tok.decode(tokens, true)
1307            .map_err(|e| anyhow::anyhow!("decode tokens: {e}"))
1308    }
1309
1310    fn transcribe_cached(&mut self, pcm: &[f32], beam_size: usize) -> Result<String> {
1311        if self.vad_config.is_some() {
1312            return self.transcribe_with_vad(pcm);
1313        }
1314        let enc = self.encode_pcm(pcm)?;
1315        let cross = self.cross_cache(&enc)?;
1316        self.transcribe_cross(cross, beam_size)
1317    }
1318}