Skip to main content

rlx_whisper/
flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use crate::audio::N_FRAMES;
17use crate::config::WhisperConfig;
18use crate::weights::WhisperWeightPrefix;
19use anyhow::Result;
20use rlx_core::flow_util::WeightMapSource;
21use rlx_core::weight_map::WeightMap;
22use rlx_flow::{BuiltModel, CompileProfile, ModelFlow};
23use rlx_ir::{DType, Shape};
24
25#[derive(Debug, Clone)]
26pub struct WhisperEncoderFlow<'a> {
27    cfg: &'a WhisperConfig,
28    batch: usize,
29    mel_frames: usize,
30    pfx: WhisperWeightPrefix,
31}
32
33impl<'a> WhisperEncoderFlow<'a> {
34    pub fn new(
35        cfg: &'a WhisperConfig,
36        weights: &WeightMap,
37        batch: usize,
38        mel_frames: usize,
39    ) -> Self {
40        Self {
41            cfg,
42            batch,
43            mel_frames,
44            pfx: WhisperWeightPrefix::detect(weights),
45        }
46    }
47
48    pub fn build(self, weights: &mut WeightMap) -> Result<BuiltModel> {
49        build_whisper_encoder_built(self.cfg, weights, &self.pfx, self.batch, self.mel_frames)
50    }
51}
52
53pub fn build_whisper_encoder_built(
54    cfg: &WhisperConfig,
55    weights: &mut WeightMap,
56    pfx: &WhisperWeightPrefix,
57    batch: usize,
58    mel_frames: usize,
59) -> Result<BuiltModel> {
60    let enc_seq = cfg.encoder_seq_len(mel_frames);
61    let f = DType::F32;
62    let hidden_shape = Shape::new(&[batch, enc_seq, cfg.d_model], f);
63    let cfg = cfg.clone();
64    let pfx = pfx.clone();
65
66    ModelFlow::new("whisper_encoder")
67        .with_profile(CompileProfile::encoder())
68        .input("mel", Shape::new(&[batch, cfg.num_mel_bins, mel_frames], f))
69        .plugin_named("whisper.encoder", move |emit, _| {
70            let mel = emit.flow_input("mel")?.hir_id();
71            let hir = emit
72                .module
73                .as_hir_mut()
74                .expect("whisper encoder flow requires HIR stage");
75            let mut b = crate::builder::WhisperBuilder {
76                hir,
77                params: emit.params,
78                weights: emit.weights,
79                pfx: &pfx,
80                batch,
81                f: DType::F32,
82            };
83            let hidden = b.emit_encoder(&cfg, mel, mel_frames, enc_seq)?;
84            Ok(Some(emit.wrap(hidden, hidden_shape.clone())))
85        })
86        .output("encoder_hidden")
87        .build(&mut WeightMapSource(weights))
88}
89
90#[derive(Debug, Clone)]
91pub struct WhisperDecoderFlow<'a> {
92    cfg: &'a WhisperConfig,
93    batch: usize,
94    dec_seq: usize,
95    enc_seq: usize,
96    pfx: WhisperWeightPrefix,
97}
98
99impl<'a> WhisperDecoderFlow<'a> {
100    pub fn new(
101        cfg: &'a WhisperConfig,
102        weights: &WeightMap,
103        batch: usize,
104        dec_seq: usize,
105        enc_seq: usize,
106    ) -> Self {
107        Self {
108            cfg,
109            batch,
110            dec_seq,
111            enc_seq,
112            pfx: WhisperWeightPrefix::detect(weights),
113        }
114    }
115
116    pub fn build(self, weights: &mut WeightMap) -> Result<BuiltModel> {
117        build_whisper_decoder_built(
118            self.cfg,
119            weights,
120            &self.pfx,
121            self.batch,
122            self.dec_seq,
123            self.enc_seq,
124        )
125    }
126}
127
128pub fn build_whisper_decoder_built(
129    cfg: &WhisperConfig,
130    weights: &mut WeightMap,
131    pfx: &WhisperWeightPrefix,
132    batch: usize,
133    dec_seq: usize,
134    enc_seq: usize,
135) -> Result<BuiltModel> {
136    let f = DType::F32;
137    let h = cfg.d_model;
138    let logits_shape = Shape::new(&[batch, dec_seq, cfg.vocab_size], f);
139    let cfg = cfg.clone();
140    let pfx = pfx.clone();
141
142    ModelFlow::new("whisper_decoder")
143        .with_profile(CompileProfile::encoder())
144        .input("token_ids", Shape::new(&[batch, dec_seq], f))
145        .input("encoder_hidden", Shape::new(&[batch, enc_seq, h], f))
146        .plugin_named("whisper.decoder", move |emit, _| {
147            let tokens = emit.flow_input("token_ids")?.hir_id();
148            let enc = emit.flow_input("encoder_hidden")?.hir_id();
149            let hir = emit
150                .module
151                .as_hir_mut()
152                .expect("whisper decoder flow requires HIR stage");
153            let mut b = crate::builder::WhisperBuilder {
154                hir,
155                params: emit.params,
156                weights: emit.weights,
157                pfx: &pfx,
158                batch,
159                f: DType::F32,
160            };
161            let logits = b.emit_decoder(&cfg, tokens, enc, dec_seq, enc_seq)?;
162            Ok(Some(emit.wrap(logits, logits_shape.clone())))
163        })
164        .output("logits")
165        .build(&mut WeightMapSource(weights))
166}
167
168/// Default 30 s chunk mel width.
169pub fn default_mel_frames() -> usize {
170    N_FRAMES
171}
172
173pub fn build_whisper_encoder_graph_sized(
174    cfg: &WhisperConfig,
175    weights: &mut WeightMap,
176    batch: usize,
177    mel_frames: usize,
178) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
179    let pfx = WhisperWeightPrefix::detect(weights);
180    rlx_core::flow_util::graph_from_built(build_whisper_encoder_built(
181        cfg, weights, &pfx, batch, mel_frames,
182    )?)
183}
184
185pub fn build_whisper_cross_kv_built(
186    cfg: &WhisperConfig,
187    weights: &mut WeightMap,
188    pfx: &WhisperWeightPrefix,
189    batch: usize,
190    enc_seq: usize,
191) -> Result<BuiltModel> {
192    let (hir, params) = crate::builder::build_whisper_cross_kv_hir(
193        cfg,
194        &mut WeightMapSource(weights),
195        pfx,
196        batch,
197        enc_seq,
198    )?;
199    rlx_core::flow_util::built_from_hir(hir, params)
200}
201
202pub fn build_whisper_decoder_prefill_built(
203    cfg: &WhisperConfig,
204    weights: &mut WeightMap,
205    pfx: &WhisperWeightPrefix,
206    batch: usize,
207    dec_seq: usize,
208    enc_seq: usize,
209) -> Result<BuiltModel> {
210    build_whisper_decoder_prefill_built_ext(cfg, weights, pfx, batch, dec_seq, enc_seq, true)
211}
212
213pub fn build_whisper_decoder_prefill_built_ext(
214    cfg: &WhisperConfig,
215    weights: &mut WeightMap,
216    pfx: &WhisperWeightPrefix,
217    batch: usize,
218    dec_seq: usize,
219    enc_seq: usize,
220    use_cross_cache: bool,
221) -> Result<BuiltModel> {
222    let (hir, params) = crate::builder::build_whisper_decoder_prefill_hir_ext(
223        cfg,
224        &mut WeightMapSource(weights),
225        pfx,
226        batch,
227        dec_seq,
228        enc_seq,
229        use_cross_cache,
230    )?;
231    rlx_core::flow_util::built_from_hir(hir, params)
232}
233
234pub fn build_whisper_decode_step_built(
235    cfg: &WhisperConfig,
236    weights: &mut WeightMap,
237    pfx: &WhisperWeightPrefix,
238    batch: usize,
239    past_seq: usize,
240    enc_seq: usize,
241) -> Result<BuiltModel> {
242    build_whisper_decode_step_built_ext(cfg, weights, pfx, batch, past_seq, enc_seq, false)
243}
244
245pub fn build_whisper_decode_step_built_ext(
246    cfg: &WhisperConfig,
247    weights: &mut WeightMap,
248    pfx: &WhisperWeightPrefix,
249    batch: usize,
250    past_seq: usize,
251    enc_seq: usize,
252    use_custom_mask: bool,
253) -> Result<BuiltModel> {
254    let (hir, params) = crate::builder::build_whisper_decode_step_hir_ext(
255        cfg,
256        &mut WeightMapSource(weights),
257        pfx,
258        batch,
259        past_seq,
260        enc_seq,
261        use_custom_mask,
262    )?;
263    rlx_core::flow_util::built_from_hir(hir, params)
264}
265
266pub fn build_whisper_decoder_graph_sized(
267    cfg: &WhisperConfig,
268    weights: &mut WeightMap,
269    batch: usize,
270    dec_seq: usize,
271    enc_seq: usize,
272) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
273    let pfx = WhisperWeightPrefix::detect(weights);
274    rlx_core::flow_util::graph_from_built(build_whisper_decoder_built(
275        cfg, weights, &pfx, batch, dec_seq, enc_seq,
276    )?)
277}