Skip to main content

rlx_flux2/text_encoder/
weights.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! FLUX.2 text encoder (Qwen3) weight loading.
17
18use super::super::weights::{LinearWeights, RmsNormWeight, load_linear, load_rms};
19use anyhow::{Context, Result, ensure};
20use rlx_core::weight_map::WeightMap;
21use rlx_qwen3::Qwen3Config;
22use std::collections::HashMap;
23use std::path::Path;
24
25#[derive(Debug, Clone)]
26pub struct Flux2TextEncoderMlpWeights {
27    pub gate: LinearWeights,
28    pub up: LinearWeights,
29    pub down: LinearWeights,
30}
31
32#[derive(Debug, Clone)]
33pub struct Flux2TextEncoderAttnWeights {
34    pub q: LinearWeights,
35    pub k: LinearWeights,
36    pub v: LinearWeights,
37    pub o: LinearWeights,
38    pub q_norm: RmsNormWeight,
39    pub k_norm: RmsNormWeight,
40}
41
42#[derive(Debug, Clone)]
43pub struct Flux2TextEncoderLayerWeights {
44    pub input_layernorm: RmsNormWeight,
45    pub post_attention_layernorm: RmsNormWeight,
46    pub attn: Flux2TextEncoderAttnWeights,
47    pub mlp: Flux2TextEncoderMlpWeights,
48}
49
50#[derive(Debug, Clone)]
51pub struct Flux2TextEncoderWeights {
52    pub embed_tokens: (Vec<f32>, usize, usize),
53    pub norm: RmsNormWeight,
54    pub layers: Vec<Flux2TextEncoderLayerWeights>,
55}
56
57fn normalize_text_encoder_keys(mut wm: WeightMap) -> WeightMap {
58    wm.remap_keys(|k| k.strip_prefix("model.").unwrap_or(&k).to_string());
59    wm
60}
61
62pub fn load_text_encoder_weights(
63    path: &Path,
64    cfg: &Qwen3Config,
65) -> Result<Flux2TextEncoderWeights> {
66    let wm = if path.is_dir() {
67        WeightMap::from_safetensors_dir(path)?
68    } else {
69        WeightMap::from_file(path.to_str().context("non-utf8 path")?)?
70    };
71    extract_text_encoder_weights(normalize_text_encoder_keys(wm), cfg)
72}
73
74pub fn extract_text_encoder_weights(
75    mut wm: WeightMap,
76    cfg: &Qwen3Config,
77) -> Result<Flux2TextEncoderWeights> {
78    let (embed_data, embed_shape) = wm.take("embed_tokens.weight")?;
79    ensure!(
80        embed_shape.len() == 2,
81        "embed_tokens.weight: expected [vocab, hidden]"
82    );
83    let vocab = embed_shape[0];
84    let hidden = embed_shape[1];
85    ensure!(
86        hidden == cfg.hidden_size,
87        "embed hidden {} != config {}",
88        hidden,
89        cfg.hidden_size
90    );
91
92    let norm = load_rms(&mut wm, "norm.weight")?;
93    let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
94    for i in 0..cfg.num_hidden_layers {
95        let lp = format!("layers.{i}");
96        layers.push(Flux2TextEncoderLayerWeights {
97            input_layernorm: load_rms(&mut wm, &format!("{lp}.input_layernorm.weight"))?,
98            post_attention_layernorm: load_rms(
99                &mut wm,
100                &format!("{lp}.post_attention_layernorm.weight"),
101            )?,
102            attn: Flux2TextEncoderAttnWeights {
103                q: load_linear(
104                    &mut wm,
105                    &format!("{lp}.self_attn.q_proj.weight"),
106                    &format!("{lp}.self_attn.q_proj.bias"),
107                    cfg.attention_bias,
108                )?,
109                k: load_linear(
110                    &mut wm,
111                    &format!("{lp}.self_attn.k_proj.weight"),
112                    &format!("{lp}.self_attn.k_proj.bias"),
113                    cfg.attention_bias,
114                )?,
115                v: load_linear(
116                    &mut wm,
117                    &format!("{lp}.self_attn.v_proj.weight"),
118                    &format!("{lp}.self_attn.v_proj.bias"),
119                    cfg.attention_bias,
120                )?,
121                o: load_linear(
122                    &mut wm,
123                    &format!("{lp}.self_attn.o_proj.weight"),
124                    &format!("{lp}.self_attn.o_proj.bias"),
125                    cfg.attention_bias,
126                )?,
127                q_norm: load_rms(&mut wm, &format!("{lp}.self_attn.q_norm.weight"))?,
128                k_norm: load_rms(&mut wm, &format!("{lp}.self_attn.k_norm.weight"))?,
129            },
130            mlp: Flux2TextEncoderMlpWeights {
131                gate: load_linear(
132                    &mut wm,
133                    &format!("{lp}.mlp.gate_proj.weight"),
134                    &format!("{lp}.mlp.gate_proj.bias"),
135                    false,
136                )?,
137                up: load_linear(
138                    &mut wm,
139                    &format!("{lp}.mlp.up_proj.weight"),
140                    &format!("{lp}.mlp.up_proj.bias"),
141                    false,
142                )?,
143                down: load_linear(
144                    &mut wm,
145                    &format!("{lp}.mlp.down_proj.weight"),
146                    &format!("{lp}.mlp.down_proj.bias"),
147                    false,
148                )?,
149            },
150        });
151    }
152
153    Ok(Flux2TextEncoderWeights {
154        embed_tokens: (embed_data, vocab, hidden),
155        norm,
156        layers,
157    })
158}
159
160/// Tiny zero weights for unit tests (2 layers, hidden 8).
161pub fn synthetic_text_encoder_weights(cfg: &Qwen3Config) -> Flux2TextEncoderWeights {
162    let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
163    let h = cfg.hidden_size;
164    let vocab = cfg.vocab_size;
165    let ff = cfg.intermediate_size;
166    let hd = cfg.head_dim;
167    let nh = cfg.num_attention_heads;
168    let nkv = cfg.num_key_value_heads;
169
170    t.insert(
171        "embed_tokens.weight".into(),
172        (vec![0.01f32; vocab * h], vec![vocab, h]),
173    );
174    t.insert("norm.weight".into(), (vec![1.0f32; h], vec![h]));
175
176    for i in 0..cfg.num_hidden_layers {
177        let lp = format!("layers.{i}");
178        for (name, out_d, in_d) in [
179            (format!("{lp}.self_attn.q_proj"), nh * hd, h),
180            (format!("{lp}.self_attn.k_proj"), nkv * hd, h),
181            (format!("{lp}.self_attn.v_proj"), nkv * hd, h),
182            (format!("{lp}.self_attn.o_proj"), h, nh * hd),
183            (format!("{lp}.mlp.gate_proj"), ff, h),
184            (format!("{lp}.mlp.up_proj"), ff, h),
185            (format!("{lp}.mlp.down_proj"), h, ff),
186        ] {
187            t.insert(
188                format!("{name}.weight"),
189                (vec![0.01f32; out_d * in_d], vec![out_d, in_d]),
190            );
191            if name.contains("self_attn") {
192                t.insert(format!("{name}.bias"), (vec![0.0f32; out_d], vec![out_d]));
193            }
194        }
195        for suffix in [
196            "input_layernorm",
197            "post_attention_layernorm",
198            "self_attn.q_norm",
199            "self_attn.k_norm",
200        ] {
201            let dim = if suffix.contains("norm") && suffix.contains("attn") {
202                hd
203            } else {
204                h
205            };
206            t.insert(
207                format!("{lp}.{suffix}.weight"),
208                (vec![1.0f32; dim], vec![dim]),
209            );
210        }
211    }
212
213    extract_text_encoder_weights(WeightMap::from_tensors(t), cfg).expect("synthetic text encoder")
214}