1use super::adapt::prepare_weight_map;
19use super::config::Flux2Config;
20use anyhow::{Context, Result, ensure};
21use rlx_core::weight_map::WeightMap;
22use std::path::Path;
23
24#[derive(Debug, Clone)]
25pub struct LinearWeights {
26 pub w_t: Vec<f32>,
27 pub in_dim: usize,
28 pub out_dim: usize,
29 pub bias: Vec<f32>,
30}
31
32#[derive(Debug, Clone)]
33pub struct RmsNormWeight {
34 pub scale: Vec<f32>,
35}
36
37#[derive(Debug, Clone)]
38pub struct Flux2FeedForwardWeights {
39 pub linear_in: LinearWeights,
40 pub linear_out: LinearWeights,
41}
42
43#[derive(Debug, Clone)]
44pub struct Flux2DualAttnWeights {
45 pub to_q: LinearWeights,
46 pub to_k: LinearWeights,
47 pub to_v: LinearWeights,
48 pub norm_q: RmsNormWeight,
49 pub norm_k: RmsNormWeight,
50 pub add_q: LinearWeights,
51 pub add_k: LinearWeights,
52 pub add_v: LinearWeights,
53 pub norm_added_q: RmsNormWeight,
54 pub norm_added_k: RmsNormWeight,
55 pub to_out: LinearWeights,
56 pub to_add_out: LinearWeights,
57}
58
59#[derive(Debug, Clone)]
60pub struct Flux2ParallelAttnWeights {
61 pub to_qkv_mlp: LinearWeights,
62 pub norm_q: RmsNormWeight,
63 pub norm_k: RmsNormWeight,
64 pub to_out: LinearWeights,
65}
66
67#[derive(Debug, Clone)]
68pub struct Flux2DoubleBlockWeights {
69 pub attn: Flux2DualAttnWeights,
70 pub ff: Flux2FeedForwardWeights,
71 pub ff_context: Flux2FeedForwardWeights,
72}
73
74#[derive(Debug, Clone)]
75pub struct Flux2SingleBlockWeights {
76 pub attn: Flux2ParallelAttnWeights,
77}
78
79#[derive(Debug, Clone)]
80pub struct Flux2TimestepGuidanceWeights {
81 pub timestep_linear1: LinearWeights,
82 pub timestep_linear2: LinearWeights,
83 pub guidance_linear1: Option<LinearWeights>,
84 pub guidance_linear2: Option<LinearWeights>,
85}
86
87#[derive(Debug, Clone)]
88pub struct Flux2ModulationWeights {
89 pub linear: LinearWeights,
90}
91
92#[derive(Debug, Clone)]
93pub struct Flux2NormOutWeights {
94 pub linear: LinearWeights,
95}
96
97#[derive(Debug, Clone)]
98pub struct Flux2Weights {
99 pub x_embedder: LinearWeights,
100 pub context_embedder: LinearWeights,
101 pub time_guidance: Flux2TimestepGuidanceWeights,
102 pub time_guidance_target: Option<Flux2TimestepGuidanceWeights>,
105 pub double_mod_img: Flux2ModulationWeights,
106 pub double_mod_txt: Flux2ModulationWeights,
107 pub single_mod: Flux2ModulationWeights,
108 pub transformer_blocks: Vec<Flux2DoubleBlockWeights>,
109 pub single_transformer_blocks: Vec<Flux2SingleBlockWeights>,
110 pub norm_out: Flux2NormOutWeights,
111 pub proj_out: LinearWeights,
112}
113
114pub fn load_flux2_weight_map(path: &Path) -> Result<WeightMap> {
116 rlx_core::load_weight_map(path, rlx_core::FLUX_GGUF_ARCHES)
117}
118
119pub fn load_flux2_weights(path: &str, cfg: &Flux2Config) -> Result<Flux2Weights> {
120 let wm = load_flux2_weight_map(Path::new(path))?;
121 extract_flux2_weights(prepare_weight_map(wm), cfg)
122}
123
124pub fn extract_flux2_weights(wm: WeightMap, cfg: &Flux2Config) -> Result<Flux2Weights> {
125 extract_flux2_weights_with_opts(wm, cfg, ExtractFlux2Opts::default())
126}
127
128pub fn extract_flux2_weights_with_opts(
129 mut wm: WeightMap,
130 cfg: &Flux2Config,
131 opts: ExtractFlux2Opts<'_>,
132) -> Result<Flux2Weights> {
133 let guidance_embeds = cfg.guidance_embeds
134 && (wm.has("time_guidance_embed.guidance_embedder.linear_1.weight")
135 || wm.has("guidance_in.in_layer.weight"));
136
137 let x_embedder =
138 load_linear_with_opts(&mut wm, "x_embedder.weight", "x_embedder.bias", false, opts)?;
139 let context_embedder = load_linear_with_opts(
140 &mut wm,
141 "context_embedder.weight",
142 "context_embedder.bias",
143 false,
144 opts,
145 )?;
146 let time_guidance =
147 load_time_guidance_block(&mut wm, "time_guidance_embed", guidance_embeds, opts)?;
148 let time_guidance_target =
149 try_load_time_guidance_block(&mut wm, "time_guidance_embed_target", guidance_embeds, opts)?
150 .or_else(|| {
151 if opts.dual_time_embedder {
152 Some(time_guidance.clone())
153 } else {
154 None
155 }
156 });
157 let double_mod_img = Flux2ModulationWeights {
158 linear: load_linear_with_opts(
159 &mut wm,
160 "double_stream_modulation_img.linear.weight",
161 "double_stream_modulation_img.linear.bias",
162 false,
163 opts,
164 )?,
165 };
166 let double_mod_txt = Flux2ModulationWeights {
167 linear: load_linear_with_opts(
168 &mut wm,
169 "double_stream_modulation_txt.linear.weight",
170 "double_stream_modulation_txt.linear.bias",
171 false,
172 opts,
173 )?,
174 };
175 let single_mod = Flux2ModulationWeights {
176 linear: load_linear_with_opts(
177 &mut wm,
178 "single_stream_modulation.linear.weight",
179 "single_stream_modulation.linear.bias",
180 false,
181 opts,
182 )?,
183 };
184
185 let mut transformer_blocks = Vec::with_capacity(cfg.num_layers);
186 for i in 0..cfg.num_layers {
187 let p = format!("transformer_blocks.{i}");
188 transformer_blocks.push(Flux2DoubleBlockWeights {
189 attn: load_dual_attn(&mut wm, &p, opts)?,
190 ff: load_ff(&mut wm, &format!("{p}.ff"), opts)?,
191 ff_context: load_ff(&mut wm, &format!("{p}.ff_context"), opts)?,
192 });
193 }
194
195 let mut single_transformer_blocks = Vec::with_capacity(cfg.num_single_layers);
196 for i in 0..cfg.num_single_layers {
197 let p = format!("single_transformer_blocks.{i}");
198 single_transformer_blocks.push(Flux2SingleBlockWeights {
199 attn: load_parallel_attn(&mut wm, &p, opts)?,
200 });
201 }
202
203 let norm_out = Flux2NormOutWeights {
204 linear: load_linear_with_opts(
205 &mut wm,
206 "norm_out.linear.weight",
207 "norm_out.linear.bias",
208 true,
209 opts,
210 )?,
211 };
212 let proj_out = load_linear_with_opts(&mut wm, "proj_out.weight", "proj_out.bias", false, opts)?;
213
214 Ok(Flux2Weights {
215 x_embedder,
216 context_embedder,
217 time_guidance,
218 time_guidance_target,
219 double_mod_img,
220 double_mod_txt,
221 single_mod,
222 transformer_blocks,
223 single_transformer_blocks,
224 norm_out,
225 proj_out,
226 })
227}
228
229fn load_ff(
230 wm: &mut WeightMap,
231 prefix: &str,
232 opts: ExtractFlux2Opts<'_>,
233) -> Result<Flux2FeedForwardWeights> {
234 Ok(Flux2FeedForwardWeights {
235 linear_in: load_linear_with_opts(
236 wm,
237 &format!("{prefix}.linear_in.weight"),
238 &format!("{prefix}.linear_in.bias"),
239 true,
240 opts,
241 )?,
242 linear_out: load_linear_with_opts(
243 wm,
244 &format!("{prefix}.linear_out.weight"),
245 &format!("{prefix}.linear_out.bias"),
246 true,
247 opts,
248 )?,
249 })
250}
251
252fn load_dual_attn(
253 wm: &mut WeightMap,
254 prefix: &str,
255 opts: ExtractFlux2Opts<'_>,
256) -> Result<Flux2DualAttnWeights> {
257 let ap = format!("{prefix}.attn");
258 Ok(Flux2DualAttnWeights {
259 to_q: load_linear_with_opts(
260 wm,
261 &format!("{ap}.to_q.weight"),
262 &format!("{ap}.to_q.bias"),
263 true,
264 opts,
265 )?,
266 to_k: load_linear_with_opts(
267 wm,
268 &format!("{ap}.to_k.weight"),
269 &format!("{ap}.to_k.bias"),
270 true,
271 opts,
272 )?,
273 to_v: load_linear_with_opts(
274 wm,
275 &format!("{ap}.to_v.weight"),
276 &format!("{ap}.to_v.bias"),
277 true,
278 opts,
279 )?,
280 norm_q: load_rms(wm, &format!("{ap}.norm_q.weight"))?,
281 norm_k: load_rms(wm, &format!("{ap}.norm_k.weight"))?,
282 add_q: load_linear_with_opts(
283 wm,
284 &format!("{ap}.add_q_proj.weight"),
285 &format!("{ap}.add_q_proj.bias"),
286 true,
287 opts,
288 )?,
289 add_k: load_linear_with_opts(
290 wm,
291 &format!("{ap}.add_k_proj.weight"),
292 &format!("{ap}.add_k_proj.bias"),
293 true,
294 opts,
295 )?,
296 add_v: load_linear_with_opts(
297 wm,
298 &format!("{ap}.add_v_proj.weight"),
299 &format!("{ap}.add_v_proj.bias"),
300 true,
301 opts,
302 )?,
303 norm_added_q: load_rms(wm, &format!("{ap}.norm_added_q.weight"))?,
304 norm_added_k: load_rms(wm, &format!("{ap}.norm_added_k.weight"))?,
305 to_out: load_linear_with_opts(
306 wm,
307 &format!("{ap}.to_out.0.weight"),
308 &format!("{ap}.to_out.0.bias"),
309 true,
310 opts,
311 )?,
312 to_add_out: load_linear_with_opts(
313 wm,
314 &format!("{ap}.to_add_out.weight"),
315 &format!("{ap}.to_add_out.bias"),
316 true,
317 opts,
318 )?,
319 })
320}
321
322fn load_parallel_attn(
323 wm: &mut WeightMap,
324 prefix: &str,
325 opts: ExtractFlux2Opts<'_>,
326) -> Result<Flux2ParallelAttnWeights> {
327 let ap = format!("{prefix}.attn");
328 Ok(Flux2ParallelAttnWeights {
329 to_qkv_mlp: load_linear_with_opts(
330 wm,
331 &format!("{ap}.to_qkv_mlp_proj.weight"),
332 &format!("{ap}.to_qkv_mlp_proj.bias"),
333 true,
334 opts,
335 )?,
336 norm_q: load_rms(wm, &format!("{ap}.norm_q.weight"))?,
337 norm_k: load_rms(wm, &format!("{ap}.norm_k.weight"))?,
338 to_out: load_linear_with_opts(
339 wm,
340 &format!("{ap}.to_out.weight"),
341 &format!("{ap}.to_out.bias"),
342 true,
343 opts,
344 )?,
345 })
346}
347
348pub(crate) fn load_rms(wm: &mut WeightMap, key: &str) -> Result<RmsNormWeight> {
349 let (scale, shape) = wm.take(key).with_context(|| format!("missing {key}"))?;
350 ensure!(shape.len() == 1, "{key}: expected 1D scale");
351 Ok(RmsNormWeight { scale })
352}
353
354#[derive(Copy, Clone, Default)]
355pub struct ExtractFlux2Opts<'a> {
356 pub typed_linears: Option<&'a crate::typed_linear::TypedLinearStore>,
357 pub packed_linears: Option<&'a crate::packed::Flux2PackedParams>,
358 pub dual_time_embedder: bool,
360}
361
362fn load_time_guidance_block(
363 wm: &mut WeightMap,
364 prefix: &str,
365 guidance_embeds: bool,
366 opts: ExtractFlux2Opts<'_>,
367) -> Result<Flux2TimestepGuidanceWeights> {
368 Ok(Flux2TimestepGuidanceWeights {
369 timestep_linear1: load_linear_with_opts(
370 wm,
371 &format!("{prefix}.timestep_embedder.linear_1.weight"),
372 &format!("{prefix}.timestep_embedder.linear_1.bias"),
373 true,
374 opts,
375 )?,
376 timestep_linear2: load_linear_with_opts(
377 wm,
378 &format!("{prefix}.timestep_embedder.linear_2.weight"),
379 &format!("{prefix}.timestep_embedder.linear_2.bias"),
380 true,
381 opts,
382 )?,
383 guidance_linear1: if guidance_embeds {
384 Some(load_linear_with_opts(
385 wm,
386 &format!("{prefix}.guidance_embedder.linear_1.weight"),
387 &format!("{prefix}.guidance_embedder.linear_1.bias"),
388 true,
389 opts,
390 )?)
391 } else {
392 None
393 },
394 guidance_linear2: if guidance_embeds {
395 Some(load_linear_with_opts(
396 wm,
397 &format!("{prefix}.guidance_embedder.linear_2.weight"),
398 &format!("{prefix}.guidance_embedder.linear_2.bias"),
399 true,
400 opts,
401 )?)
402 } else {
403 None
404 },
405 })
406}
407
408fn try_load_time_guidance_block(
409 wm: &mut WeightMap,
410 prefix: &str,
411 guidance_embeds: bool,
412 opts: ExtractFlux2Opts<'_>,
413) -> Result<Option<Flux2TimestepGuidanceWeights>> {
414 let w1 = format!("{prefix}.timestep_embedder.linear_1.weight");
415 if !wm.has(&w1) {
416 return Ok(None);
417 }
418 Ok(Some(load_time_guidance_block(
419 wm,
420 prefix,
421 guidance_embeds,
422 opts,
423 )?))
424}
425
426pub(crate) fn load_linear(
427 wm: &mut WeightMap,
428 w_key: &str,
429 b_key: &str,
430 expect_bias: bool,
431) -> Result<LinearWeights> {
432 load_linear_with_opts(wm, w_key, b_key, expect_bias, ExtractFlux2Opts::default())
433}
434
435pub(crate) fn load_linear_with_opts(
436 wm: &mut WeightMap,
437 w_key: &str,
438 b_key: &str,
439 _expect_bias: bool,
440 opts: ExtractFlux2Opts<'_>,
441) -> Result<LinearWeights> {
442 let prefix = w_key.strip_suffix(".weight").unwrap_or(w_key);
443 if !wm.has(w_key) {
444 if let Some(tl) = opts.typed_linears.and_then(|t| t.get(prefix)) {
445 return Ok(LinearWeights {
446 w_t: Vec::new(),
447 in_dim: tl.in_dim,
448 out_dim: tl.out_dim,
449 bias: tl.bias.clone(),
450 });
451 }
452 if let Some(p) = opts.packed_linears.and_then(|m| m.get_nvfp4(prefix)) {
453 return Ok(LinearWeights {
454 w_t: Vec::new(),
455 in_dim: p.in_dim,
456 out_dim: p.out_dim,
457 bias: p.bias.clone(),
458 });
459 }
460 if let Some(p) = opts.packed_linears.and_then(|m| m.get_gguf(prefix)) {
461 return Ok(LinearWeights {
462 w_t: Vec::new(),
463 in_dim: p.in_dim,
464 out_dim: p.out_dim,
465 bias: p.bias.clone(),
466 });
467 }
468 }
469 let (w_t, shape) = wm
470 .take_transposed(w_key)
471 .with_context(|| format!("missing {w_key}"))?;
472 ensure!(shape.len() == 2, "{w_key}: expected 2D");
473 let out_dim = shape[1];
474 let in_dim = shape[0];
475 let bias = if wm.has(b_key) {
476 let (b, bshape) = wm.take(b_key)?;
477 ensure!(bshape == vec![out_dim], "{b_key}: bias shape");
478 b
479 } else {
480 vec![0.0f32; out_dim]
481 };
482 Ok(LinearWeights {
483 w_t,
484 in_dim,
485 out_dim,
486 bias,
487 })
488}