1use crate::builder::WhisperGraphOpts;
19use crate::config::WhisperConfig;
20use crate::flow::{
21 WhisperEncoderFlow, build_whisper_cross_kv_built, build_whisper_decode_step_built_ext_opts,
22 build_whisper_decoder_prefill_built_ext_opts,
23};
24use crate::fused::{FusedDecoderWeights, FusedEncoderWeights};
25use crate::weights::WhisperWeightPrefix;
26use anyhow::Result;
27use rlx_core::device_supports_gpu_kv;
28use rlx_core::flow_bridge::{compile_options_for_profile, profile_near_weights};
29use rlx_core::weight_map::WeightMap;
30use rlx_flow::BuiltModel;
31use rlx_flow::CompileProfile;
32use rlx_opt::PrecisionPolicy;
33use rlx_runtime::compile_cache::BucketedCompileCache;
34use rlx_runtime::{CompileOptions, Device, Precision};
35use std::collections::HashMap;
36use std::ops::Range;
37use std::path::Path;
38use std::sync::Arc;
39
40type WeightTensor = (Vec<f32>, Vec<usize>);
41type WhisperWeightMap = HashMap<String, WeightTensor>;
42
43pub fn decode_cache_key(batch: usize, seq: usize) -> u64 {
45 ((batch as u64) << 32) | (seq as u64)
46}
47
48pub fn decode_bucket_ladder(device: Device, max_past: u64) -> BucketedCompileCache {
50 let max_past = max_past.max(1);
51 #[allow(clippy::single_range_in_vec_init)]
52 let mut ranges: Vec<Range<u64>> = vec![0..2];
53 let mut start = 2u64;
54 let mut extent = 4u64;
55 loop {
56 ranges.push(start..(extent + 1));
57 if extent >= max_past {
58 break;
59 }
60 start = extent + 1;
61 extent = extent.saturating_mul(2).max(start + 1);
62 }
63 BucketedCompileCache::with_policy(device, ranges, None)
64}
65
66pub(crate) fn metal_safe_decode_profile(
68 device: Device,
69 mut profile: CompileProfile,
70) -> CompileProfile {
71 if device == Device::Metal {
72 profile.fusion.skip = true;
73 profile.backend.metal.skip_fusion = true;
74 profile.backend.metal.unfuse_regions = true;
75 }
76 profile
77}
78
79fn apply_f16_compute(opts: &mut CompileOptions, f16: bool) {
80 if f16 {
81 opts.precision = Precision::F16;
82 opts.policy = Some(PrecisionPolicy::AutoMixed);
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct WhisperCompileOpts {
89 pub encoder: CompileOptions,
90 pub cross: CompileOptions,
91 pub prefill: CompileOptions,
92 pub decode: CompileOptions,
93}
94
95pub fn whisper_use_gpu_kv(lm_device: Device, decode_device: Device) -> bool {
97 device_supports_gpu_kv(decode_device) && decode_device == lm_device
98}
99
100pub fn whisper_decoder_device(lm_device: Device) -> Device {
105 match lm_device {
106 Device::Metal | Device::Mlx | Device::Vulkan => Device::Cpu,
107 other => other,
108 }
109}
110
111pub fn metal_compile_guard<R, F>(device: Device, f: F) -> R
113where
114 F: FnOnce() -> R,
115{
116 if device == Device::Metal {
117 rlx_ir::env::set("RLX_DISABLE_MPSGRAPH", "1");
118 let out = f();
119 rlx_ir::env::unset("RLX_DISABLE_MPSGRAPH");
120 out
121 } else {
122 f()
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use rlx_runtime::Device;
130
131 #[test]
132 fn decode_bucket_ladder_includes_past_zero() {
133 let cache = decode_bucket_ladder(Device::Cpu, 448);
134 assert!(cache.bucket_for(0).is_some());
135 assert!(cache.bucket_for(1).is_some());
136 assert!(cache.bucket_for(200).is_some());
137 }
138
139 #[test]
140 fn decode_bucket_ladder_is_logarithmic() {
141 let cache = decode_bucket_ladder(Device::Cpu, 448);
142 assert!(
143 cache.total_buckets() <= 12,
144 "expected O(log n) buckets, got {}",
145 cache.total_buckets()
146 );
147 assert_eq!(cache.bucket_for(0), cache.bucket_for(1));
148 assert_eq!(cache.bucket_for(5), cache.bucket_for(8));
149 }
150}
151
152impl WhisperCompileOpts {
153 pub fn new(device: Device, f16_compute: bool, weights: &Path) -> Self {
154 let encoder_profile =
155 profile_near_weights(weights, "whisper.rlx.toml", CompileProfile::encoder());
156 let decode_profile = metal_safe_decode_profile(device, CompileProfile::gemma_decode());
157 let prefill_profile = metal_safe_decode_profile(device, CompileProfile::llama32_prefill());
158
159 let mut encoder = compile_options_for_profile(&encoder_profile, device);
160 let mut cross = compile_options_for_profile(&CompileProfile::encoder(), device);
161 let mut prefill = compile_options_for_profile(&prefill_profile, device);
162 let mut decode = compile_options_for_profile(&decode_profile, device);
163
164 apply_f16_compute(&mut encoder, f16_compute);
165 apply_f16_compute(&mut cross, f16_compute);
166 apply_f16_compute(&mut prefill, f16_compute);
167 apply_f16_compute(&mut decode, f16_compute);
168
169 Self {
170 encoder,
171 cross,
172 prefill,
173 decode,
174 }
175 }
176}
177
178#[derive(Clone)]
180pub struct WhisperGraphCtx {
181 pub cfg: WhisperConfig,
182 pub pfx: WhisperWeightPrefix,
183 pub weights: Arc<WhisperWeightMap>,
184 pub enc_seq: usize,
185 pub mel_frames: usize,
186 pub graph_opts: WhisperGraphOpts,
187 pub fused: Option<FusedDecoderWeights>,
188 pub fused_enc: Option<FusedEncoderWeights>,
189}
190
191impl WhisperGraphCtx {
192 pub fn weight_map(&self) -> WeightMap {
193 WeightMap::from_tensors((*self.weights).clone())
194 }
195
196 pub fn build_encoder(&self, batch: usize) -> Result<BuiltModel> {
197 let mut wm = self.weight_map();
198 WhisperEncoderFlow::new_opts(
199 &self.cfg,
200 &wm,
201 batch,
202 self.mel_frames,
203 self.graph_opts,
204 self.fused_enc.as_ref(),
205 )
206 .build(&mut wm)
207 }
208
209 pub fn build_cross(&self, batch: usize) -> Result<BuiltModel> {
210 let mut wm = self.weight_map();
211 build_whisper_cross_kv_built(&self.cfg, &mut wm, &self.pfx, batch, self.enc_seq)
212 }
213
214 pub fn build_prefill(&self, batch: usize, dec_seq: usize) -> Result<BuiltModel> {
215 let mut wm = self.weight_map();
216 build_whisper_decoder_prefill_built_ext_opts(
217 &self.cfg,
218 &mut wm,
219 &self.pfx,
220 batch,
221 dec_seq,
222 self.enc_seq,
223 true,
224 self.graph_opts,
225 self.fused.as_ref(),
226 )
227 }
228
229 pub fn build_decode_step(&self, batch: usize, bucket_upper: usize) -> Result<BuiltModel> {
230 let mut wm = self.weight_map();
231 build_whisper_decode_step_built_ext_opts(
232 &self.cfg,
233 &mut wm,
234 &self.pfx,
235 batch,
236 bucket_upper,
237 self.enc_seq,
238 true,
239 self.graph_opts,
240 self.fused.as_ref(),
241 )
242 }
243}