Skip to main content

rlx_whisper/
backend.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Tier-1 compile profiles and bucket keys for Whisper graphs.
17
18use crate::builder::WhisperGraphOpts;
19use crate::config::WhisperConfig;
20use crate::flow::{
21    WhisperEncoderFlow, build_whisper_cross_kv_built, build_whisper_decode_step_built_ext_opts,
22    build_whisper_decoder_prefill_built_ext_opts,
23};
24use crate::fused::{FusedDecoderWeights, FusedEncoderWeights};
25use crate::weights::WhisperWeightPrefix;
26use anyhow::Result;
27use rlx_core::device_supports_gpu_kv;
28use rlx_core::flow_bridge::{compile_options_for_profile, profile_near_weights};
29use rlx_core::weight_map::WeightMap;
30use rlx_flow::BuiltModel;
31use rlx_flow::CompileProfile;
32use rlx_opt::PrecisionPolicy;
33use rlx_runtime::compile_cache::BucketedCompileCache;
34use rlx_runtime::{CompileOptions, Device, Precision};
35use std::collections::HashMap;
36use std::ops::Range;
37use std::path::Path;
38use std::sync::Arc;
39
40type WeightTensor = (Vec<f32>, Vec<usize>);
41type WhisperWeightMap = HashMap<String, WeightTensor>;
42
43/// Non-bucketed compile cache key (prefill graphs indexed by `(batch, dec_seq)`).
44pub fn decode_cache_key(batch: usize, seq: usize) -> u64 {
45    ((batch as u64) << 32) | (seq as u64)
46}
47
48/// Power-of-two decode buckets; graphs take runtime `pos_ix` + bucket self-attn mask.
49pub fn decode_bucket_ladder(device: Device, max_past: u64) -> BucketedCompileCache {
50    let max_past = max_past.max(1);
51    #[allow(clippy::single_range_in_vec_init)]
52    let mut ranges: Vec<Range<u64>> = vec![0..2];
53    let mut start = 2u64;
54    let mut extent = 4u64;
55    loop {
56        ranges.push(start..(extent + 1));
57        if extent >= max_past {
58            break;
59        }
60        start = extent + 1;
61        extent = extent.saturating_mul(2).max(start + 1);
62    }
63    BucketedCompileCache::with_policy(device, ranges, None)
64}
65
66/// MPSGraph rejects some fused attention reshapes on Metal decode graphs.
67pub(crate) fn metal_safe_decode_profile(
68    device: Device,
69    mut profile: CompileProfile,
70) -> CompileProfile {
71    if device == Device::Metal {
72        profile.fusion.skip = true;
73        profile.backend.metal.skip_fusion = true;
74        profile.backend.metal.unfuse_regions = true;
75    }
76    profile
77}
78
79fn apply_f16_compute(opts: &mut CompileOptions, f16: bool) {
80    if f16 {
81        opts.precision = Precision::F16;
82        opts.policy = Some(PrecisionPolicy::AutoMixed);
83    }
84}
85
86/// Resolved compile options for encoder, cross, prefill, and bucketed decode.
87#[derive(Debug, Clone)]
88pub struct WhisperCompileOpts {
89    pub encoder: CompileOptions,
90    pub cross: CompileOptions,
91    pub prefill: CompileOptions,
92    pub decode: CompileOptions,
93}
94
95/// GPU-resident KV when decode runs on the same device (not CPU-decode fallback).
96pub fn whisper_use_gpu_kv(lm_device: Device, decode_device: Device) -> bool {
97    device_supports_gpu_kv(decode_device) && decode_device == lm_device
98}
99
100/// Device for encoder, cross, prefill, and bucketed decode.
101///
102/// Metal / MLX / Vulkan LM devices run these stages on CPU until graph parity
103/// matches the CPU reference (decoder attention + encoder conv/attn).
104pub fn whisper_decoder_device(lm_device: Device) -> Device {
105    match lm_device {
106        Device::Metal | Device::Mlx | Device::Vulkan => Device::Cpu,
107        other => other,
108    }
109}
110
111/// Disable MPSGraph for LM-shaped compiles on Metal (attention reshape coverage).
112pub fn metal_compile_guard<R, F>(device: Device, f: F) -> R
113where
114    F: FnOnce() -> R,
115{
116    if device == Device::Metal {
117        rlx_ir::env::set("RLX_DISABLE_MPSGRAPH", "1");
118        let out = f();
119        rlx_ir::env::unset("RLX_DISABLE_MPSGRAPH");
120        out
121    } else {
122        f()
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use rlx_runtime::Device;
130
131    #[test]
132    fn decode_bucket_ladder_includes_past_zero() {
133        let cache = decode_bucket_ladder(Device::Cpu, 448);
134        assert!(cache.bucket_for(0).is_some());
135        assert!(cache.bucket_for(1).is_some());
136        assert!(cache.bucket_for(200).is_some());
137    }
138
139    #[test]
140    fn decode_bucket_ladder_is_logarithmic() {
141        let cache = decode_bucket_ladder(Device::Cpu, 448);
142        assert!(
143            cache.total_buckets() <= 12,
144            "expected O(log n) buckets, got {}",
145            cache.total_buckets()
146        );
147        assert_eq!(cache.bucket_for(0), cache.bucket_for(1));
148        assert_eq!(cache.bucket_for(5), cache.bucket_for(8));
149    }
150}
151
152impl WhisperCompileOpts {
153    pub fn new(device: Device, f16_compute: bool, weights: &Path) -> Self {
154        let encoder_profile =
155            profile_near_weights(weights, "whisper.rlx.toml", CompileProfile::encoder());
156        let decode_profile = metal_safe_decode_profile(device, CompileProfile::gemma_decode());
157        let prefill_profile = metal_safe_decode_profile(device, CompileProfile::llama32_prefill());
158
159        let mut encoder = compile_options_for_profile(&encoder_profile, device);
160        let mut cross = compile_options_for_profile(&CompileProfile::encoder(), device);
161        let mut prefill = compile_options_for_profile(&prefill_profile, device);
162        let mut decode = compile_options_for_profile(&decode_profile, device);
163
164        apply_f16_compute(&mut encoder, f16_compute);
165        apply_f16_compute(&mut cross, f16_compute);
166        apply_f16_compute(&mut prefill, f16_compute);
167        apply_f16_compute(&mut decode, f16_compute);
168
169        Self {
170            encoder,
171            cross,
172            prefill,
173            decode,
174        }
175    }
176}
177
178/// Shared checkpoint + graph options (cheap `Clone` via `Arc` weights).
179#[derive(Clone)]
180pub struct WhisperGraphCtx {
181    pub cfg: WhisperConfig,
182    pub pfx: WhisperWeightPrefix,
183    pub weights: Arc<WhisperWeightMap>,
184    pub enc_seq: usize,
185    pub mel_frames: usize,
186    pub graph_opts: WhisperGraphOpts,
187    pub fused: Option<FusedDecoderWeights>,
188    pub fused_enc: Option<FusedEncoderWeights>,
189}
190
191impl WhisperGraphCtx {
192    pub fn weight_map(&self) -> WeightMap {
193        WeightMap::from_tensors((*self.weights).clone())
194    }
195
196    pub fn build_encoder(&self, batch: usize) -> Result<BuiltModel> {
197        let mut wm = self.weight_map();
198        WhisperEncoderFlow::new_opts(
199            &self.cfg,
200            &wm,
201            batch,
202            self.mel_frames,
203            self.graph_opts,
204            self.fused_enc.as_ref(),
205        )
206        .build(&mut wm)
207    }
208
209    pub fn build_cross(&self, batch: usize) -> Result<BuiltModel> {
210        let mut wm = self.weight_map();
211        build_whisper_cross_kv_built(&self.cfg, &mut wm, &self.pfx, batch, self.enc_seq)
212    }
213
214    pub fn build_prefill(&self, batch: usize, dec_seq: usize) -> Result<BuiltModel> {
215        let mut wm = self.weight_map();
216        build_whisper_decoder_prefill_built_ext_opts(
217            &self.cfg,
218            &mut wm,
219            &self.pfx,
220            batch,
221            dec_seq,
222            self.enc_seq,
223            true,
224            self.graph_opts,
225            self.fused.as_ref(),
226        )
227    }
228
229    pub fn build_decode_step(&self, batch: usize, bucket_upper: usize) -> Result<BuiltModel> {
230        let mut wm = self.weight_map();
231        build_whisper_decode_step_built_ext_opts(
232            &self.cfg,
233            &mut wm,
234            &self.pfx,
235            batch,
236            bucket_upper,
237            self.enc_seq,
238            true,
239            self.graph_opts,
240            self.fused.as_ref(),
241        )
242    }
243}