Skip to main content

rlx_flux2/
weights.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! FLUX.2 transformer weight extraction from safetensors.
17
18use super::adapt::prepare_weight_map;
19use super::config::Flux2Config;
20use anyhow::{Context, Result, ensure};
21use rlx_core::weight_map::WeightMap;
22use std::path::Path;
23
24#[derive(Debug, Clone)]
25pub struct LinearWeights {
26    pub w_t: Vec<f32>,
27    pub in_dim: usize,
28    pub out_dim: usize,
29    pub bias: Vec<f32>,
30}
31
32#[derive(Debug, Clone)]
33pub struct RmsNormWeight {
34    pub scale: Vec<f32>,
35}
36
37#[derive(Debug, Clone)]
38pub struct Flux2FeedForwardWeights {
39    pub linear_in: LinearWeights,
40    pub linear_out: LinearWeights,
41}
42
43#[derive(Debug, Clone)]
44pub struct Flux2DualAttnWeights {
45    pub to_q: LinearWeights,
46    pub to_k: LinearWeights,
47    pub to_v: LinearWeights,
48    pub norm_q: RmsNormWeight,
49    pub norm_k: RmsNormWeight,
50    pub add_q: LinearWeights,
51    pub add_k: LinearWeights,
52    pub add_v: LinearWeights,
53    pub norm_added_q: RmsNormWeight,
54    pub norm_added_k: RmsNormWeight,
55    pub to_out: LinearWeights,
56    pub to_add_out: LinearWeights,
57}
58
59#[derive(Debug, Clone)]
60pub struct Flux2ParallelAttnWeights {
61    pub to_qkv_mlp: LinearWeights,
62    pub norm_q: RmsNormWeight,
63    pub norm_k: RmsNormWeight,
64    pub to_out: LinearWeights,
65}
66
67#[derive(Debug, Clone)]
68pub struct Flux2DoubleBlockWeights {
69    pub attn: Flux2DualAttnWeights,
70    pub ff: Flux2FeedForwardWeights,
71    pub ff_context: Flux2FeedForwardWeights,
72}
73
74#[derive(Debug, Clone)]
75pub struct Flux2SingleBlockWeights {
76    pub attn: Flux2ParallelAttnWeights,
77}
78
79#[derive(Debug, Clone)]
80pub struct Flux2TimestepGuidanceWeights {
81    pub timestep_linear1: LinearWeights,
82    pub timestep_linear2: LinearWeights,
83    pub guidance_linear1: Option<LinearWeights>,
84    pub guidance_linear2: Option<LinearWeights>,
85}
86
87#[derive(Debug, Clone)]
88pub struct Flux2ModulationWeights {
89    pub linear: LinearWeights,
90}
91
92#[derive(Debug, Clone)]
93pub struct Flux2NormOutWeights {
94    pub linear: LinearWeights,
95}
96
97#[derive(Debug, Clone)]
98pub struct Flux2Weights {
99    pub x_embedder: LinearWeights,
100    pub context_embedder: LinearWeights,
101    pub time_guidance: Flux2TimestepGuidanceWeights,
102    /// Second timestep embedder for flow-map dual-time (`t` vs `t′`). When `None`, dual-time
103    /// forwards reuse [`Self::time_guidance`] for both (averaged embedding).
104    pub time_guidance_target: Option<Flux2TimestepGuidanceWeights>,
105    pub double_mod_img: Flux2ModulationWeights,
106    pub double_mod_txt: Flux2ModulationWeights,
107    pub single_mod: Flux2ModulationWeights,
108    pub transformer_blocks: Vec<Flux2DoubleBlockWeights>,
109    pub single_transformer_blocks: Vec<Flux2SingleBlockWeights>,
110    pub norm_out: Flux2NormOutWeights,
111    pub proj_out: LinearWeights,
112}
113
114/// Load denoiser weights from `.safetensors` or a single-file `.gguf` (BFL / ComfyUI naming).
115pub fn load_flux2_weight_map(path: &Path) -> Result<WeightMap> {
116    rlx_core::load_weight_map(path, rlx_core::FLUX_GGUF_ARCHES)
117}
118
119pub fn load_flux2_weights(path: &str, cfg: &Flux2Config) -> Result<Flux2Weights> {
120    let wm = load_flux2_weight_map(Path::new(path))?;
121    extract_flux2_weights(prepare_weight_map(wm), cfg)
122}
123
124pub fn extract_flux2_weights(wm: WeightMap, cfg: &Flux2Config) -> Result<Flux2Weights> {
125    extract_flux2_weights_with_opts(wm, cfg, ExtractFlux2Opts::default())
126}
127
128pub fn extract_flux2_weights_with_opts(
129    mut wm: WeightMap,
130    cfg: &Flux2Config,
131    opts: ExtractFlux2Opts<'_>,
132) -> Result<Flux2Weights> {
133    let guidance_embeds = cfg.guidance_embeds
134        && (wm.has("time_guidance_embed.guidance_embedder.linear_1.weight")
135            || wm.has("guidance_in.in_layer.weight"));
136
137    let x_embedder =
138        load_linear_with_opts(&mut wm, "x_embedder.weight", "x_embedder.bias", false, opts)?;
139    let context_embedder = load_linear_with_opts(
140        &mut wm,
141        "context_embedder.weight",
142        "context_embedder.bias",
143        false,
144        opts,
145    )?;
146    let time_guidance =
147        load_time_guidance_block(&mut wm, "time_guidance_embed", guidance_embeds, opts)?;
148    let time_guidance_target =
149        try_load_time_guidance_block(&mut wm, "time_guidance_embed_target", guidance_embeds, opts)?
150            .or_else(|| {
151                if opts.dual_time_embedder {
152                    Some(time_guidance.clone())
153                } else {
154                    None
155                }
156            });
157    let double_mod_img = Flux2ModulationWeights {
158        linear: load_linear_with_opts(
159            &mut wm,
160            "double_stream_modulation_img.linear.weight",
161            "double_stream_modulation_img.linear.bias",
162            false,
163            opts,
164        )?,
165    };
166    let double_mod_txt = Flux2ModulationWeights {
167        linear: load_linear_with_opts(
168            &mut wm,
169            "double_stream_modulation_txt.linear.weight",
170            "double_stream_modulation_txt.linear.bias",
171            false,
172            opts,
173        )?,
174    };
175    let single_mod = Flux2ModulationWeights {
176        linear: load_linear_with_opts(
177            &mut wm,
178            "single_stream_modulation.linear.weight",
179            "single_stream_modulation.linear.bias",
180            false,
181            opts,
182        )?,
183    };
184
185    let mut transformer_blocks = Vec::with_capacity(cfg.num_layers);
186    for i in 0..cfg.num_layers {
187        let p = format!("transformer_blocks.{i}");
188        transformer_blocks.push(Flux2DoubleBlockWeights {
189            attn: load_dual_attn(&mut wm, &p, opts)?,
190            ff: load_ff(&mut wm, &format!("{p}.ff"), opts)?,
191            ff_context: load_ff(&mut wm, &format!("{p}.ff_context"), opts)?,
192        });
193    }
194
195    let mut single_transformer_blocks = Vec::with_capacity(cfg.num_single_layers);
196    for i in 0..cfg.num_single_layers {
197        let p = format!("single_transformer_blocks.{i}");
198        single_transformer_blocks.push(Flux2SingleBlockWeights {
199            attn: load_parallel_attn(&mut wm, &p, opts)?,
200        });
201    }
202
203    let norm_out = Flux2NormOutWeights {
204        linear: load_linear_with_opts(
205            &mut wm,
206            "norm_out.linear.weight",
207            "norm_out.linear.bias",
208            true,
209            opts,
210        )?,
211    };
212    let proj_out = load_linear_with_opts(&mut wm, "proj_out.weight", "proj_out.bias", false, opts)?;
213
214    Ok(Flux2Weights {
215        x_embedder,
216        context_embedder,
217        time_guidance,
218        time_guidance_target,
219        double_mod_img,
220        double_mod_txt,
221        single_mod,
222        transformer_blocks,
223        single_transformer_blocks,
224        norm_out,
225        proj_out,
226    })
227}
228
229fn load_ff(
230    wm: &mut WeightMap,
231    prefix: &str,
232    opts: ExtractFlux2Opts<'_>,
233) -> Result<Flux2FeedForwardWeights> {
234    Ok(Flux2FeedForwardWeights {
235        linear_in: load_linear_with_opts(
236            wm,
237            &format!("{prefix}.linear_in.weight"),
238            &format!("{prefix}.linear_in.bias"),
239            true,
240            opts,
241        )?,
242        linear_out: load_linear_with_opts(
243            wm,
244            &format!("{prefix}.linear_out.weight"),
245            &format!("{prefix}.linear_out.bias"),
246            true,
247            opts,
248        )?,
249    })
250}
251
252fn load_dual_attn(
253    wm: &mut WeightMap,
254    prefix: &str,
255    opts: ExtractFlux2Opts<'_>,
256) -> Result<Flux2DualAttnWeights> {
257    let ap = format!("{prefix}.attn");
258    Ok(Flux2DualAttnWeights {
259        to_q: load_linear_with_opts(
260            wm,
261            &format!("{ap}.to_q.weight"),
262            &format!("{ap}.to_q.bias"),
263            true,
264            opts,
265        )?,
266        to_k: load_linear_with_opts(
267            wm,
268            &format!("{ap}.to_k.weight"),
269            &format!("{ap}.to_k.bias"),
270            true,
271            opts,
272        )?,
273        to_v: load_linear_with_opts(
274            wm,
275            &format!("{ap}.to_v.weight"),
276            &format!("{ap}.to_v.bias"),
277            true,
278            opts,
279        )?,
280        norm_q: load_rms(wm, &format!("{ap}.norm_q.weight"))?,
281        norm_k: load_rms(wm, &format!("{ap}.norm_k.weight"))?,
282        add_q: load_linear_with_opts(
283            wm,
284            &format!("{ap}.add_q_proj.weight"),
285            &format!("{ap}.add_q_proj.bias"),
286            true,
287            opts,
288        )?,
289        add_k: load_linear_with_opts(
290            wm,
291            &format!("{ap}.add_k_proj.weight"),
292            &format!("{ap}.add_k_proj.bias"),
293            true,
294            opts,
295        )?,
296        add_v: load_linear_with_opts(
297            wm,
298            &format!("{ap}.add_v_proj.weight"),
299            &format!("{ap}.add_v_proj.bias"),
300            true,
301            opts,
302        )?,
303        norm_added_q: load_rms(wm, &format!("{ap}.norm_added_q.weight"))?,
304        norm_added_k: load_rms(wm, &format!("{ap}.norm_added_k.weight"))?,
305        to_out: load_linear_with_opts(
306            wm,
307            &format!("{ap}.to_out.0.weight"),
308            &format!("{ap}.to_out.0.bias"),
309            true,
310            opts,
311        )?,
312        to_add_out: load_linear_with_opts(
313            wm,
314            &format!("{ap}.to_add_out.weight"),
315            &format!("{ap}.to_add_out.bias"),
316            true,
317            opts,
318        )?,
319    })
320}
321
322fn load_parallel_attn(
323    wm: &mut WeightMap,
324    prefix: &str,
325    opts: ExtractFlux2Opts<'_>,
326) -> Result<Flux2ParallelAttnWeights> {
327    let ap = format!("{prefix}.attn");
328    Ok(Flux2ParallelAttnWeights {
329        to_qkv_mlp: load_linear_with_opts(
330            wm,
331            &format!("{ap}.to_qkv_mlp_proj.weight"),
332            &format!("{ap}.to_qkv_mlp_proj.bias"),
333            true,
334            opts,
335        )?,
336        norm_q: load_rms(wm, &format!("{ap}.norm_q.weight"))?,
337        norm_k: load_rms(wm, &format!("{ap}.norm_k.weight"))?,
338        to_out: load_linear_with_opts(
339            wm,
340            &format!("{ap}.to_out.weight"),
341            &format!("{ap}.to_out.bias"),
342            true,
343            opts,
344        )?,
345    })
346}
347
348pub(crate) fn load_rms(wm: &mut WeightMap, key: &str) -> Result<RmsNormWeight> {
349    let (scale, shape) = wm.take(key).with_context(|| format!("missing {key}"))?;
350    ensure!(shape.len() == 1, "{key}: expected 1D scale");
351    Ok(RmsNormWeight { scale })
352}
353
354#[derive(Copy, Clone, Default)]
355pub struct ExtractFlux2Opts<'a> {
356    pub typed_linears: Option<&'a crate::typed_linear::TypedLinearStore>,
357    pub packed_linears: Option<&'a crate::packed::Flux2PackedParams>,
358    /// Clone [`Flux2TimestepGuidanceWeights`] for `t′` when no `time_guidance_embed_target` tensors.
359    pub dual_time_embedder: bool,
360}
361
362fn load_time_guidance_block(
363    wm: &mut WeightMap,
364    prefix: &str,
365    guidance_embeds: bool,
366    opts: ExtractFlux2Opts<'_>,
367) -> Result<Flux2TimestepGuidanceWeights> {
368    Ok(Flux2TimestepGuidanceWeights {
369        timestep_linear1: load_linear_with_opts(
370            wm,
371            &format!("{prefix}.timestep_embedder.linear_1.weight"),
372            &format!("{prefix}.timestep_embedder.linear_1.bias"),
373            true,
374            opts,
375        )?,
376        timestep_linear2: load_linear_with_opts(
377            wm,
378            &format!("{prefix}.timestep_embedder.linear_2.weight"),
379            &format!("{prefix}.timestep_embedder.linear_2.bias"),
380            true,
381            opts,
382        )?,
383        guidance_linear1: if guidance_embeds {
384            Some(load_linear_with_opts(
385                wm,
386                &format!("{prefix}.guidance_embedder.linear_1.weight"),
387                &format!("{prefix}.guidance_embedder.linear_1.bias"),
388                true,
389                opts,
390            )?)
391        } else {
392            None
393        },
394        guidance_linear2: if guidance_embeds {
395            Some(load_linear_with_opts(
396                wm,
397                &format!("{prefix}.guidance_embedder.linear_2.weight"),
398                &format!("{prefix}.guidance_embedder.linear_2.bias"),
399                true,
400                opts,
401            )?)
402        } else {
403            None
404        },
405    })
406}
407
408fn try_load_time_guidance_block(
409    wm: &mut WeightMap,
410    prefix: &str,
411    guidance_embeds: bool,
412    opts: ExtractFlux2Opts<'_>,
413) -> Result<Option<Flux2TimestepGuidanceWeights>> {
414    let w1 = format!("{prefix}.timestep_embedder.linear_1.weight");
415    if !wm.has(&w1) {
416        return Ok(None);
417    }
418    Ok(Some(load_time_guidance_block(
419        wm,
420        prefix,
421        guidance_embeds,
422        opts,
423    )?))
424}
425
426pub(crate) fn load_linear(
427    wm: &mut WeightMap,
428    w_key: &str,
429    b_key: &str,
430    expect_bias: bool,
431) -> Result<LinearWeights> {
432    load_linear_with_opts(wm, w_key, b_key, expect_bias, ExtractFlux2Opts::default())
433}
434
435pub(crate) fn load_linear_with_opts(
436    wm: &mut WeightMap,
437    w_key: &str,
438    b_key: &str,
439    _expect_bias: bool,
440    opts: ExtractFlux2Opts<'_>,
441) -> Result<LinearWeights> {
442    let prefix = w_key.strip_suffix(".weight").unwrap_or(w_key);
443    if !wm.has(w_key) {
444        if let Some(tl) = opts.typed_linears.and_then(|t| t.get(prefix)) {
445            return Ok(LinearWeights {
446                w_t: Vec::new(),
447                in_dim: tl.in_dim,
448                out_dim: tl.out_dim,
449                bias: tl.bias.clone(),
450            });
451        }
452        if let Some(p) = opts.packed_linears.and_then(|m| m.get_nvfp4(prefix)) {
453            return Ok(LinearWeights {
454                w_t: Vec::new(),
455                in_dim: p.in_dim,
456                out_dim: p.out_dim,
457                bias: p.bias.clone(),
458            });
459        }
460        if let Some(p) = opts.packed_linears.and_then(|m| m.get_gguf(prefix)) {
461            return Ok(LinearWeights {
462                w_t: Vec::new(),
463                in_dim: p.in_dim,
464                out_dim: p.out_dim,
465                bias: p.bias.clone(),
466            });
467        }
468    }
469    let (w_t, shape) = wm
470        .take_transposed(w_key)
471        .with_context(|| format!("missing {w_key}"))?;
472    ensure!(shape.len() == 2, "{w_key}: expected 2D");
473    let out_dim = shape[1];
474    let in_dim = shape[0];
475    let bias = if wm.has(b_key) {
476        let (b, bshape) = wm.take(b_key)?;
477        ensure!(bshape == vec![out_dim], "{b_key}: bias shape");
478        b
479    } else {
480        vec![0.0f32; out_dim]
481    };
482    Ok(LinearWeights {
483        w_t,
484        in_dim,
485        out_dim,
486        bias,
487    })
488}