1use super::prompt::DEFAULT_TEXT_ENCODER_LAYERS;
19use super::weights::{
20 Flux2TextEncoderAttnWeights, Flux2TextEncoderLayerWeights, Flux2TextEncoderMlpWeights,
21 Flux2TextEncoderWeights,
22};
23use anyhow::{Result, ensure};
24use rlx_qwen3::Qwen3Config;
25use rlx_tensor::{layer_norm, linear};
26
27#[derive(Debug, Clone)]
28pub struct Flux2PromptOutput {
29 pub prompt_embeds: Vec<f32>,
30 pub seq_len: usize,
31 pub joint_dim: usize,
32}
33
34fn rms_norm(x: &[f32], scale: &[f32], dim: usize, eps: f32) -> Result<Vec<f32>> {
35 let beta = vec![0.0f32; dim];
36 layer_norm(x, scale, &beta, dim, eps)
37}
38
39fn rms_norm_heads(
40 x: &[f32],
41 scale: &[f32],
42 batch: usize,
43 seq: usize,
44 heads: usize,
45 head_dim: usize,
46 eps: f32,
47) -> Result<Vec<f32>> {
48 let mut out = vec![0.0f32; x.len()];
49 for b in 0..batch {
50 for t in 0..seq {
51 for h in 0..heads {
52 let off = ((b * seq + t) * heads + h) * head_dim;
53 let row = rms_norm(&x[off..off + head_dim], scale, head_dim, eps)?;
54 out[off..off + head_dim].copy_from_slice(&row);
55 }
56 }
57 }
58 Ok(out)
59}
60
61fn mlp_forward(
62 mlp: &Flux2TextEncoderMlpWeights,
63 x: &[f32],
64 rows: usize,
65 _dim: usize,
66) -> Result<Vec<f32>> {
67 let gate = linear(
68 x,
69 rows,
70 mlp.gate.in_dim,
71 &mlp.gate.w_t,
72 mlp.gate.out_dim,
73 &mlp.gate.bias,
74 )?;
75 let up = linear(
76 x,
77 rows,
78 mlp.up.in_dim,
79 &mlp.up.w_t,
80 mlp.up.out_dim,
81 &mlp.up.bias,
82 )?;
83 let half = mlp.gate.out_dim;
84 let mut h = vec![0.0f32; rows * half];
85 for r in 0..rows {
86 for c in 0..half {
87 let a = gate[r * half + c];
88 let b = up[r * half + c];
89 let s = a / (1.0 + (-a).exp());
90 h[r * half + c] = s * b;
91 }
92 }
93 linear(
94 &h,
95 rows,
96 mlp.down.in_dim,
97 &mlp.down.w_t,
98 mlp.down.out_dim,
99 &mlp.down.bias,
100 )
101}
102
103fn rope_cache(cfg: &Qwen3Config, seq: usize) -> (Vec<f32>, Vec<f32>) {
104 let dh = cfg.head_dim;
105 let half = dh / 2;
106 let mut cos = vec![0.0f32; seq * dh];
107 let mut sin = vec![0.0f32; seq * dh];
108 for pos in 0..seq {
109 for i in 0..half {
110 let freq = 1.0 / cfg.rope_theta.powf((2 * i) as f64 / dh as f64);
111 let angle = pos as f64 * freq;
112 let c = angle.cos() as f32;
113 let s = angle.sin() as f32;
114 cos[pos * dh + 2 * i] = c;
115 cos[pos * dh + 2 * i + 1] = c;
116 sin[pos * dh + 2 * i] = s;
117 sin[pos * dh + 2 * i + 1] = s;
118 }
119 }
120 (cos, sin)
121}
122
123fn apply_rope_row(x: &mut [f32], cos: &[f32], sin: &[f32], head_dim: usize) {
124 let mut rotated = vec![0.0f32; head_dim];
125 let pairs = head_dim / 2;
126 for i in 0..pairs {
127 let xr = x[2 * i];
128 let xi = x[2 * i + 1];
129 rotated[2 * i] = -xi;
130 rotated[2 * i + 1] = xr;
131 }
132 for d in 0..head_dim {
133 x[d] = x[d] * cos[d] + rotated[d] * sin[d];
134 }
135}
136
137fn repeat_kv(
138 k: &[f32],
139 v: &[f32],
140 batch: usize,
141 seq: usize,
142 n_kv: usize,
143 n_heads: usize,
144 head_dim: usize,
145) -> (Vec<f32>, Vec<f32>) {
146 let group = n_heads / n_kv;
147 let mut k_out = vec![0.0f32; batch * seq * n_heads * head_dim];
148 let mut v_out = vec![0.0f32; batch * seq * n_heads * head_dim];
149 for b in 0..batch {
150 for t in 0..seq {
151 for h in 0..n_heads {
152 let kv_h = h / group;
153 let src = ((b * seq + t) * n_kv + kv_h) * head_dim;
154 let dst = ((b * seq + t) * n_heads + h) * head_dim;
155 k_out[dst..dst + head_dim].copy_from_slice(&k[src..src + head_dim]);
156 v_out[dst..dst + head_dim].copy_from_slice(&v[src..src + head_dim]);
157 }
158 }
159 }
160 (k_out, v_out)
161}
162
163fn causal_attention(
164 q: &[f32],
165 k: &[f32],
166 v: &[f32],
167 batch: usize,
168 seq: usize,
169 n_heads: usize,
170 head_dim: usize,
171 scale: f32,
172) -> Vec<f32> {
173 let mut out = vec![0.0f32; batch * seq * n_heads * head_dim];
174 for b in 0..batch {
175 for h in 0..n_heads {
176 for i in 0..seq {
177 let q_off = ((b * seq + i) * n_heads + h) * head_dim;
178 let q_h = &q[q_off..q_off + head_dim];
179 let mut scores = vec![0.0f32; i + 1];
180 let mut max_s = f32::NEG_INFINITY;
181 for j in 0..=i {
182 let k_off = ((b * seq + j) * n_heads + h) * head_dim;
183 let mut dot = 0.0f32;
184 for d in 0..head_dim {
185 dot += q_h[d] * k[k_off + d];
186 }
187 let s = dot * scale;
188 scores[j] = s;
189 max_s = max_s.max(s);
190 }
191 let mut sum = 0.0f32;
192 let mut probs = vec![0.0f32; i + 1];
193 for j in 0..=i {
194 let e = (scores[j] - max_s).exp();
195 probs[j] = e;
196 sum += e;
197 }
198 for j in 0..=i {
199 probs[j] /= sum;
200 }
201 let o_off = ((b * seq + i) * n_heads + h) * head_dim;
202 for d in 0..head_dim {
203 let mut acc = 0.0f32;
204 for j in 0..=i {
205 let v_off = ((b * seq + j) * n_heads + h) * head_dim;
206 acc += probs[j] * v[v_off + d];
207 }
208 out[o_off + d] = acc;
209 }
210 }
211 }
212 }
213 out
214}
215
216fn attn_forward(
217 attn: &Flux2TextEncoderAttnWeights,
218 x: &[f32],
219 cos: &[f32],
220 sin: &[f32],
221 batch: usize,
222 seq: usize,
223 cfg: &Qwen3Config,
224) -> Result<Vec<f32>> {
225 let nh = cfg.num_attention_heads;
226 let nkv = cfg.num_key_value_heads;
227 let hd = cfg.head_dim;
228 let rows = batch * seq;
229
230 let mut q = linear(
231 x,
232 rows,
233 attn.q.in_dim,
234 &attn.q.w_t,
235 attn.q.out_dim,
236 &attn.q.bias,
237 )?;
238 let mut k = linear(
239 x,
240 rows,
241 attn.k.in_dim,
242 &attn.k.w_t,
243 attn.k.out_dim,
244 &attn.k.bias,
245 )?;
246 let v = linear(
247 x,
248 rows,
249 attn.v.in_dim,
250 &attn.v.w_t,
251 attn.v.out_dim,
252 &attn.v.bias,
253 )?;
254
255 q = rms_norm_heads(
256 &q,
257 &attn.q_norm.scale,
258 batch,
259 seq,
260 nh,
261 hd,
262 cfg.rms_norm_eps as f32,
263 )?;
264 k = rms_norm_heads(
265 &k,
266 &attn.k_norm.scale,
267 batch,
268 seq,
269 nkv,
270 hd,
271 cfg.rms_norm_eps as f32,
272 )?;
273
274 for t in 0..seq {
275 let c = &cos[t * hd..(t + 1) * hd];
276 let s = &sin[t * hd..(t + 1) * hd];
277 for b in 0..batch {
278 for h in 0..nh {
279 let off = ((b * seq + t) * nh + h) * hd;
280 apply_rope_row(&mut q[off..off + hd], c, s, hd);
281 }
282 for h in 0..nkv {
283 let off = ((b * seq + t) * nkv + h) * hd;
284 apply_rope_row(&mut k[off..off + hd], c, s, hd);
285 }
286 }
287 }
288
289 let (k_rep, v_rep) = repeat_kv(&k, &v, batch, seq, nkv, nh, hd);
290 let scale = 1.0 / (hd as f32).sqrt();
291 let attn_out = causal_attention(&q, &k_rep, &v_rep, batch, seq, nh, hd, scale);
292 linear(
293 &attn_out,
294 rows,
295 attn.o.in_dim,
296 &attn.o.w_t,
297 attn.o.out_dim,
298 &attn.o.bias,
299 )
300}
301
302fn layer_forward(
303 layer: &Flux2TextEncoderLayerWeights,
304 x: &[f32],
305 cos: &[f32],
306 sin: &[f32],
307 batch: usize,
308 seq: usize,
309 cfg: &Qwen3Config,
310) -> Result<Vec<f32>> {
311 let h = cfg.hidden_size;
312 let rows = batch * seq;
313 let eps = cfg.rms_norm_eps as f32;
314
315 let normed = rms_norm(x, &layer.input_layernorm.scale, h, eps)?;
316 let attn_out = attn_forward(&layer.attn, &normed, cos, sin, batch, seq, cfg)?;
317 let mut hidden = vec![0.0f32; x.len()];
318 for i in 0..hidden.len() {
319 hidden[i] = x[i] + attn_out[i];
320 }
321
322 let normed2 = rms_norm(&hidden, &layer.post_attention_layernorm.scale, h, eps)?;
323 let mlp_out = mlp_forward(&layer.mlp, &normed2, rows, h)?;
324 for i in 0..hidden.len() {
325 hidden[i] += mlp_out[i];
326 }
327 Ok(hidden)
328}
329
330fn embed_tokens(
331 embed: &(Vec<f32>, usize, usize),
332 input_ids: &[u32],
333 batch: usize,
334 seq: usize,
335 hidden: usize,
336) -> Vec<f32> {
337 let (data, vocab, _) = embed;
338 let mut out = vec![0.0f32; batch * seq * hidden];
339 for b in 0..batch {
340 for t in 0..seq {
341 let id = input_ids[b * seq + t] as usize;
342 let id = id.min(vocab.saturating_sub(1));
343 let src = id * hidden;
344 let dst = (b * seq + t) * hidden;
345 out[dst..dst + hidden].copy_from_slice(&data[src..src + hidden]);
346 }
347 }
348 out
349}
350
351pub fn encode_prompt_embeds(
353 weights: &Flux2TextEncoderWeights,
354 cfg: &Qwen3Config,
355 input_ids: &[u32],
356 batch: usize,
357 seq: usize,
358 hidden_state_layers: &[usize],
359) -> Result<Flux2PromptOutput> {
360 ensure!(input_ids.len() == batch * seq, "input_ids length mismatch");
361 let (cos, sin) = rope_cache(cfg, seq);
362 let mut hidden = embed_tokens(
363 &weights.embed_tokens,
364 input_ids,
365 batch,
366 seq,
367 cfg.hidden_size,
368 );
369 let mut hidden_states: Vec<Vec<f32>> = vec![hidden.clone()];
370 for layer in &weights.layers {
371 hidden = layer_forward(layer, &hidden, &cos, &sin, batch, seq, cfg)?;
372 hidden_states.push(hidden.clone());
373 }
374 let eps = cfg.rms_norm_eps as f32;
375 let _ = rms_norm(&hidden, &weights.norm.scale, cfg.hidden_size, eps)?;
376
377 let h = cfg.hidden_size;
378 let joint_dim = h * hidden_state_layers.len();
379 let mut prompt_embeds = vec![0.0f32; batch * seq * joint_dim];
380 for b in 0..batch {
381 for t in 0..seq {
382 let mut off = 0usize;
383 for (li, &layer_idx) in hidden_state_layers.iter().enumerate() {
384 ensure!(
385 layer_idx < hidden_states.len(),
386 "hidden_state_layers[{li}]={layer_idx} out of range (len={})",
387 hidden_states.len()
388 );
389 let src = (b * seq + t) * h;
390 let dst = (b * seq + t) * joint_dim + off;
391 prompt_embeds[dst..dst + h]
392 .copy_from_slice(&hidden_states[layer_idx][src..src + h]);
393 off += h;
394 }
395 }
396 }
397 Ok(Flux2PromptOutput {
398 prompt_embeds,
399 seq_len: seq,
400 joint_dim,
401 })
402}
403
404pub fn encode_prompt_embeds_default_layers(
406 weights: &Flux2TextEncoderWeights,
407 cfg: &Qwen3Config,
408 input_ids: &[u32],
409 batch: usize,
410 seq: usize,
411) -> Result<Flux2PromptOutput> {
412 encode_prompt_embeds(
413 weights,
414 cfg,
415 input_ids,
416 batch,
417 seq,
418 DEFAULT_TEXT_ENCODER_LAYERS,
419 )
420}