1use crate::audio::N_FRAMES;
17use crate::config::WhisperConfig;
18use crate::weights::WhisperWeightPrefix;
19use anyhow::Result;
20use rlx_core::flow_util::WeightMapSource;
21use rlx_core::weight_map::WeightMap;
22use rlx_flow::{BuiltModel, CompileProfile, ModelFlow};
23use rlx_ir::{DType, Shape};
24
25#[derive(Debug, Clone)]
26pub struct WhisperEncoderFlow<'a> {
27 cfg: &'a WhisperConfig,
28 batch: usize,
29 mel_frames: usize,
30 pfx: WhisperWeightPrefix,
31 graph_opts: crate::builder::WhisperGraphOpts,
32 fused_enc: Option<&'a crate::fused::FusedEncoderWeights>,
33}
34
35impl<'a> WhisperEncoderFlow<'a> {
36 pub fn new(
37 cfg: &'a WhisperConfig,
38 weights: &WeightMap,
39 batch: usize,
40 mel_frames: usize,
41 ) -> Self {
42 Self::new_opts(
43 cfg,
44 weights,
45 batch,
46 mel_frames,
47 crate::builder::WhisperGraphOpts::default(),
48 None,
49 )
50 }
51
52 pub fn new_opts(
53 cfg: &'a WhisperConfig,
54 weights: &WeightMap,
55 batch: usize,
56 mel_frames: usize,
57 graph_opts: crate::builder::WhisperGraphOpts,
58 fused_enc: Option<&'a crate::fused::FusedEncoderWeights>,
59 ) -> Self {
60 Self {
61 cfg,
62 batch,
63 mel_frames,
64 pfx: WhisperWeightPrefix::detect(weights),
65 graph_opts,
66 fused_enc,
67 }
68 }
69
70 pub fn build(self, weights: &mut WeightMap) -> Result<BuiltModel> {
71 build_whisper_encoder_built_opts(
72 self.cfg,
73 weights,
74 &self.pfx,
75 self.batch,
76 self.mel_frames,
77 self.graph_opts,
78 self.fused_enc,
79 )
80 }
81}
82
83pub fn build_whisper_encoder_built(
84 cfg: &WhisperConfig,
85 weights: &mut WeightMap,
86 pfx: &WhisperWeightPrefix,
87 batch: usize,
88 mel_frames: usize,
89) -> Result<BuiltModel> {
90 build_whisper_encoder_built_opts(
91 cfg,
92 weights,
93 pfx,
94 batch,
95 mel_frames,
96 crate::builder::WhisperGraphOpts::default(),
97 None,
98 )
99}
100
101pub fn build_whisper_encoder_built_opts(
102 cfg: &WhisperConfig,
103 weights: &mut WeightMap,
104 pfx: &WhisperWeightPrefix,
105 batch: usize,
106 mel_frames: usize,
107 graph_opts: crate::builder::WhisperGraphOpts,
108 fused_enc: Option<&crate::fused::FusedEncoderWeights>,
109) -> Result<BuiltModel> {
110 let enc_seq = cfg.encoder_seq_len(mel_frames);
111 let f = DType::F32;
112 let hidden_shape = Shape::new(&[batch, enc_seq, cfg.d_model], f);
113 let cfg = cfg.clone();
114 let pfx = pfx.clone();
115 let fused_enc = fused_enc.cloned();
116
117 ModelFlow::new("whisper_encoder")
118 .with_profile(CompileProfile::encoder())
119 .input("mel", Shape::new(&[batch, cfg.num_mel_bins, mel_frames], f))
120 .plugin_named("whisper.encoder", move |emit, _| {
121 let mel = emit.flow_input("mel")?.hir_id();
122 let hir = emit
123 .module
124 .as_hir_mut()
125 .expect("whisper encoder flow requires HIR stage");
126 let mut b = crate::builder::WhisperBuilder::new(
127 hir,
128 emit.params,
129 emit.weights,
130 &pfx,
131 batch,
132 graph_opts,
133 );
134 if let Some(ref fused_enc) = fused_enc {
135 b = b.with_fused_enc(fused_enc);
136 }
137 let hidden = b.emit_encoder(&cfg, mel, mel_frames, enc_seq)?;
138 Ok(Some(emit.wrap(hidden, hidden_shape.clone())))
139 })
140 .output("encoder_hidden")
141 .build(&mut WeightMapSource(weights))
142}
143
144#[derive(Debug, Clone)]
145pub struct WhisperDecoderFlow<'a> {
146 cfg: &'a WhisperConfig,
147 batch: usize,
148 dec_seq: usize,
149 enc_seq: usize,
150 pfx: WhisperWeightPrefix,
151}
152
153impl<'a> WhisperDecoderFlow<'a> {
154 pub fn new(
155 cfg: &'a WhisperConfig,
156 weights: &WeightMap,
157 batch: usize,
158 dec_seq: usize,
159 enc_seq: usize,
160 ) -> Self {
161 Self {
162 cfg,
163 batch,
164 dec_seq,
165 enc_seq,
166 pfx: WhisperWeightPrefix::detect(weights),
167 }
168 }
169
170 pub fn build(self, weights: &mut WeightMap) -> Result<BuiltModel> {
171 build_whisper_decoder_built(
172 self.cfg,
173 weights,
174 &self.pfx,
175 self.batch,
176 self.dec_seq,
177 self.enc_seq,
178 )
179 }
180}
181
182pub fn build_whisper_decoder_built(
183 cfg: &WhisperConfig,
184 weights: &mut WeightMap,
185 pfx: &WhisperWeightPrefix,
186 batch: usize,
187 dec_seq: usize,
188 enc_seq: usize,
189) -> Result<BuiltModel> {
190 let f = DType::F32;
191 let h = cfg.d_model;
192 let logits_shape = Shape::new(&[batch, dec_seq, cfg.vocab_size], f);
193 let cfg = cfg.clone();
194 let pfx = pfx.clone();
195
196 ModelFlow::new("whisper_decoder")
197 .with_profile(CompileProfile::encoder())
198 .input("token_ids", Shape::new(&[batch, dec_seq], f))
199 .input("encoder_hidden", Shape::new(&[batch, enc_seq, h], f))
200 .plugin_named("whisper.decoder", move |emit, _| {
201 let tokens = emit.flow_input("token_ids")?.hir_id();
202 let enc = emit.flow_input("encoder_hidden")?.hir_id();
203 let hir = emit
204 .module
205 .as_hir_mut()
206 .expect("whisper decoder flow requires HIR stage");
207 let mut b = crate::builder::WhisperBuilder::new(
208 hir,
209 emit.params,
210 emit.weights,
211 &pfx,
212 batch,
213 crate::builder::WhisperGraphOpts::default(),
214 );
215 let logits = b.emit_decoder(&cfg, tokens, enc, dec_seq, enc_seq)?;
216 Ok(Some(emit.wrap(logits, logits_shape.clone())))
217 })
218 .output("logits")
219 .build(&mut WeightMapSource(weights))
220}
221
222pub fn default_mel_frames() -> usize {
224 N_FRAMES
225}
226
227pub fn build_whisper_encoder_graph_sized(
228 cfg: &WhisperConfig,
229 weights: &mut WeightMap,
230 batch: usize,
231 mel_frames: usize,
232) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
233 let pfx = WhisperWeightPrefix::detect(weights);
234 rlx_core::flow_util::graph_from_built(build_whisper_encoder_built(
235 cfg, weights, &pfx, batch, mel_frames,
236 )?)
237}
238
239pub fn build_whisper_cross_kv_built(
240 cfg: &WhisperConfig,
241 weights: &mut WeightMap,
242 pfx: &WhisperWeightPrefix,
243 batch: usize,
244 enc_seq: usize,
245) -> Result<BuiltModel> {
246 let (hir, params) = crate::builder::build_whisper_cross_kv_hir(
247 cfg,
248 &mut WeightMapSource(weights),
249 pfx,
250 batch,
251 enc_seq,
252 )?;
253 rlx_core::flow_util::built_from_hir_with_profile(hir, params, CompileProfile::encoder())
254}
255
256pub fn build_whisper_decoder_prefill_built(
257 cfg: &WhisperConfig,
258 weights: &mut WeightMap,
259 pfx: &WhisperWeightPrefix,
260 batch: usize,
261 dec_seq: usize,
262 enc_seq: usize,
263) -> Result<BuiltModel> {
264 build_whisper_decoder_prefill_built_ext(cfg, weights, pfx, batch, dec_seq, enc_seq, true)
265}
266
267pub fn build_whisper_decoder_prefill_built_ext(
268 cfg: &WhisperConfig,
269 weights: &mut WeightMap,
270 pfx: &WhisperWeightPrefix,
271 batch: usize,
272 dec_seq: usize,
273 enc_seq: usize,
274 use_cross_cache: bool,
275) -> Result<BuiltModel> {
276 build_whisper_decoder_prefill_built_ext_opts(
277 cfg,
278 weights,
279 pfx,
280 batch,
281 dec_seq,
282 enc_seq,
283 use_cross_cache,
284 crate::builder::WhisperGraphOpts::default(),
285 None,
286 )
287}
288
289pub fn build_whisper_decoder_prefill_built_ext_opts(
290 cfg: &WhisperConfig,
291 weights: &mut WeightMap,
292 pfx: &WhisperWeightPrefix,
293 batch: usize,
294 dec_seq: usize,
295 enc_seq: usize,
296 use_cross_cache: bool,
297 graph_opts: crate::builder::WhisperGraphOpts,
298 fused: Option<&crate::fused::FusedDecoderWeights>,
299) -> Result<BuiltModel> {
300 let (hir, params) = crate::builder::build_whisper_decoder_prefill_hir_ext_opts(
301 cfg,
302 &mut WeightMapSource(weights),
303 pfx,
304 batch,
305 dec_seq,
306 enc_seq,
307 use_cross_cache,
308 graph_opts,
309 fused,
310 )?;
311 rlx_core::flow_util::built_from_hir_with_profile(hir, params, CompileProfile::llama32_prefill())
312}
313
314pub fn build_whisper_decode_step_built(
315 cfg: &WhisperConfig,
316 weights: &mut WeightMap,
317 pfx: &WhisperWeightPrefix,
318 batch: usize,
319 past_seq: usize,
320 enc_seq: usize,
321) -> Result<BuiltModel> {
322 build_whisper_decode_step_built_ext(cfg, weights, pfx, batch, past_seq, enc_seq, false)
323}
324
325pub fn build_whisper_decode_step_built_ext(
326 cfg: &WhisperConfig,
327 weights: &mut WeightMap,
328 pfx: &WhisperWeightPrefix,
329 batch: usize,
330 past_seq: usize,
331 enc_seq: usize,
332 use_custom_mask: bool,
333) -> Result<BuiltModel> {
334 build_whisper_decode_step_built_ext_opts(
335 cfg,
336 weights,
337 pfx,
338 batch,
339 past_seq,
340 enc_seq,
341 use_custom_mask,
342 crate::builder::WhisperGraphOpts::default(),
343 None,
344 )
345}
346
347pub fn build_whisper_decode_step_built_ext_opts(
348 cfg: &WhisperConfig,
349 weights: &mut WeightMap,
350 pfx: &WhisperWeightPrefix,
351 batch: usize,
352 past_seq: usize,
353 enc_seq: usize,
354 use_custom_mask: bool,
355 graph_opts: crate::builder::WhisperGraphOpts,
356 fused: Option<&crate::fused::FusedDecoderWeights>,
357) -> Result<BuiltModel> {
358 let (hir, params) = crate::builder::build_whisper_decode_step_hir_ext_opts(
359 cfg,
360 &mut WeightMapSource(weights),
361 pfx,
362 batch,
363 past_seq,
364 enc_seq,
365 use_custom_mask,
366 graph_opts,
367 fused,
368 )?;
369 rlx_core::flow_util::built_from_hir_with_profile(hir, params, CompileProfile::gemma_decode())
370}
371
372pub fn build_whisper_decoder_graph_sized(
373 cfg: &WhisperConfig,
374 weights: &mut WeightMap,
375 batch: usize,
376 dec_seq: usize,
377 enc_seq: usize,
378) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
379 let pfx = WhisperWeightPrefix::detect(weights);
380 rlx_core::flow_util::graph_from_built(build_whisper_decoder_built(
381 cfg, weights, &pfx, batch, dec_seq, enc_seq,
382 )?)
383}