1use crate::audio::{MelSpectrogram, N_FRAMES, pcm_to_mel};
17use crate::backend::{
18 WhisperCompileOpts, WhisperGraphCtx, decode_bucket_ladder, decode_cache_key,
19 metal_compile_guard, whisper_decoder_device, whisper_use_gpu_kv,
20};
21use crate::batch::{batched_prompt_f32, replicate_encoder_for_beams};
22use crate::builder::WhisperGraphOpts;
23use crate::cache::{
24 WhisperCrossCache, WhisperKvCache, apply_bucketed_decode_step, cross_from_outputs,
25 kv_from_prefill_outputs,
26};
27use crate::config::WhisperConfig;
28use crate::decode::{
29 EOT_TOKEN, SuppressionMask, batched_logits_row_owned, beam_search_decode_kv,
30 beam_search_decode_kv_batched, initial_prompt_opts, last_logits_row,
31};
32use crate::fused::{FusedDecoderWeights, FusedEncoderWeights};
33use crate::mel::stack_mels;
34use crate::vad::{VadConfig, segments_by_vad};
35use crate::weights::WhisperWeightPrefix;
36use anyhow::{Context, Result, bail, ensure};
37use rlx_core::flow_util::{
38 bucket_cache_ensure_built, compile_cache_ensure_built_with_options, graph_from_built,
39};
40use rlx_core::validate_standard_device;
41use rlx_core::weight_map::WeightMap;
42use rlx_core::{
43 GpuKvBinding, cross_attn_gpu_handles_ready, install_cross_attn_gpu_handles,
44 run_bucketed_kv_decode_gpu, run_bucketed_kv_decode_keyed, sync_gpu_kv_to_host,
45};
46use rlx_ir::DType;
47use rlx_runtime::attn_mask::bucket_decode_mask;
48use rlx_runtime::compile_cache::{BucketedCompileCache, CacheRunInput, CompileCache};
49use rlx_runtime::{CompiledGraph, Device};
50use std::path::{Path, PathBuf};
51use std::sync::Arc;
52
53#[derive(Debug, Clone)]
54pub struct WhisperRunnerBuilder {
55 weights: Option<PathBuf>,
56 config_path: Option<PathBuf>,
57 tokenizer_path: Option<PathBuf>,
58 config: Option<WhisperConfig>,
59 device: Option<Device>,
60 mel_frames: usize,
61 max_decode_steps: usize,
62 beam_size: usize,
63 language: Option<String>,
64 translate: bool,
65 timestamps: bool,
66 activation_dtype: DType,
67 use_f16_compute: bool,
68 vad_config: Option<VadConfig>,
69 max_region_batch: usize,
70 encoder_attn_chunk: usize,
71}
72
73impl Default for WhisperRunnerBuilder {
74 fn default() -> Self {
75 Self {
76 weights: None,
77 config_path: None,
78 tokenizer_path: None,
79 config: None,
80 device: None,
81 mel_frames: 0,
82 max_decode_steps: 0,
83 beam_size: 0,
84 language: None,
85 translate: false,
86 timestamps: false,
87 activation_dtype: DType::F32,
88 use_f16_compute: false,
89 vad_config: None,
90 max_region_batch: 10,
91 encoder_attn_chunk: crate::builder::DEFAULT_ENCODER_ATTN_CHUNK,
92 }
93 }
94}
95
96impl WhisperRunnerBuilder {
97 pub fn weights<P: Into<PathBuf>>(mut self, path: P) -> Self {
98 self.weights = Some(path.into());
99 self
100 }
101 pub fn config_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
102 self.config_path = Some(path.into());
103 self
104 }
105 pub fn tokenizer_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
106 self.tokenizer_path = Some(path.into());
107 self
108 }
109 pub fn config(mut self, cfg: WhisperConfig) -> Self {
110 self.config = Some(cfg);
111 self
112 }
113 pub fn device(mut self, d: Device) -> Self {
114 self.device = Some(d);
115 self
116 }
117 pub fn language(mut self, lang: impl Into<String>) -> Self {
118 self.language = Some(lang.into());
119 self
120 }
121 pub fn translate(mut self, on: bool) -> Self {
122 self.translate = on;
123 self
124 }
125 pub fn timestamps(mut self, on: bool) -> Self {
126 self.timestamps = on;
127 self
128 }
129 pub fn activation_dtype(mut self, dt: DType) -> Self {
130 self.activation_dtype = dt;
131 self
132 }
133 pub fn use_f16_compute(mut self, on: bool) -> Self {
134 self.use_f16_compute = on;
135 self
136 }
137 pub fn vad_config(mut self, cfg: VadConfig) -> Self {
138 self.vad_config = Some(cfg);
139 self
140 }
141 pub fn max_region_batch(mut self, n: usize) -> Self {
142 self.max_region_batch = n.max(1);
143 self
144 }
145 pub fn encoder_attn_chunk(mut self, n: usize) -> Self {
146 self.encoder_attn_chunk = n;
147 self
148 }
149 pub fn max_decode_steps(mut self, n: usize) -> Self {
150 self.max_decode_steps = n;
151 self
152 }
153 pub fn beam_size(mut self, n: usize) -> Self {
154 self.beam_size = n;
155 self
156 }
157
158 pub fn build(self) -> Result<WhisperRunner> {
159 let weights_path = self
160 .weights
161 .ok_or_else(|| anyhow::anyhow!("weights path required"))?;
162 if !weights_path.exists() {
163 bail!("weights file not found: {weights_path:?}");
164 }
165 let weights_dir = weights_path
166 .parent()
167 .ok_or_else(|| anyhow::anyhow!("weights path has no parent"))?;
168 let cfg_path = self
169 .config_path
170 .clone()
171 .unwrap_or_else(|| weights_dir.join("config.json"));
172 let cfg = match self.config {
173 Some(c) => c,
174 None => WhisperConfig::from_file(&cfg_path)
175 .with_context(|| format!("reading config {cfg_path:?}"))?,
176 };
177 let tok_path = self
178 .tokenizer_path
179 .clone()
180 .unwrap_or_else(|| weights_dir.join("tokenizer.json"));
181 let device = self.device.unwrap_or(Device::Cpu);
182 validate_standard_device("whisper", device)?;
183 let mel_frames = if self.mel_frames == 0 {
184 N_FRAMES
185 } else {
186 self.mel_frames
187 };
188 let max_decode_steps = if self.max_decode_steps == 0 {
189 cfg.max_target_positions.saturating_sub(8)
190 } else {
191 self.max_decode_steps
192 };
193 let wt = weights_path
194 .to_str()
195 .ok_or_else(|| anyhow::anyhow!("non-utf8 weights path"))?;
196 let mut weights_cache = WeightMap::snapshot_from_path(wt)?;
197 let pfx = {
198 let wm = WeightMap::from_tensors(weights_cache.clone());
199 WhisperWeightPrefix::detect(&wm)
200 };
201 let fused = FusedDecoderWeights::from_checkpoint(&weights_cache, &cfg, &pfx)?;
202 let fused_enc = FusedEncoderWeights::from_checkpoint(&weights_cache, &cfg, &pfx)?;
203 fused.merge_into_tensors(&mut weights_cache);
204 fused_enc.merge_into_tensors(&mut weights_cache);
205 let mut graph_opts = if self.use_f16_compute || self.activation_dtype == DType::F16 {
206 WhisperGraphOpts::f16_mixed()
207 } else {
208 WhisperGraphOpts::default()
209 };
210 if self.encoder_attn_chunk != crate::builder::DEFAULT_ENCODER_ATTN_CHUNK {
211 graph_opts.encoder_attn_chunk = self.encoder_attn_chunk;
212 graph_opts.cross_attn_chunk = self.encoder_attn_chunk;
213 }
214 let suppression = SuppressionMask::from_config(&cfg);
215
216 let f16 = self.use_f16_compute || self.activation_dtype == DType::F16;
217 let mut compile_opts = WhisperCompileOpts::new(device, f16, &weights_path);
218 let decode_device = whisper_decoder_device(device);
220 let prefill_device = decode_device;
221 if decode_device != device {
222 let cpu_opts = WhisperCompileOpts::new(decode_device, f16, &weights_path);
223 compile_opts.encoder = cpu_opts.encoder.clone();
224 compile_opts.cross = cpu_opts.cross.clone();
225 compile_opts.decode = cpu_opts.decode.clone();
226 compile_opts.prefill = cpu_opts.prefill;
227 }
228 let use_gpu_kv = whisper_use_gpu_kv(device, decode_device);
229
230 let enc_seq = cfg.encoder_seq_len(mel_frames);
231 let weights_cache = Arc::new(weights_cache);
232 let graph_ctx = WhisperGraphCtx {
233 cfg: cfg.clone(),
234 pfx: pfx.clone(),
235 weights: Arc::clone(&weights_cache),
236 enc_seq,
237 mel_frames,
238 graph_opts,
239 fused: Some(fused.clone()),
240 fused_enc: Some(fused_enc.clone()),
241 };
242
243 let mut enc_compile_cache = CompileCache::new(decode_device, 8);
244 let mut cross_compile_cache = CompileCache::new(decode_device, 8);
245 metal_compile_guard(decode_device, || -> Result<()> {
246 compile_cache_ensure_built_with_options(
247 &mut enc_compile_cache,
248 1,
249 graph_ctx.build_encoder(1)?,
250 &compile_opts.encoder,
251 )?;
252 compile_cache_ensure_built_with_options(
253 &mut cross_compile_cache,
254 1,
255 graph_ctx.build_cross(1)?,
256 &compile_opts.cross,
257 )?;
258 Ok(())
259 })?;
260
261 let max_past = cfg.max_target_positions.max(1);
262 let decode_compile_cache = decode_bucket_ladder(decode_device, max_past as u64);
263
264 #[cfg(feature = "tokenizer")]
265 let tokenizer = {
266 ensure!(tok_path.exists(), "tokenizer not found: {tok_path:?}");
267 Some(
268 tokenizers::Tokenizer::from_file(&tok_path)
269 .map_err(|e| anyhow::anyhow!("load tokenizer {tok_path:?}: {e}"))?,
270 )
271 };
272
273 let cross_input_names: Vec<String> = (0..cfg.decoder_layers)
274 .flat_map(|i| [format!("cross_k_{i}"), format!("cross_v_{i}")])
275 .collect();
276
277 Ok(WhisperRunner {
278 graph_ctx,
279 device,
280 decode_device,
281 prefill_device,
282 activation_dtype: self.activation_dtype,
283 suppression,
284 max_decode_steps,
285 beam_size: self.beam_size,
286 max_region_batch: self.max_region_batch,
287 vad_config: self.vad_config,
288 compile_opts,
289 use_gpu_kv,
290 gpu_kv_binding: GpuKvBinding::default(),
291 cross_gpu_epoch: 0,
292 cross_gpu_bound_epoch: u64::MAX,
293 decode_batch_tag: u64::MAX,
294 enc_compile_cache,
295 cross_compile_cache,
296 prefill_compile_cache: CompileCache::new(prefill_device, 8),
297 decode_compile_cache,
298 decode_token_f32: Vec::new(),
299 decode_pos_ix: Vec::new(),
300 decode_mask: Vec::new(),
301 cross_input_names,
302 language: self.language,
303 translate: self.translate,
304 timestamps: self.timestamps,
305 #[cfg(feature = "tokenizer")]
306 tokenizer,
307 })
308 }
309}
310
311#[derive(Debug, Clone)]
313pub struct WhisperBenchReport {
314 pub encode_ms: f64,
315 pub cross_ms: f64,
316 pub prefill_ms: f64,
317 pub decode_ms: f64,
318 pub decode_steps: usize,
319 pub greedy_ms: f64,
320 pub last_prefill_logits: Vec<f32>,
322}
323
324pub struct WhisperRunner {
325 graph_ctx: WhisperGraphCtx,
326 pub device: Device,
327 decode_device: Device,
329 prefill_device: Device,
331 pub activation_dtype: DType,
332 suppression: SuppressionMask,
333 max_decode_steps: usize,
334 beam_size: usize,
335 max_region_batch: usize,
336 vad_config: Option<VadConfig>,
337 compile_opts: WhisperCompileOpts,
338 use_gpu_kv: bool,
339 gpu_kv_binding: GpuKvBinding,
340 cross_gpu_epoch: u64,
342 cross_gpu_bound_epoch: u64,
343 decode_batch_tag: u64,
344 enc_compile_cache: CompileCache,
345 cross_compile_cache: CompileCache,
346 prefill_compile_cache: CompileCache,
347 decode_compile_cache: BucketedCompileCache,
348 decode_token_f32: Vec<f32>,
349 decode_pos_ix: Vec<f32>,
350 decode_mask: Vec<f32>,
351 cross_input_names: Vec<String>,
352 language: Option<String>,
353 translate: bool,
354 timestamps: bool,
355 #[cfg(feature = "tokenizer")]
356 tokenizer: Option<tokenizers::Tokenizer>,
357}
358
359impl WhisperRunner {
360 pub fn builder() -> WhisperRunnerBuilder {
361 WhisperRunnerBuilder::default()
362 }
363
364 pub fn config(&self) -> &WhisperConfig {
365 &self.graph_ctx.cfg
366 }
367
368 pub fn decode_buckets_compiled(&self) -> usize {
370 self.decode_compile_cache.compiled_count()
371 }
372
373 fn prepare_decode_step_inputs(&mut self, tokens: &[u32], past_seq: usize, upper: usize) {
374 self.decode_token_f32.clear();
375 self.decode_token_f32
376 .extend(tokens.iter().map(|&t| t as f32));
377 self.decode_pos_ix.clear();
378 self.decode_pos_ix.resize(tokens.len(), past_seq as f32);
379 let mask = bucket_decode_mask(past_seq, upper);
380 if self.decode_mask.len() != mask.len() {
381 self.decode_mask = mask;
382 } else {
383 self.decode_mask.copy_from_slice(&mask);
384 }
385 }
386
387 pub fn mel_frames(&self) -> usize {
388 self.graph_ctx.mel_frames
389 }
390
391 pub fn enc_seq(&self) -> usize {
392 self.graph_ctx.enc_seq
393 }
394
395 pub fn decode_device(&self) -> Device {
397 self.decode_device
398 }
399
400 pub fn stage_device(&self) -> Device {
402 self.decode_device
403 }
404
405 pub fn uses_gpu_kv(&self) -> bool {
406 self.use_gpu_kv
407 }
408
409 fn ensure_encoder(&mut self, batch: usize) -> Result<()> {
410 let key = batch as u64;
411 if self.enc_compile_cache.contains(key) {
412 return Ok(());
413 }
414 let built = self.graph_ctx.build_encoder(batch)?;
415 let opts = self.compile_opts.encoder.clone();
416 metal_compile_guard(self.decode_device, || -> Result<()> {
417 compile_cache_ensure_built_with_options(
418 &mut self.enc_compile_cache,
419 key,
420 built,
421 &opts,
422 )?;
423 Ok(())
424 })
425 }
426
427 fn bind_cross_gpu_if_needed(
428 compiled: &mut CompiledGraph,
429 cross: &WhisperCrossCache,
430 enc_seq: usize,
431 d_model: usize,
432 n_layers: usize,
433 epoch: u64,
434 bound_epoch: u64,
435 use_gpu: bool,
436 ) -> Result<bool> {
437 if !use_gpu {
438 return Ok(false);
439 }
440 if epoch == bound_epoch && cross_attn_gpu_handles_ready(compiled) {
441 return Ok(true);
442 }
443 install_cross_attn_gpu_handles(compiled, cross, enc_seq, d_model, n_layers)?;
444 Ok(true)
445 }
446
447 fn ensure_cross(&mut self, batch: usize) -> Result<()> {
448 let key = batch as u64;
449 if self.cross_compile_cache.contains(key) {
450 return Ok(());
451 }
452 let built = self.graph_ctx.build_cross(batch)?;
453 let opts = self.compile_opts.cross.clone();
454 metal_compile_guard(self.decode_device, || -> Result<()> {
455 compile_cache_ensure_built_with_options(
456 &mut self.cross_compile_cache,
457 key,
458 built,
459 &opts,
460 )?;
461 Ok(())
462 })
463 }
464
465 pub fn encode_mel(&mut self, mel: &MelSpectrogram) -> Result<Vec<f32>> {
466 ensure!(
467 mel.n_frames == self.graph_ctx.mel_frames,
468 "mel frame count mismatch"
469 );
470 self.ensure_encoder(1)?;
471 let key = 1u64;
472 metal_compile_guard(self.decode_device, || {
473 self.enc_compile_cache
474 .get_or_compile(key, || panic!("encoder cache missing"))
475 .run(&[("mel", &mel.data)])
476 })
477 .into_iter()
478 .next()
479 .ok_or_else(|| anyhow::anyhow!("encoder produced no output"))
480 }
481
482 pub fn encode_pcm(&mut self, samples: &[f32]) -> Result<Vec<f32>> {
483 let mel = pcm_to_mel(&self.graph_ctx.cfg, samples);
484 self.encode_mel(&mel)
485 }
486
487 pub fn encode_wav(&mut self, path: &Path) -> Result<Vec<f32>> {
488 let samples = crate::audio::load_wav_mono_f32(path)?;
489 self.encode_pcm(&samples)
490 }
491
492 fn cross_cache(&mut self, enc: &[f32]) -> Result<WhisperCrossCache> {
493 self.ensure_cross(1)?;
494 let outs = metal_compile_guard(self.decode_device, || {
495 self.cross_compile_cache
496 .get_or_compile(1, || panic!("cross cache missing"))
497 .run(&[("encoder_hidden", enc)])
498 });
499 let cross = cross_from_outputs(
500 self.graph_ctx.cfg.decoder_layers,
501 1,
502 self.graph_ctx.enc_seq,
503 self.graph_ctx.cfg.d_model,
504 &outs,
505 )
506 .map_err(|e| anyhow::anyhow!(e))?;
507 self.cross_gpu_epoch = self.cross_gpu_epoch.saturating_add(1);
508 Ok(cross)
509 }
510
511 pub fn prefill_prompt(
512 &mut self,
513 cross: &WhisperCrossCache,
514 prompt_tokens: &[u32],
515 batch: usize,
516 ) -> Result<(Vec<f32>, WhisperKvCache)> {
517 let dec_seq = prompt_tokens.len();
518 let key = decode_cache_key(batch, dec_seq);
519
520 metal_compile_guard(self.prefill_device, || {
521 compile_cache_ensure_built_with_options(
522 &mut self.prefill_compile_cache,
523 key,
524 self.graph_ctx.build_prefill(batch, dec_seq)?,
525 &self.compile_opts.prefill,
526 )
527 })?;
528 let token_f32 = if batch == 1 {
529 prompt_tokens.iter().map(|&t| t as f32).collect()
530 } else {
531 batched_prompt_f32(prompt_tokens, batch)
532 };
533 let enc_seq = self.graph_ctx.enc_seq;
534 let d_model = self.graph_ctx.cfg.d_model;
535 let n_layers = self.graph_ctx.cfg.decoder_layers;
536 let epoch = self.cross_gpu_epoch;
537 let bound_epoch = self.cross_gpu_bound_epoch;
538 let use_gpu = self.use_gpu_kv;
539 let mut cross_on_gpu = use_gpu && bound_epoch == epoch;
540 let cross_bound = {
541 let prefill = self
542 .prefill_compile_cache
543 .get_or_compile(key, || panic!("prefill cache missing"));
544 Self::bind_cross_gpu_if_needed(
545 prefill,
546 cross,
547 enc_seq,
548 d_model,
549 n_layers,
550 epoch,
551 bound_epoch,
552 use_gpu,
553 )?
554 };
555 if cross_bound {
556 self.cross_gpu_bound_epoch = epoch;
557 cross_on_gpu = true;
558 }
559 let prefill = self
560 .prefill_compile_cache
561 .get_or_compile(key, || panic!("prefill cache missing"));
562 let mut inputs: Vec<(&str, &[f32])> = vec![("token_ids", &token_f32)];
563 if !cross_on_gpu {
564 for i in 0..self.graph_ctx.cfg.decoder_layers {
565 inputs.push((
566 self.cross_input_names[2 * i].as_str(),
567 cross.layers_k[i].as_slice(),
568 ));
569 inputs.push((
570 self.cross_input_names[2 * i + 1].as_str(),
571 cross.layers_v[i].as_slice(),
572 ));
573 }
574 }
575 let outputs = metal_compile_guard(self.prefill_device, || prefill.run(&inputs));
576 ensure!(!outputs.is_empty(), "prefill returned no outputs");
577 let logits = outputs[0].clone();
578 let kv = kv_from_prefill_outputs(
579 self.graph_ctx.cfg.decoder_layers,
580 batch,
581 dec_seq,
582 self.graph_ctx.cfg.d_model,
583 &outputs[1..],
584 )
585 .map_err(|e| anyhow::anyhow!(e))?;
586 Ok((logits, kv))
587 }
588
589 fn decode_step_bucketed(
590 &mut self,
591 cross: &WhisperCrossCache,
592 token: u32,
593 cache: &mut WhisperKvCache,
594 batch: usize,
595 ) -> Result<Vec<f32>> {
596 self.decode_step_batch(cross, std::slice::from_ref(&token), cache, batch, false)
597 }
598
599 fn decode_step_batch(
600 &mut self,
601 cross: &WhisperCrossCache,
602 tokens: &[u32],
603 cache: &mut WhisperKvCache,
604 batch: usize,
605 sync_kv_to_host: bool,
606 ) -> Result<Vec<f32>> {
607 ensure!(
608 tokens.len() == batch,
609 "decode_step_batch: expected {batch} tokens, got {}",
610 tokens.len()
611 );
612 self.ensure_decode_batch(batch)?;
613 let past_seq = cache.past_len;
614 let bucket_key = past_seq as u64;
615 if self.use_gpu_kv {
616 return self.decode_step_batch_gpu(
617 cross,
618 tokens,
619 cache,
620 batch,
621 bucket_key,
622 past_seq,
623 sync_kv_to_host,
624 );
625 }
626 self.decode_step_batch_host(cross, tokens, cache, batch, bucket_key, past_seq)
627 }
628
629 fn decode_step_batch_gpu(
630 &mut self,
631 cross: &WhisperCrossCache,
632 tokens: &[u32],
633 cache: &mut WhisperKvCache,
634 batch: usize,
635 key: u64,
636 past_seq: usize,
637 sync_kv_to_host: bool,
638 ) -> Result<Vec<f32>> {
639 let graph_ctx = self.graph_ctx.clone();
640 let decode_opts = self.compile_opts.decode.clone();
641 let d_model = self.graph_ctx.cfg.d_model;
642 let n_layers = self.graph_ctx.cfg.decoder_layers;
643
644 metal_compile_guard(self.decode_device, || {
645 bucket_cache_ensure_built(
646 &mut self.decode_compile_cache,
647 key,
648 |upper| graph_ctx.build_decode_step(batch, upper as usize),
649 &decode_opts,
650 )
651 })
652 .ok_or_else(|| anyhow::anyhow!("past_seq {past_seq} outside decode buckets"))?;
653
654 let upper = self
655 .decode_upper_for_key(key)
656 .ok_or_else(|| anyhow::anyhow!("past_seq {past_seq} outside decode buckets"))?;
657 self.prepare_decode_step_inputs(tokens, past_seq, upper);
658 let token_f32 = &self.decode_token_f32;
659 let pos_ix = &self.decode_pos_ix;
660 let mask = &self.decode_mask;
661 let mut specs: Vec<CacheRunInput<'_>> = vec![
662 CacheRunInput {
663 name: "token_id",
664 data: token_f32,
665 row_inner: None,
666 },
667 CacheRunInput {
668 name: "pos_ix",
669 data: pos_ix,
670 row_inner: None,
671 },
672 CacheRunInput {
673 name: "mask",
674 data: mask,
675 row_inner: None,
676 },
677 ];
678 let epoch = self.cross_gpu_epoch;
679 let bound_epoch = self.cross_gpu_bound_epoch;
680 let use_gpu = self.use_gpu_kv;
681 let enc_seq = self.graph_ctx.enc_seq;
682 let mut cross_on_gpu = use_gpu && bound_epoch == epoch;
683 if let Some(compiled) = self.decode_compile_cache.compiled_for_key_mut(key) {
684 if Self::bind_cross_gpu_if_needed(
685 compiled,
686 cross,
687 enc_seq,
688 d_model,
689 n_layers,
690 epoch,
691 bound_epoch,
692 use_gpu,
693 )? {
694 self.cross_gpu_bound_epoch = epoch;
695 cross_on_gpu = true;
696 }
697 }
698 if !cross_on_gpu {
699 for i in 0..n_layers {
700 specs.push(CacheRunInput {
701 name: self.cross_input_names[2 * i].as_str(),
702 data: cross.layers_k[i].as_slice(),
703 row_inner: None,
704 });
705 specs.push(CacheRunInput {
706 name: self.cross_input_names[2 * i + 1].as_str(),
707 data: cross.layers_v[i].as_slice(),
708 row_inner: None,
709 });
710 }
711 }
712
713 let upper_u = upper as u64;
714 let prev_upper = self.gpu_kv_binding.upper;
715 let bucket_changed = prev_upper != 0 && prev_upper != upper_u;
716 let handles_live = self
717 .decode_compile_cache
718 .compiled_for_key_mut(key)
719 .map(|c| c.has_gpu_handle("past_k_0"))
720 .unwrap_or(false);
721 let refresh_kv = if self.decode_device == Device::Gpu {
722 true
724 } else {
725 bucket_changed || !handles_live
726 };
727
728 let logits = metal_compile_guard(self.decode_device, || {
729 run_bucketed_kv_decode_gpu(
730 &mut self.decode_compile_cache,
731 key,
732 past_seq,
733 cache,
734 &mut self.gpu_kv_binding,
735 d_model,
736 n_layers,
737 &specs,
738 |upper| {
739 let built = graph_ctx
740 .build_decode_step(batch, upper as usize)
741 .expect("whisper decode step built");
742 graph_from_built(built).expect("whisper decode step graph")
743 },
744 &decode_opts,
745 refresh_kv,
746 )
747 })?;
748
749 let force_host_kv = self.decode_device == Device::Gpu;
750 let next_upper = self
751 .decode_upper_for_key((past_seq + 1) as u64)
752 .unwrap_or(upper);
753 let leaves_bucket = next_upper != upper;
754
755 if sync_kv_to_host || leaves_bucket || force_host_kv {
756 if let Some(compiled) = self.decode_compile_cache.compiled_for_key_mut(key) {
757 sync_gpu_kv_to_host(compiled, cache, d_model, n_layers)?;
758 }
759 }
760 Ok(logits)
761 }
762
763 fn ensure_decode_batch(&mut self, batch: usize) -> Result<()> {
764 let batch_tag = batch as u64;
765 if self.decode_batch_tag == batch_tag {
766 return Ok(());
767 }
768 self.gpu_kv_binding = GpuKvBinding::default();
769 self.decode_batch_tag = batch_tag;
770 let max_past = self.graph_ctx.cfg.max_target_positions.max(1) as u64;
771 self.decode_compile_cache = decode_bucket_ladder(self.decode_device, max_past);
772 Ok(())
773 }
774
775 fn decode_upper_for_key(&self, key: u64) -> Option<usize> {
776 self.decode_compile_cache.bucket_for(key).and_then(|idx| {
777 self.decode_compile_cache
778 .buckets()
779 .nth(idx)
780 .map(|r| (r.end - 1) as usize)
781 })
782 }
783
784 fn decode_step_batch_host(
785 &mut self,
786 cross: &WhisperCrossCache,
787 tokens: &[u32],
788 cache: &mut WhisperKvCache,
789 batch: usize,
790 key: u64,
791 past_seq: usize,
792 ) -> Result<Vec<f32>> {
793 let graph_ctx = self.graph_ctx.clone();
794 let d_model = self.graph_ctx.cfg.d_model;
795 let n_layers = self.graph_ctx.cfg.decoder_layers;
796 let upper = self
797 .decode_upper_for_key(key)
798 .ok_or_else(|| anyhow::anyhow!("past_seq {past_seq} outside decode buckets"))?;
799 self.prepare_decode_step_inputs(tokens, past_seq, upper);
800 let token_f32 = &self.decode_token_f32;
801 let pos_ix = &self.decode_pos_ix;
802 let mask = &self.decode_mask;
803 let mut specs: Vec<CacheRunInput<'_>> = vec![
804 CacheRunInput {
805 name: "token_id",
806 data: token_f32,
807 row_inner: None,
808 },
809 CacheRunInput {
810 name: "pos_ix",
811 data: pos_ix,
812 row_inner: None,
813 },
814 CacheRunInput {
815 name: "mask",
816 data: mask,
817 row_inner: None,
818 },
819 ];
820 let epoch = self.cross_gpu_epoch;
821 let bound_epoch = self.cross_gpu_bound_epoch;
822 let use_gpu = self.use_gpu_kv;
823 let enc_seq = self.graph_ctx.enc_seq;
824 let mut cross_on_gpu = use_gpu && bound_epoch == epoch;
825 if let Some(compiled) = self.decode_compile_cache.compiled_for_key_mut(key) {
826 if Self::bind_cross_gpu_if_needed(
827 compiled,
828 cross,
829 enc_seq,
830 d_model,
831 n_layers,
832 epoch,
833 bound_epoch,
834 use_gpu,
835 )? {
836 self.cross_gpu_bound_epoch = epoch;
837 cross_on_gpu = true;
838 }
839 }
840 if !cross_on_gpu {
841 for i in 0..n_layers {
842 specs.push(CacheRunInput {
843 name: self.cross_input_names[2 * i].as_str(),
844 data: cross.layers_k[i].as_slice(),
845 row_inner: None,
846 });
847 specs.push(CacheRunInput {
848 name: self.cross_input_names[2 * i + 1].as_str(),
849 data: cross.layers_v[i].as_slice(),
850 row_inner: None,
851 });
852 }
853 }
854
855 let (logits, new_k, new_v) = metal_compile_guard(self.decode_device, || {
856 run_bucketed_kv_decode_keyed(
857 &mut self.decode_compile_cache,
858 key,
859 past_seq,
860 cache,
861 d_model,
862 n_layers,
863 &specs,
864 |upper| {
865 let built = graph_ctx
866 .build_decode_step(batch, upper as usize)
867 .expect("whisper decode step built");
868 graph_from_built(built).expect("whisper decode step graph")
869 },
870 &self.compile_opts.decode,
871 )
872 })?;
873
874 apply_bucketed_decode_step(cache, new_k, new_v, batch, d_model)
875 .map_err(|e| anyhow::anyhow!(e))?;
876 Ok(logits)
877 }
878
879 pub fn swap_decode_cache(&mut self, other: &mut Self) {
881 std::mem::swap(
882 &mut self.decode_compile_cache,
883 &mut other.decode_compile_cache,
884 );
885 std::mem::swap(&mut self.decode_batch_tag, &mut other.decode_batch_tag);
886 self.gpu_kv_binding = GpuKvBinding::default();
887 other.gpu_kv_binding = GpuKvBinding::default();
888 }
889
890 pub fn decode_one_step(
892 &mut self,
893 cross: &WhisperCrossCache,
894 token: u32,
895 cache: &mut WhisperKvCache,
896 ) -> Result<Vec<f32>> {
897 self.decode_step_bucketed(cross, token, cache, 1)
898 }
899
900 fn decode_step(
901 &mut self,
902 cross: &WhisperCrossCache,
903 token: u32,
904 cache: &mut WhisperKvCache,
905 batch: usize,
906 ) -> Result<Vec<f32>> {
907 self.decode_step_bucketed(cross, token, cache, batch)
908 }
909
910 pub fn encode_mel_batch(&mut self, mels: &[MelSpectrogram]) -> Result<Vec<f32>> {
911 if mels.is_empty() {
912 return Ok(Vec::new());
913 }
914 let batch = mels.len();
915 let mel_input: Vec<f32> = if batch == 1 {
916 mels[0].data.clone()
917 } else {
918 stack_mels(mels)
919 };
920 self.ensure_encoder(batch)?;
921 metal_compile_guard(self.decode_device, || {
922 self.enc_compile_cache
923 .get_or_compile(batch as u64, || panic!("encoder cache missing"))
924 .run(&[("mel", &mel_input)])
925 })
926 .into_iter()
927 .next()
928 .ok_or_else(|| anyhow::anyhow!("encoder produced no output"))
929 }
930
931 #[cfg(feature = "tokenizer")]
933 pub fn bench_greedy_pipeline(
934 &mut self,
935 pcm: &[f32],
936 decode_steps: usize,
937 warmup: usize,
938 ) -> Result<(WhisperBenchReport, String)> {
939 use std::time::Instant;
940 let mel = pcm_to_mel(&self.graph_ctx.cfg, pcm);
941 for _ in 0..warmup {
942 let enc = self.encode_mel(&mel)?;
943 self.bench_greedy_from_encoder(&enc, decode_steps.min(2))?;
944 }
945 let t_enc = Instant::now();
946 let enc = self.encode_mel(&mel)?;
947 let encode_ms = t_enc.elapsed().as_secs_f64() * 1000.0;
948 let (mut report, transcript) = self.bench_greedy_from_encoder(&enc, decode_steps)?;
949 report.encode_ms = encode_ms;
950 report.greedy_ms =
951 report.encode_ms + report.cross_ms + report.prefill_ms + report.decode_ms;
952 Ok((report, transcript))
953 }
954
955 #[cfg(feature = "tokenizer")]
957 pub fn bench_greedy_from_encoder(
958 &mut self,
959 enc: &[f32],
960 decode_steps: usize,
961 ) -> Result<(WhisperBenchReport, String)> {
962 use std::time::Instant;
963 let t_cross = Instant::now();
964 let cross = self.cross_cache_batch(enc, 1)?;
965 let cross_ms = t_cross.elapsed().as_secs_f64() * 1000.0;
966 let (mut report, transcript) = self.bench_greedy_from_cross(&cross, decode_steps)?;
967 report.cross_ms = cross_ms;
968 report.greedy_ms =
969 report.encode_ms + report.cross_ms + report.prefill_ms + report.decode_ms;
970 Ok((report, transcript))
971 }
972
973 #[cfg(feature = "tokenizer")]
975 pub fn bench_greedy_from_cross(
976 &mut self,
977 cross: &WhisperCrossCache,
978 decode_steps: usize,
979 ) -> Result<(WhisperBenchReport, String)> {
980 use std::time::Instant;
981
982 let prompt = self.build_prompt()?;
983 let t_pre = Instant::now();
984 let (prefill_logits, cache) = self.prefill_prompt(cross, &prompt, 1)?;
985 let prefill_ms = t_pre.elapsed().as_secs_f64() * 1000.0;
986 let (mut report, transcript) = self.bench_greedy_decode_from_state(
987 cross,
988 &prompt,
989 prefill_logits,
990 cache,
991 decode_steps,
992 )?;
993 report.prefill_ms = prefill_ms;
994 report.greedy_ms =
995 report.encode_ms + report.cross_ms + report.prefill_ms + report.decode_ms;
996 Ok((report, transcript))
997 }
998
999 #[cfg(feature = "tokenizer")]
1001 pub fn bench_greedy_decode_from_state(
1002 &mut self,
1003 cross: &WhisperCrossCache,
1004 prompt: &[u32],
1005 prefill_logits: Vec<f32>,
1006 mut cache: WhisperKvCache,
1007 decode_steps: usize,
1008 ) -> Result<(WhisperBenchReport, String)> {
1009 use std::time::Instant;
1010
1011 let steps = decode_steps.min(self.max_decode_steps);
1012 let vocab = self.graph_ctx.cfg.vocab_size;
1013 let eot = self.eot_id()?;
1014 let last_prefill_logits = prefill_logits.clone();
1015
1016 let t_dec = Instant::now();
1017 let mut tokens = prompt.to_vec();
1018 let mut next_logits = last_logits_row(&prefill_logits, prompt.len(), vocab);
1019 let mut done_steps = 0usize;
1020 for (n_gen, _) in (0..steps).enumerate() {
1021 let mut row = next_logits;
1022 let next = self.suppression.argmax_next(&mut row, n_gen == 0);
1023 tokens.push(next);
1024 done_steps += 1;
1025 if next == eot {
1026 break;
1027 }
1028 let step_logits = self.decode_step(cross, next, &mut cache, 1)?;
1029 next_logits = if step_logits.len() == vocab {
1030 step_logits
1031 } else {
1032 last_logits_row(&step_logits, 1, vocab)
1034 };
1035 }
1036 let decode_ms = t_dec.elapsed().as_secs_f64() * 1000.0;
1037 let transcript = self.decode_tokens(&tokens)?;
1038
1039 let report = WhisperBenchReport {
1040 encode_ms: 0.0,
1041 cross_ms: 0.0,
1042 prefill_ms: 0.0,
1043 decode_ms,
1044 decode_steps: done_steps,
1045 greedy_ms: 0.0,
1046 last_prefill_logits,
1047 };
1048 Ok((report, transcript))
1049 }
1050
1051 pub fn cross_cache_batch(&mut self, enc: &[f32], batch: usize) -> Result<WhisperCrossCache> {
1052 self.ensure_cross(batch)?;
1053 let outs = metal_compile_guard(self.decode_device, || {
1054 self.cross_compile_cache
1055 .get_or_compile(batch as u64, || panic!("cross cache missing"))
1056 .run(&[("encoder_hidden", enc)])
1057 });
1058 let cross = cross_from_outputs(
1059 self.graph_ctx.cfg.decoder_layers,
1060 batch,
1061 self.graph_ctx.enc_seq,
1062 self.graph_ctx.cfg.d_model,
1063 &outs,
1064 )
1065 .map_err(|e| anyhow::anyhow!(e))?;
1066 self.cross_gpu_epoch = self.cross_gpu_epoch.saturating_add(1);
1067 Ok(cross)
1068 }
1069
1070 #[cfg(feature = "tokenizer")]
1071 pub fn transcribe_greedy(&mut self, pcm: &[f32]) -> Result<String> {
1072 self.transcribe_cached(pcm, 1)
1073 }
1074
1075 #[cfg(feature = "tokenizer")]
1076 pub fn transcribe_beam(&mut self, pcm: &[f32]) -> Result<String> {
1077 let beam = if self.beam_size == 0 {
1078 5
1079 } else {
1080 self.beam_size
1081 };
1082 self.transcribe_cached(pcm, beam)
1083 }
1084
1085 #[cfg(feature = "tokenizer")]
1086 pub fn transcribe_with_vad(&mut self, pcm: &[f32]) -> Result<String> {
1087 let vad = self.vad_config.clone().unwrap_or_default();
1088 let regions = segments_by_vad(&vad, pcm);
1089 if regions.len() <= 1 {
1090 return self.transcribe_cached(pcm, 1);
1091 }
1092 let beam = if self.beam_size == 0 {
1093 1
1094 } else {
1095 self.beam_size
1096 };
1097 let texts = self.transcribe_regions_batched(pcm, ®ions, beam)?;
1098 Ok(texts.join(" "))
1099 }
1100
1101 #[cfg(feature = "tokenizer")]
1102 pub fn transcribe_regions_batched(
1103 &mut self,
1104 pcm: &[f32],
1105 regions: &[crate::audio::SpeechSegment],
1106 beam_size: usize,
1107 ) -> Result<Vec<String>> {
1108 if regions.is_empty() {
1109 return Ok(Vec::new());
1110 }
1111 let mut out = Vec::with_capacity(regions.len());
1112 let prompt = self.build_prompt()?;
1113 for chunk in regions.chunks(self.max_region_batch) {
1114 let n = chunk.len();
1115 let mels: Vec<MelSpectrogram> = chunk
1116 .iter()
1117 .map(|seg| pcm_to_mel(&self.graph_ctx.cfg, &pcm[seg.start..seg.end]))
1118 .collect();
1119 let enc_n = self.encode_mel_batch(&mels)?;
1120 let texts = if beam_size <= 1 {
1121 self.greedy_decode_batch(&enc_n, n, &prompt)?
1122 } else {
1123 self.beam_decode_batch(&enc_n, n, beam_size, &prompt)?
1124 };
1125 out.extend(texts);
1126 }
1127 Ok(out)
1128 }
1129
1130 #[cfg(feature = "tokenizer")]
1131 fn greedy_decode_batch(
1132 &mut self,
1133 enc: &[f32],
1134 n_regions: usize,
1135 prompt: &[u32],
1136 ) -> Result<Vec<String>> {
1137 let cross = self.cross_cache_batch(enc, n_regions)?;
1138 let (prefill_logits, mut cache) = self.prefill_prompt(&cross, prompt, n_regions)?;
1139 let mut tokens: Vec<Vec<u32>> = (0..n_regions).map(|_| prompt.to_vec()).collect();
1140 let mut done = vec![false; n_regions];
1141 let vocab = self.graph_ctx.cfg.vocab_size;
1142 let eot = self.eot_id()?;
1143 let mut last_logits = prefill_logits;
1144
1145 for _ in 0..self.max_decode_steps {
1146 if done.iter().all(|&d| d) {
1147 break;
1148 }
1149 let mut step_tokens = vec![eot; n_regions];
1150 for b in 0..n_regions {
1151 if done[b] {
1152 continue;
1153 }
1154 let mut row =
1155 batched_logits_row_owned(&last_logits, b, n_regions, tokens[b].len(), vocab);
1156 let at_begin = tokens[b].len() == prompt.len();
1157 step_tokens[b] = self.suppression.argmax_next(&mut row, at_begin);
1158 }
1159 let new_logits =
1160 self.decode_step_batch(&cross, &step_tokens, &mut cache, n_regions, false)?;
1161 last_logits = new_logits;
1162 for b in 0..n_regions {
1163 if done[b] {
1164 continue;
1165 }
1166 tokens[b].push(step_tokens[b]);
1167 if step_tokens[b] == eot {
1168 done[b] = true;
1169 }
1170 }
1171 }
1172 tokens.into_iter().map(|t| self.decode_tokens(&t)).collect()
1173 }
1174
1175 #[cfg(feature = "tokenizer")]
1176 fn beam_decode_batch(
1177 &mut self,
1178 enc: &[f32],
1179 n_regions: usize,
1180 beam_size: usize,
1181 prompt: &[u32],
1182 ) -> Result<Vec<String>> {
1183 let plane = self.graph_ctx.enc_seq * self.graph_ctx.cfg.d_model;
1184 let enc_rep = replicate_encoder_for_beams(enc, n_regions, beam_size, plane);
1185 let batch = n_regions * beam_size;
1186 let cross = self.cross_cache_batch(&enc_rep, batch)?;
1187 let (prefill_logits, cache) = self.prefill_prompt(&cross, prompt, batch)?;
1188 let eot = self.eot_id()?;
1189 let cross_ref = ✗
1190 let suffixes = beam_search_decode_kv_batched(
1191 &prefill_logits,
1192 prompt.len(),
1193 cache,
1194 n_regions,
1195 beam_size,
1196 self.max_decode_steps,
1197 self.graph_ctx.cfg.vocab_size,
1198 eot,
1199 |tokens, cache| self.decode_step_batch(cross_ref, tokens, cache, batch, true),
1200 )?;
1201 suffixes
1202 .into_iter()
1203 .map(|suffix| {
1204 let mut t = prompt.to_vec();
1205 t.extend(suffix);
1206 self.decode_tokens(&t)
1207 })
1208 .collect()
1209 }
1210
1211 #[cfg(feature = "tokenizer")]
1212 fn greedy_extend_after_prefill(
1213 &mut self,
1214 cross: &WhisperCrossCache,
1215 prompt: &[u32],
1216 mut cache: WhisperKvCache,
1217 prefill_logits: &[f32],
1218 max_steps: usize,
1219 ) -> Result<Vec<u32>> {
1220 let vocab = self.graph_ctx.cfg.vocab_size;
1221 let eot = self.eot_id()?;
1222 let prompt_len = prompt.len();
1223 let mut tokens = prompt.to_vec();
1224 let mut next_logits = last_logits_row(prefill_logits, prompt_len, vocab);
1225 for (n_gen, _) in (0..max_steps).enumerate() {
1226 let mut row = next_logits;
1227 let next = self.suppression.argmax_next(&mut row, n_gen == 0);
1228 tokens.push(next);
1229 if next == eot {
1230 break;
1231 }
1232 let step_logits = self.decode_step(cross, next, &mut cache, 1)?;
1233 next_logits = if step_logits.len() == vocab {
1234 step_logits
1235 } else {
1236 last_logits_row(&step_logits, 1, vocab)
1237 };
1238 }
1239 Ok(tokens)
1240 }
1241
1242 fn transcribe_cross(&mut self, cross: WhisperCrossCache, beam_size: usize) -> Result<String> {
1243 let prompt = self.build_prompt()?;
1244 let cross_ref = ✗
1245 if beam_size <= 1 {
1246 let (prefill_logits, cache) = self.prefill_prompt(cross_ref, &prompt, 1)?;
1247 let tokens = self.greedy_extend_after_prefill(
1248 cross_ref,
1249 &prompt,
1250 cache,
1251 &prefill_logits,
1252 self.max_decode_steps,
1253 )?;
1254 return self.decode_tokens(&tokens);
1255 }
1256 let (prefill_logits, base_cache) = self.prefill_prompt(cross_ref, &prompt, 1)?;
1257 let extra = beam_search_decode_kv(
1258 &prefill_logits,
1259 prompt.len(),
1260 base_cache,
1261 self.eot_id()?,
1262 beam_size,
1263 self.max_decode_steps,
1264 self.graph_ctx.cfg.vocab_size,
1265 |token, cache| {
1266 let mut branch = cache.clone();
1267 let logits = self.decode_step(cross_ref, token, &mut branch, 1)?;
1268 let mut row = last_logits_row(&logits, 1, self.graph_ctx.cfg.vocab_size);
1269 self.suppression.apply(&mut row);
1270 Ok((row, branch))
1271 },
1272 )?;
1273 let mut tokens = prompt;
1274 tokens.extend(extra);
1275 self.decode_tokens(&tokens)
1276 }
1277
1278 #[cfg(feature = "tokenizer")]
1279 pub fn build_prompt(&self) -> Result<Vec<u32>> {
1280 let tok = self
1281 .tokenizer
1282 .as_ref()
1283 .ok_or_else(|| anyhow::anyhow!("tokenizer not loaded"))?;
1284 initial_prompt_opts(
1285 tok,
1286 self.language.as_deref(),
1287 self.translate,
1288 self.timestamps,
1289 )
1290 }
1291
1292 #[cfg(feature = "tokenizer")]
1293 fn eot_id(&self) -> Result<u32> {
1294 self.tokenizer
1295 .as_ref()
1296 .and_then(|t| t.token_to_id(EOT_TOKEN))
1297 .ok_or_else(|| anyhow::anyhow!("tokenizer missing {EOT_TOKEN}"))
1298 }
1299
1300 #[cfg(feature = "tokenizer")]
1301 fn decode_tokens(&self, tokens: &[u32]) -> Result<String> {
1302 let tok = self
1303 .tokenizer
1304 .as_ref()
1305 .ok_or_else(|| anyhow::anyhow!("tokenizer not loaded"))?;
1306 tok.decode(tokens, true)
1307 .map_err(|e| anyhow::anyhow!("decode tokens: {e}"))
1308 }
1309
1310 fn transcribe_cached(&mut self, pcm: &[f32], beam_size: usize) -> Result<String> {
1311 if self.vad_config.is_some() {
1312 return self.transcribe_with_vad(pcm);
1313 }
1314 let enc = self.encode_pcm(pcm)?;
1315 let cross = self.cross_cache(&enc)?;
1316 self.transcribe_cross(cross, beam_size)
1317 }
1318}