1use super::super::weights::{LinearWeights, RmsNormWeight, load_linear, load_rms};
19use anyhow::{Context, Result, ensure};
20use rlx_core::weight_map::WeightMap;
21use rlx_qwen3::Qwen3Config;
22use std::collections::HashMap;
23use std::path::Path;
24
25#[derive(Debug, Clone)]
26pub struct Flux2TextEncoderMlpWeights {
27 pub gate: LinearWeights,
28 pub up: LinearWeights,
29 pub down: LinearWeights,
30}
31
32#[derive(Debug, Clone)]
33pub struct Flux2TextEncoderAttnWeights {
34 pub q: LinearWeights,
35 pub k: LinearWeights,
36 pub v: LinearWeights,
37 pub o: LinearWeights,
38 pub q_norm: RmsNormWeight,
39 pub k_norm: RmsNormWeight,
40}
41
42#[derive(Debug, Clone)]
43pub struct Flux2TextEncoderLayerWeights {
44 pub input_layernorm: RmsNormWeight,
45 pub post_attention_layernorm: RmsNormWeight,
46 pub attn: Flux2TextEncoderAttnWeights,
47 pub mlp: Flux2TextEncoderMlpWeights,
48}
49
50#[derive(Debug, Clone)]
51pub struct Flux2TextEncoderWeights {
52 pub embed_tokens: (Vec<f32>, usize, usize),
53 pub norm: RmsNormWeight,
54 pub layers: Vec<Flux2TextEncoderLayerWeights>,
55}
56
57fn normalize_text_encoder_keys(mut wm: WeightMap) -> WeightMap {
58 wm.remap_keys(|k| k.strip_prefix("model.").unwrap_or(&k).to_string());
59 wm
60}
61
62pub fn load_text_encoder_weights(
63 path: &Path,
64 cfg: &Qwen3Config,
65) -> Result<Flux2TextEncoderWeights> {
66 let wm = if path.is_dir() {
67 WeightMap::from_safetensors_dir(path)?
68 } else {
69 WeightMap::from_file(path.to_str().context("non-utf8 path")?)?
70 };
71 extract_text_encoder_weights(normalize_text_encoder_keys(wm), cfg)
72}
73
74pub fn extract_text_encoder_weights(
75 mut wm: WeightMap,
76 cfg: &Qwen3Config,
77) -> Result<Flux2TextEncoderWeights> {
78 let (embed_data, embed_shape) = wm.take("embed_tokens.weight")?;
79 ensure!(
80 embed_shape.len() == 2,
81 "embed_tokens.weight: expected [vocab, hidden]"
82 );
83 let vocab = embed_shape[0];
84 let hidden = embed_shape[1];
85 ensure!(
86 hidden == cfg.hidden_size,
87 "embed hidden {} != config {}",
88 hidden,
89 cfg.hidden_size
90 );
91
92 let norm = load_rms(&mut wm, "norm.weight")?;
93 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
94 for i in 0..cfg.num_hidden_layers {
95 let lp = format!("layers.{i}");
96 layers.push(Flux2TextEncoderLayerWeights {
97 input_layernorm: load_rms(&mut wm, &format!("{lp}.input_layernorm.weight"))?,
98 post_attention_layernorm: load_rms(
99 &mut wm,
100 &format!("{lp}.post_attention_layernorm.weight"),
101 )?,
102 attn: Flux2TextEncoderAttnWeights {
103 q: load_linear(
104 &mut wm,
105 &format!("{lp}.self_attn.q_proj.weight"),
106 &format!("{lp}.self_attn.q_proj.bias"),
107 cfg.attention_bias,
108 )?,
109 k: load_linear(
110 &mut wm,
111 &format!("{lp}.self_attn.k_proj.weight"),
112 &format!("{lp}.self_attn.k_proj.bias"),
113 cfg.attention_bias,
114 )?,
115 v: load_linear(
116 &mut wm,
117 &format!("{lp}.self_attn.v_proj.weight"),
118 &format!("{lp}.self_attn.v_proj.bias"),
119 cfg.attention_bias,
120 )?,
121 o: load_linear(
122 &mut wm,
123 &format!("{lp}.self_attn.o_proj.weight"),
124 &format!("{lp}.self_attn.o_proj.bias"),
125 cfg.attention_bias,
126 )?,
127 q_norm: load_rms(&mut wm, &format!("{lp}.self_attn.q_norm.weight"))?,
128 k_norm: load_rms(&mut wm, &format!("{lp}.self_attn.k_norm.weight"))?,
129 },
130 mlp: Flux2TextEncoderMlpWeights {
131 gate: load_linear(
132 &mut wm,
133 &format!("{lp}.mlp.gate_proj.weight"),
134 &format!("{lp}.mlp.gate_proj.bias"),
135 false,
136 )?,
137 up: load_linear(
138 &mut wm,
139 &format!("{lp}.mlp.up_proj.weight"),
140 &format!("{lp}.mlp.up_proj.bias"),
141 false,
142 )?,
143 down: load_linear(
144 &mut wm,
145 &format!("{lp}.mlp.down_proj.weight"),
146 &format!("{lp}.mlp.down_proj.bias"),
147 false,
148 )?,
149 },
150 });
151 }
152
153 Ok(Flux2TextEncoderWeights {
154 embed_tokens: (embed_data, vocab, hidden),
155 norm,
156 layers,
157 })
158}
159
160pub fn synthetic_text_encoder_weights(cfg: &Qwen3Config) -> Flux2TextEncoderWeights {
162 let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
163 let h = cfg.hidden_size;
164 let vocab = cfg.vocab_size;
165 let ff = cfg.intermediate_size;
166 let hd = cfg.head_dim;
167 let nh = cfg.num_attention_heads;
168 let nkv = cfg.num_key_value_heads;
169
170 t.insert(
171 "embed_tokens.weight".into(),
172 (vec![0.01f32; vocab * h], vec![vocab, h]),
173 );
174 t.insert("norm.weight".into(), (vec![1.0f32; h], vec![h]));
175
176 for i in 0..cfg.num_hidden_layers {
177 let lp = format!("layers.{i}");
178 for (name, out_d, in_d) in [
179 (format!("{lp}.self_attn.q_proj"), nh * hd, h),
180 (format!("{lp}.self_attn.k_proj"), nkv * hd, h),
181 (format!("{lp}.self_attn.v_proj"), nkv * hd, h),
182 (format!("{lp}.self_attn.o_proj"), h, nh * hd),
183 (format!("{lp}.mlp.gate_proj"), ff, h),
184 (format!("{lp}.mlp.up_proj"), ff, h),
185 (format!("{lp}.mlp.down_proj"), h, ff),
186 ] {
187 t.insert(
188 format!("{name}.weight"),
189 (vec![0.01f32; out_d * in_d], vec![out_d, in_d]),
190 );
191 if name.contains("self_attn") {
192 t.insert(format!("{name}.bias"), (vec![0.0f32; out_d], vec![out_d]));
193 }
194 }
195 for suffix in [
196 "input_layernorm",
197 "post_attention_layernorm",
198 "self_attn.q_norm",
199 "self_attn.k_norm",
200 ] {
201 let dim = if suffix.contains("norm") && suffix.contains("attn") {
202 hd
203 } else {
204 h
205 };
206 t.insert(
207 format!("{lp}.{suffix}.weight"),
208 (vec![1.0f32; dim], vec![dim]),
209 );
210 }
211 }
212
213 extract_text_encoder_weights(WeightMap::from_tensors(t), cfg).expect("synthetic text encoder")
214}