Skip to main content

rlx_whisper/
flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use crate::audio::N_FRAMES;
17use crate::config::WhisperConfig;
18use crate::weights::WhisperWeightPrefix;
19use anyhow::Result;
20use rlx_core::flow_util::WeightMapSource;
21use rlx_core::weight_map::WeightMap;
22use rlx_flow::{BuiltModel, CompileProfile, ModelFlow};
23use rlx_ir::{DType, Shape};
24
25#[derive(Debug, Clone)]
26pub struct WhisperEncoderFlow<'a> {
27    cfg: &'a WhisperConfig,
28    batch: usize,
29    mel_frames: usize,
30    pfx: WhisperWeightPrefix,
31    graph_opts: crate::builder::WhisperGraphOpts,
32    fused_enc: Option<&'a crate::fused::FusedEncoderWeights>,
33}
34
35impl<'a> WhisperEncoderFlow<'a> {
36    pub fn new(
37        cfg: &'a WhisperConfig,
38        weights: &WeightMap,
39        batch: usize,
40        mel_frames: usize,
41    ) -> Self {
42        Self::new_opts(
43            cfg,
44            weights,
45            batch,
46            mel_frames,
47            crate::builder::WhisperGraphOpts::default(),
48            None,
49        )
50    }
51
52    pub fn new_opts(
53        cfg: &'a WhisperConfig,
54        weights: &WeightMap,
55        batch: usize,
56        mel_frames: usize,
57        graph_opts: crate::builder::WhisperGraphOpts,
58        fused_enc: Option<&'a crate::fused::FusedEncoderWeights>,
59    ) -> Self {
60        Self {
61            cfg,
62            batch,
63            mel_frames,
64            pfx: WhisperWeightPrefix::detect(weights),
65            graph_opts,
66            fused_enc,
67        }
68    }
69
70    pub fn build(self, weights: &mut WeightMap) -> Result<BuiltModel> {
71        build_whisper_encoder_built_opts(
72            self.cfg,
73            weights,
74            &self.pfx,
75            self.batch,
76            self.mel_frames,
77            self.graph_opts,
78            self.fused_enc,
79        )
80    }
81}
82
83pub fn build_whisper_encoder_built(
84    cfg: &WhisperConfig,
85    weights: &mut WeightMap,
86    pfx: &WhisperWeightPrefix,
87    batch: usize,
88    mel_frames: usize,
89) -> Result<BuiltModel> {
90    build_whisper_encoder_built_opts(
91        cfg,
92        weights,
93        pfx,
94        batch,
95        mel_frames,
96        crate::builder::WhisperGraphOpts::default(),
97        None,
98    )
99}
100
101pub fn build_whisper_encoder_built_opts(
102    cfg: &WhisperConfig,
103    weights: &mut WeightMap,
104    pfx: &WhisperWeightPrefix,
105    batch: usize,
106    mel_frames: usize,
107    graph_opts: crate::builder::WhisperGraphOpts,
108    fused_enc: Option<&crate::fused::FusedEncoderWeights>,
109) -> Result<BuiltModel> {
110    let enc_seq = cfg.encoder_seq_len(mel_frames);
111    let f = DType::F32;
112    let hidden_shape = Shape::new(&[batch, enc_seq, cfg.d_model], f);
113    let cfg = cfg.clone();
114    let pfx = pfx.clone();
115    let fused_enc = fused_enc.cloned();
116
117    ModelFlow::new("whisper_encoder")
118        .with_profile(CompileProfile::encoder())
119        .input("mel", Shape::new(&[batch, cfg.num_mel_bins, mel_frames], f))
120        .plugin_named("whisper.encoder", move |emit, _| {
121            let mel = emit.flow_input("mel")?.hir_id();
122            let hir = emit
123                .module
124                .as_hir_mut()
125                .expect("whisper encoder flow requires HIR stage");
126            let mut b = crate::builder::WhisperBuilder::new(
127                hir,
128                emit.params,
129                emit.weights,
130                &pfx,
131                batch,
132                graph_opts,
133            );
134            if let Some(ref fused_enc) = fused_enc {
135                b = b.with_fused_enc(fused_enc);
136            }
137            let hidden = b.emit_encoder(&cfg, mel, mel_frames, enc_seq)?;
138            Ok(Some(emit.wrap(hidden, hidden_shape.clone())))
139        })
140        .output("encoder_hidden")
141        .build(&mut WeightMapSource(weights))
142}
143
144#[derive(Debug, Clone)]
145pub struct WhisperDecoderFlow<'a> {
146    cfg: &'a WhisperConfig,
147    batch: usize,
148    dec_seq: usize,
149    enc_seq: usize,
150    pfx: WhisperWeightPrefix,
151}
152
153impl<'a> WhisperDecoderFlow<'a> {
154    pub fn new(
155        cfg: &'a WhisperConfig,
156        weights: &WeightMap,
157        batch: usize,
158        dec_seq: usize,
159        enc_seq: usize,
160    ) -> Self {
161        Self {
162            cfg,
163            batch,
164            dec_seq,
165            enc_seq,
166            pfx: WhisperWeightPrefix::detect(weights),
167        }
168    }
169
170    pub fn build(self, weights: &mut WeightMap) -> Result<BuiltModel> {
171        build_whisper_decoder_built(
172            self.cfg,
173            weights,
174            &self.pfx,
175            self.batch,
176            self.dec_seq,
177            self.enc_seq,
178        )
179    }
180}
181
182pub fn build_whisper_decoder_built(
183    cfg: &WhisperConfig,
184    weights: &mut WeightMap,
185    pfx: &WhisperWeightPrefix,
186    batch: usize,
187    dec_seq: usize,
188    enc_seq: usize,
189) -> Result<BuiltModel> {
190    let f = DType::F32;
191    let h = cfg.d_model;
192    let logits_shape = Shape::new(&[batch, dec_seq, cfg.vocab_size], f);
193    let cfg = cfg.clone();
194    let pfx = pfx.clone();
195
196    ModelFlow::new("whisper_decoder")
197        .with_profile(CompileProfile::encoder())
198        .input("token_ids", Shape::new(&[batch, dec_seq], f))
199        .input("encoder_hidden", Shape::new(&[batch, enc_seq, h], f))
200        .plugin_named("whisper.decoder", move |emit, _| {
201            let tokens = emit.flow_input("token_ids")?.hir_id();
202            let enc = emit.flow_input("encoder_hidden")?.hir_id();
203            let hir = emit
204                .module
205                .as_hir_mut()
206                .expect("whisper decoder flow requires HIR stage");
207            let mut b = crate::builder::WhisperBuilder::new(
208                hir,
209                emit.params,
210                emit.weights,
211                &pfx,
212                batch,
213                crate::builder::WhisperGraphOpts::default(),
214            );
215            let logits = b.emit_decoder(&cfg, tokens, enc, dec_seq, enc_seq)?;
216            Ok(Some(emit.wrap(logits, logits_shape.clone())))
217        })
218        .output("logits")
219        .build(&mut WeightMapSource(weights))
220}
221
222/// Default 30 s chunk mel width.
223pub fn default_mel_frames() -> usize {
224    N_FRAMES
225}
226
227pub fn build_whisper_encoder_graph_sized(
228    cfg: &WhisperConfig,
229    weights: &mut WeightMap,
230    batch: usize,
231    mel_frames: usize,
232) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
233    let pfx = WhisperWeightPrefix::detect(weights);
234    rlx_core::flow_util::graph_from_built(build_whisper_encoder_built(
235        cfg, weights, &pfx, batch, mel_frames,
236    )?)
237}
238
239pub fn build_whisper_cross_kv_built(
240    cfg: &WhisperConfig,
241    weights: &mut WeightMap,
242    pfx: &WhisperWeightPrefix,
243    batch: usize,
244    enc_seq: usize,
245) -> Result<BuiltModel> {
246    let (hir, params) = crate::builder::build_whisper_cross_kv_hir(
247        cfg,
248        &mut WeightMapSource(weights),
249        pfx,
250        batch,
251        enc_seq,
252    )?;
253    rlx_core::flow_util::built_from_hir_with_profile(hir, params, CompileProfile::encoder())
254}
255
256pub fn build_whisper_decoder_prefill_built(
257    cfg: &WhisperConfig,
258    weights: &mut WeightMap,
259    pfx: &WhisperWeightPrefix,
260    batch: usize,
261    dec_seq: usize,
262    enc_seq: usize,
263) -> Result<BuiltModel> {
264    build_whisper_decoder_prefill_built_ext(cfg, weights, pfx, batch, dec_seq, enc_seq, true)
265}
266
267pub fn build_whisper_decoder_prefill_built_ext(
268    cfg: &WhisperConfig,
269    weights: &mut WeightMap,
270    pfx: &WhisperWeightPrefix,
271    batch: usize,
272    dec_seq: usize,
273    enc_seq: usize,
274    use_cross_cache: bool,
275) -> Result<BuiltModel> {
276    build_whisper_decoder_prefill_built_ext_opts(
277        cfg,
278        weights,
279        pfx,
280        batch,
281        dec_seq,
282        enc_seq,
283        use_cross_cache,
284        crate::builder::WhisperGraphOpts::default(),
285        None,
286    )
287}
288
289pub fn build_whisper_decoder_prefill_built_ext_opts(
290    cfg: &WhisperConfig,
291    weights: &mut WeightMap,
292    pfx: &WhisperWeightPrefix,
293    batch: usize,
294    dec_seq: usize,
295    enc_seq: usize,
296    use_cross_cache: bool,
297    graph_opts: crate::builder::WhisperGraphOpts,
298    fused: Option<&crate::fused::FusedDecoderWeights>,
299) -> Result<BuiltModel> {
300    let (hir, params) = crate::builder::build_whisper_decoder_prefill_hir_ext_opts(
301        cfg,
302        &mut WeightMapSource(weights),
303        pfx,
304        batch,
305        dec_seq,
306        enc_seq,
307        use_cross_cache,
308        graph_opts,
309        fused,
310    )?;
311    rlx_core::flow_util::built_from_hir_with_profile(hir, params, CompileProfile::llama32_prefill())
312}
313
314pub fn build_whisper_decode_step_built(
315    cfg: &WhisperConfig,
316    weights: &mut WeightMap,
317    pfx: &WhisperWeightPrefix,
318    batch: usize,
319    past_seq: usize,
320    enc_seq: usize,
321) -> Result<BuiltModel> {
322    build_whisper_decode_step_built_ext(cfg, weights, pfx, batch, past_seq, enc_seq, false)
323}
324
325pub fn build_whisper_decode_step_built_ext(
326    cfg: &WhisperConfig,
327    weights: &mut WeightMap,
328    pfx: &WhisperWeightPrefix,
329    batch: usize,
330    past_seq: usize,
331    enc_seq: usize,
332    use_custom_mask: bool,
333) -> Result<BuiltModel> {
334    build_whisper_decode_step_built_ext_opts(
335        cfg,
336        weights,
337        pfx,
338        batch,
339        past_seq,
340        enc_seq,
341        use_custom_mask,
342        crate::builder::WhisperGraphOpts::default(),
343        None,
344    )
345}
346
347pub fn build_whisper_decode_step_built_ext_opts(
348    cfg: &WhisperConfig,
349    weights: &mut WeightMap,
350    pfx: &WhisperWeightPrefix,
351    batch: usize,
352    past_seq: usize,
353    enc_seq: usize,
354    use_custom_mask: bool,
355    graph_opts: crate::builder::WhisperGraphOpts,
356    fused: Option<&crate::fused::FusedDecoderWeights>,
357) -> Result<BuiltModel> {
358    let (hir, params) = crate::builder::build_whisper_decode_step_hir_ext_opts(
359        cfg,
360        &mut WeightMapSource(weights),
361        pfx,
362        batch,
363        past_seq,
364        enc_seq,
365        use_custom_mask,
366        graph_opts,
367        fused,
368    )?;
369    rlx_core::flow_util::built_from_hir_with_profile(hir, params, CompileProfile::gemma_decode())
370}
371
372pub fn build_whisper_decoder_graph_sized(
373    cfg: &WhisperConfig,
374    weights: &mut WeightMap,
375    batch: usize,
376    dec_seq: usize,
377    enc_seq: usize,
378) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
379    let pfx = WhisperWeightPrefix::detect(weights);
380    rlx_core::flow_util::graph_from_built(build_whisper_decoder_built(
381        cfg, weights, &pfx, batch, dec_seq, enc_seq,
382    )?)
383}