1use crate::audio::N_FRAMES;
17use crate::config::WhisperConfig;
18use crate::weights::WhisperWeightPrefix;
19use anyhow::Result;
20use rlx_core::flow_util::WeightMapSource;
21use rlx_core::weight_map::WeightMap;
22use rlx_flow::{BuiltModel, CompileProfile, ModelFlow};
23use rlx_ir::{DType, Shape};
24
25#[derive(Debug, Clone)]
26pub struct WhisperEncoderFlow<'a> {
27 cfg: &'a WhisperConfig,
28 batch: usize,
29 mel_frames: usize,
30 pfx: WhisperWeightPrefix,
31}
32
33impl<'a> WhisperEncoderFlow<'a> {
34 pub fn new(
35 cfg: &'a WhisperConfig,
36 weights: &WeightMap,
37 batch: usize,
38 mel_frames: usize,
39 ) -> Self {
40 Self {
41 cfg,
42 batch,
43 mel_frames,
44 pfx: WhisperWeightPrefix::detect(weights),
45 }
46 }
47
48 pub fn build(self, weights: &mut WeightMap) -> Result<BuiltModel> {
49 build_whisper_encoder_built(self.cfg, weights, &self.pfx, self.batch, self.mel_frames)
50 }
51}
52
53pub fn build_whisper_encoder_built(
54 cfg: &WhisperConfig,
55 weights: &mut WeightMap,
56 pfx: &WhisperWeightPrefix,
57 batch: usize,
58 mel_frames: usize,
59) -> Result<BuiltModel> {
60 let enc_seq = cfg.encoder_seq_len(mel_frames);
61 let f = DType::F32;
62 let hidden_shape = Shape::new(&[batch, enc_seq, cfg.d_model], f);
63 let cfg = cfg.clone();
64 let pfx = pfx.clone();
65
66 ModelFlow::new("whisper_encoder")
67 .with_profile(CompileProfile::encoder())
68 .input("mel", Shape::new(&[batch, cfg.num_mel_bins, mel_frames], f))
69 .plugin_named("whisper.encoder", move |emit, _| {
70 let mel = emit.flow_input("mel")?.hir_id();
71 let hir = emit
72 .module
73 .as_hir_mut()
74 .expect("whisper encoder flow requires HIR stage");
75 let mut b = crate::builder::WhisperBuilder {
76 hir,
77 params: emit.params,
78 weights: emit.weights,
79 pfx: &pfx,
80 batch,
81 f: DType::F32,
82 };
83 let hidden = b.emit_encoder(&cfg, mel, mel_frames, enc_seq)?;
84 Ok(Some(emit.wrap(hidden, hidden_shape.clone())))
85 })
86 .output("encoder_hidden")
87 .build(&mut WeightMapSource(weights))
88}
89
90#[derive(Debug, Clone)]
91pub struct WhisperDecoderFlow<'a> {
92 cfg: &'a WhisperConfig,
93 batch: usize,
94 dec_seq: usize,
95 enc_seq: usize,
96 pfx: WhisperWeightPrefix,
97}
98
99impl<'a> WhisperDecoderFlow<'a> {
100 pub fn new(
101 cfg: &'a WhisperConfig,
102 weights: &WeightMap,
103 batch: usize,
104 dec_seq: usize,
105 enc_seq: usize,
106 ) -> Self {
107 Self {
108 cfg,
109 batch,
110 dec_seq,
111 enc_seq,
112 pfx: WhisperWeightPrefix::detect(weights),
113 }
114 }
115
116 pub fn build(self, weights: &mut WeightMap) -> Result<BuiltModel> {
117 build_whisper_decoder_built(
118 self.cfg,
119 weights,
120 &self.pfx,
121 self.batch,
122 self.dec_seq,
123 self.enc_seq,
124 )
125 }
126}
127
128pub fn build_whisper_decoder_built(
129 cfg: &WhisperConfig,
130 weights: &mut WeightMap,
131 pfx: &WhisperWeightPrefix,
132 batch: usize,
133 dec_seq: usize,
134 enc_seq: usize,
135) -> Result<BuiltModel> {
136 let f = DType::F32;
137 let h = cfg.d_model;
138 let logits_shape = Shape::new(&[batch, dec_seq, cfg.vocab_size], f);
139 let cfg = cfg.clone();
140 let pfx = pfx.clone();
141
142 ModelFlow::new("whisper_decoder")
143 .with_profile(CompileProfile::encoder())
144 .input("token_ids", Shape::new(&[batch, dec_seq], f))
145 .input("encoder_hidden", Shape::new(&[batch, enc_seq, h], f))
146 .plugin_named("whisper.decoder", move |emit, _| {
147 let tokens = emit.flow_input("token_ids")?.hir_id();
148 let enc = emit.flow_input("encoder_hidden")?.hir_id();
149 let hir = emit
150 .module
151 .as_hir_mut()
152 .expect("whisper decoder flow requires HIR stage");
153 let mut b = crate::builder::WhisperBuilder {
154 hir,
155 params: emit.params,
156 weights: emit.weights,
157 pfx: &pfx,
158 batch,
159 f: DType::F32,
160 };
161 let logits = b.emit_decoder(&cfg, tokens, enc, dec_seq, enc_seq)?;
162 Ok(Some(emit.wrap(logits, logits_shape.clone())))
163 })
164 .output("logits")
165 .build(&mut WeightMapSource(weights))
166}
167
168pub fn default_mel_frames() -> usize {
170 N_FRAMES
171}
172
173pub fn build_whisper_encoder_graph_sized(
174 cfg: &WhisperConfig,
175 weights: &mut WeightMap,
176 batch: usize,
177 mel_frames: usize,
178) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
179 let pfx = WhisperWeightPrefix::detect(weights);
180 rlx_core::flow_util::graph_from_built(build_whisper_encoder_built(
181 cfg, weights, &pfx, batch, mel_frames,
182 )?)
183}
184
185pub fn build_whisper_cross_kv_built(
186 cfg: &WhisperConfig,
187 weights: &mut WeightMap,
188 pfx: &WhisperWeightPrefix,
189 batch: usize,
190 enc_seq: usize,
191) -> Result<BuiltModel> {
192 let (hir, params) = crate::builder::build_whisper_cross_kv_hir(
193 cfg,
194 &mut WeightMapSource(weights),
195 pfx,
196 batch,
197 enc_seq,
198 )?;
199 rlx_core::flow_util::built_from_hir(hir, params)
200}
201
202pub fn build_whisper_decoder_prefill_built(
203 cfg: &WhisperConfig,
204 weights: &mut WeightMap,
205 pfx: &WhisperWeightPrefix,
206 batch: usize,
207 dec_seq: usize,
208 enc_seq: usize,
209) -> Result<BuiltModel> {
210 build_whisper_decoder_prefill_built_ext(cfg, weights, pfx, batch, dec_seq, enc_seq, true)
211}
212
213pub fn build_whisper_decoder_prefill_built_ext(
214 cfg: &WhisperConfig,
215 weights: &mut WeightMap,
216 pfx: &WhisperWeightPrefix,
217 batch: usize,
218 dec_seq: usize,
219 enc_seq: usize,
220 use_cross_cache: bool,
221) -> Result<BuiltModel> {
222 let (hir, params) = crate::builder::build_whisper_decoder_prefill_hir_ext(
223 cfg,
224 &mut WeightMapSource(weights),
225 pfx,
226 batch,
227 dec_seq,
228 enc_seq,
229 use_cross_cache,
230 )?;
231 rlx_core::flow_util::built_from_hir(hir, params)
232}
233
234pub fn build_whisper_decode_step_built(
235 cfg: &WhisperConfig,
236 weights: &mut WeightMap,
237 pfx: &WhisperWeightPrefix,
238 batch: usize,
239 past_seq: usize,
240 enc_seq: usize,
241) -> Result<BuiltModel> {
242 build_whisper_decode_step_built_ext(cfg, weights, pfx, batch, past_seq, enc_seq, false)
243}
244
245pub fn build_whisper_decode_step_built_ext(
246 cfg: &WhisperConfig,
247 weights: &mut WeightMap,
248 pfx: &WhisperWeightPrefix,
249 batch: usize,
250 past_seq: usize,
251 enc_seq: usize,
252 use_custom_mask: bool,
253) -> Result<BuiltModel> {
254 let (hir, params) = crate::builder::build_whisper_decode_step_hir_ext(
255 cfg,
256 &mut WeightMapSource(weights),
257 pfx,
258 batch,
259 past_seq,
260 enc_seq,
261 use_custom_mask,
262 )?;
263 rlx_core::flow_util::built_from_hir(hir, params)
264}
265
266pub fn build_whisper_decoder_graph_sized(
267 cfg: &WhisperConfig,
268 weights: &mut WeightMap,
269 batch: usize,
270 dec_seq: usize,
271 enc_seq: usize,
272) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
273 let pfx = WhisperWeightPrefix::detect(weights);
274 rlx_core::flow_util::graph_from_built(build_whisper_decoder_built(
275 cfg, weights, &pfx, batch, dec_seq, enc_seq,
276 )?)
277}