Skip to main content

rlx_flux2/text_encoder/
forward.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Native CPU forward for the FLUX.2 Qwen3 text encoder.
17
18use super::prompt::DEFAULT_TEXT_ENCODER_LAYERS;
19use super::weights::{
20    Flux2TextEncoderAttnWeights, Flux2TextEncoderLayerWeights, Flux2TextEncoderMlpWeights,
21    Flux2TextEncoderWeights,
22};
23use anyhow::{Result, ensure};
24use rlx_qwen3::Qwen3Config;
25use rlx_tensor::{layer_norm, linear};
26
27#[derive(Debug, Clone)]
28pub struct Flux2PromptOutput {
29    pub prompt_embeds: Vec<f32>,
30    pub seq_len: usize,
31    pub joint_dim: usize,
32}
33
34fn rms_norm(x: &[f32], scale: &[f32], dim: usize, eps: f32) -> Result<Vec<f32>> {
35    let beta = vec![0.0f32; dim];
36    layer_norm(x, scale, &beta, dim, eps)
37}
38
39fn rms_norm_heads(
40    x: &[f32],
41    scale: &[f32],
42    batch: usize,
43    seq: usize,
44    heads: usize,
45    head_dim: usize,
46    eps: f32,
47) -> Result<Vec<f32>> {
48    let mut out = vec![0.0f32; x.len()];
49    for b in 0..batch {
50        for t in 0..seq {
51            for h in 0..heads {
52                let off = ((b * seq + t) * heads + h) * head_dim;
53                let row = rms_norm(&x[off..off + head_dim], scale, head_dim, eps)?;
54                out[off..off + head_dim].copy_from_slice(&row);
55            }
56        }
57    }
58    Ok(out)
59}
60
61fn mlp_forward(
62    mlp: &Flux2TextEncoderMlpWeights,
63    x: &[f32],
64    rows: usize,
65    _dim: usize,
66) -> Result<Vec<f32>> {
67    let gate = linear(
68        x,
69        rows,
70        mlp.gate.in_dim,
71        &mlp.gate.w_t,
72        mlp.gate.out_dim,
73        &mlp.gate.bias,
74    )?;
75    let up = linear(
76        x,
77        rows,
78        mlp.up.in_dim,
79        &mlp.up.w_t,
80        mlp.up.out_dim,
81        &mlp.up.bias,
82    )?;
83    let half = mlp.gate.out_dim;
84    let mut h = vec![0.0f32; rows * half];
85    for r in 0..rows {
86        for c in 0..half {
87            let a = gate[r * half + c];
88            let b = up[r * half + c];
89            let s = a / (1.0 + (-a).exp());
90            h[r * half + c] = s * b;
91        }
92    }
93    linear(
94        &h,
95        rows,
96        mlp.down.in_dim,
97        &mlp.down.w_t,
98        mlp.down.out_dim,
99        &mlp.down.bias,
100    )
101}
102
103fn rope_cache(cfg: &Qwen3Config, seq: usize) -> (Vec<f32>, Vec<f32>) {
104    let dh = cfg.head_dim;
105    let half = dh / 2;
106    let mut cos = vec![0.0f32; seq * dh];
107    let mut sin = vec![0.0f32; seq * dh];
108    for pos in 0..seq {
109        for i in 0..half {
110            let freq = 1.0 / cfg.rope_theta.powf((2 * i) as f64 / dh as f64);
111            let angle = pos as f64 * freq;
112            let c = angle.cos() as f32;
113            let s = angle.sin() as f32;
114            cos[pos * dh + 2 * i] = c;
115            cos[pos * dh + 2 * i + 1] = c;
116            sin[pos * dh + 2 * i] = s;
117            sin[pos * dh + 2 * i + 1] = s;
118        }
119    }
120    (cos, sin)
121}
122
123fn apply_rope_row(x: &mut [f32], cos: &[f32], sin: &[f32], head_dim: usize) {
124    let mut rotated = vec![0.0f32; head_dim];
125    let pairs = head_dim / 2;
126    for i in 0..pairs {
127        let xr = x[2 * i];
128        let xi = x[2 * i + 1];
129        rotated[2 * i] = -xi;
130        rotated[2 * i + 1] = xr;
131    }
132    for d in 0..head_dim {
133        x[d] = x[d] * cos[d] + rotated[d] * sin[d];
134    }
135}
136
137fn repeat_kv(
138    k: &[f32],
139    v: &[f32],
140    batch: usize,
141    seq: usize,
142    n_kv: usize,
143    n_heads: usize,
144    head_dim: usize,
145) -> (Vec<f32>, Vec<f32>) {
146    let group = n_heads / n_kv;
147    let mut k_out = vec![0.0f32; batch * seq * n_heads * head_dim];
148    let mut v_out = vec![0.0f32; batch * seq * n_heads * head_dim];
149    for b in 0..batch {
150        for t in 0..seq {
151            for h in 0..n_heads {
152                let kv_h = h / group;
153                let src = ((b * seq + t) * n_kv + kv_h) * head_dim;
154                let dst = ((b * seq + t) * n_heads + h) * head_dim;
155                k_out[dst..dst + head_dim].copy_from_slice(&k[src..src + head_dim]);
156                v_out[dst..dst + head_dim].copy_from_slice(&v[src..src + head_dim]);
157            }
158        }
159    }
160    (k_out, v_out)
161}
162
163fn causal_attention(
164    q: &[f32],
165    k: &[f32],
166    v: &[f32],
167    batch: usize,
168    seq: usize,
169    n_heads: usize,
170    head_dim: usize,
171    scale: f32,
172) -> Vec<f32> {
173    let mut out = vec![0.0f32; batch * seq * n_heads * head_dim];
174    for b in 0..batch {
175        for h in 0..n_heads {
176            for i in 0..seq {
177                let q_off = ((b * seq + i) * n_heads + h) * head_dim;
178                let q_h = &q[q_off..q_off + head_dim];
179                let mut scores = vec![0.0f32; i + 1];
180                let mut max_s = f32::NEG_INFINITY;
181                for j in 0..=i {
182                    let k_off = ((b * seq + j) * n_heads + h) * head_dim;
183                    let mut dot = 0.0f32;
184                    for d in 0..head_dim {
185                        dot += q_h[d] * k[k_off + d];
186                    }
187                    let s = dot * scale;
188                    scores[j] = s;
189                    max_s = max_s.max(s);
190                }
191                let mut sum = 0.0f32;
192                let mut probs = vec![0.0f32; i + 1];
193                for j in 0..=i {
194                    let e = (scores[j] - max_s).exp();
195                    probs[j] = e;
196                    sum += e;
197                }
198                for j in 0..=i {
199                    probs[j] /= sum;
200                }
201                let o_off = ((b * seq + i) * n_heads + h) * head_dim;
202                for d in 0..head_dim {
203                    let mut acc = 0.0f32;
204                    for j in 0..=i {
205                        let v_off = ((b * seq + j) * n_heads + h) * head_dim;
206                        acc += probs[j] * v[v_off + d];
207                    }
208                    out[o_off + d] = acc;
209                }
210            }
211        }
212    }
213    out
214}
215
216fn attn_forward(
217    attn: &Flux2TextEncoderAttnWeights,
218    x: &[f32],
219    cos: &[f32],
220    sin: &[f32],
221    batch: usize,
222    seq: usize,
223    cfg: &Qwen3Config,
224) -> Result<Vec<f32>> {
225    let nh = cfg.num_attention_heads;
226    let nkv = cfg.num_key_value_heads;
227    let hd = cfg.head_dim;
228    let rows = batch * seq;
229
230    let mut q = linear(
231        x,
232        rows,
233        attn.q.in_dim,
234        &attn.q.w_t,
235        attn.q.out_dim,
236        &attn.q.bias,
237    )?;
238    let mut k = linear(
239        x,
240        rows,
241        attn.k.in_dim,
242        &attn.k.w_t,
243        attn.k.out_dim,
244        &attn.k.bias,
245    )?;
246    let v = linear(
247        x,
248        rows,
249        attn.v.in_dim,
250        &attn.v.w_t,
251        attn.v.out_dim,
252        &attn.v.bias,
253    )?;
254
255    q = rms_norm_heads(
256        &q,
257        &attn.q_norm.scale,
258        batch,
259        seq,
260        nh,
261        hd,
262        cfg.rms_norm_eps as f32,
263    )?;
264    k = rms_norm_heads(
265        &k,
266        &attn.k_norm.scale,
267        batch,
268        seq,
269        nkv,
270        hd,
271        cfg.rms_norm_eps as f32,
272    )?;
273
274    for t in 0..seq {
275        let c = &cos[t * hd..(t + 1) * hd];
276        let s = &sin[t * hd..(t + 1) * hd];
277        for b in 0..batch {
278            for h in 0..nh {
279                let off = ((b * seq + t) * nh + h) * hd;
280                apply_rope_row(&mut q[off..off + hd], c, s, hd);
281            }
282            for h in 0..nkv {
283                let off = ((b * seq + t) * nkv + h) * hd;
284                apply_rope_row(&mut k[off..off + hd], c, s, hd);
285            }
286        }
287    }
288
289    let (k_rep, v_rep) = repeat_kv(&k, &v, batch, seq, nkv, nh, hd);
290    let scale = 1.0 / (hd as f32).sqrt();
291    let attn_out = causal_attention(&q, &k_rep, &v_rep, batch, seq, nh, hd, scale);
292    linear(
293        &attn_out,
294        rows,
295        attn.o.in_dim,
296        &attn.o.w_t,
297        attn.o.out_dim,
298        &attn.o.bias,
299    )
300}
301
302fn layer_forward(
303    layer: &Flux2TextEncoderLayerWeights,
304    x: &[f32],
305    cos: &[f32],
306    sin: &[f32],
307    batch: usize,
308    seq: usize,
309    cfg: &Qwen3Config,
310) -> Result<Vec<f32>> {
311    let h = cfg.hidden_size;
312    let rows = batch * seq;
313    let eps = cfg.rms_norm_eps as f32;
314
315    let normed = rms_norm(x, &layer.input_layernorm.scale, h, eps)?;
316    let attn_out = attn_forward(&layer.attn, &normed, cos, sin, batch, seq, cfg)?;
317    let mut hidden = vec![0.0f32; x.len()];
318    for i in 0..hidden.len() {
319        hidden[i] = x[i] + attn_out[i];
320    }
321
322    let normed2 = rms_norm(&hidden, &layer.post_attention_layernorm.scale, h, eps)?;
323    let mlp_out = mlp_forward(&layer.mlp, &normed2, rows, h)?;
324    for i in 0..hidden.len() {
325        hidden[i] += mlp_out[i];
326    }
327    Ok(hidden)
328}
329
330fn embed_tokens(
331    embed: &(Vec<f32>, usize, usize),
332    input_ids: &[u32],
333    batch: usize,
334    seq: usize,
335    hidden: usize,
336) -> Vec<f32> {
337    let (data, vocab, _) = embed;
338    let mut out = vec![0.0f32; batch * seq * hidden];
339    for b in 0..batch {
340        for t in 0..seq {
341            let id = input_ids[b * seq + t] as usize;
342            let id = id.min(vocab.saturating_sub(1));
343            let src = id * hidden;
344            let dst = (b * seq + t) * hidden;
345            out[dst..dst + hidden].copy_from_slice(&data[src..src + hidden]);
346        }
347    }
348    out
349}
350
351/// Encode token ids → FLUX.2 `encoder_hidden_states` + metadata.
352pub fn encode_prompt_embeds(
353    weights: &Flux2TextEncoderWeights,
354    cfg: &Qwen3Config,
355    input_ids: &[u32],
356    batch: usize,
357    seq: usize,
358    hidden_state_layers: &[usize],
359) -> Result<Flux2PromptOutput> {
360    ensure!(input_ids.len() == batch * seq, "input_ids length mismatch");
361    let (cos, sin) = rope_cache(cfg, seq);
362    let mut hidden = embed_tokens(
363        &weights.embed_tokens,
364        input_ids,
365        batch,
366        seq,
367        cfg.hidden_size,
368    );
369    let mut hidden_states: Vec<Vec<f32>> = vec![hidden.clone()];
370    for layer in &weights.layers {
371        hidden = layer_forward(layer, &hidden, &cos, &sin, batch, seq, cfg)?;
372        hidden_states.push(hidden.clone());
373    }
374    let eps = cfg.rms_norm_eps as f32;
375    let _ = rms_norm(&hidden, &weights.norm.scale, cfg.hidden_size, eps)?;
376
377    let h = cfg.hidden_size;
378    let joint_dim = h * hidden_state_layers.len();
379    let mut prompt_embeds = vec![0.0f32; batch * seq * joint_dim];
380    for b in 0..batch {
381        for t in 0..seq {
382            let mut off = 0usize;
383            for (li, &layer_idx) in hidden_state_layers.iter().enumerate() {
384                ensure!(
385                    layer_idx < hidden_states.len(),
386                    "hidden_state_layers[{li}]={layer_idx} out of range (len={})",
387                    hidden_states.len()
388                );
389                let src = (b * seq + t) * h;
390                let dst = (b * seq + t) * joint_dim + off;
391                prompt_embeds[dst..dst + h]
392                    .copy_from_slice(&hidden_states[layer_idx][src..src + h]);
393                off += h;
394            }
395        }
396    }
397    Ok(Flux2PromptOutput {
398        prompt_embeds,
399        seq_len: seq,
400        joint_dim,
401    })
402}
403
404/// Encode with default Klein layer indices (9, 18, 27).
405pub fn encode_prompt_embeds_default_layers(
406    weights: &Flux2TextEncoderWeights,
407    cfg: &Qwen3Config,
408    input_ids: &[u32],
409    batch: usize,
410    seq: usize,
411) -> Result<Flux2PromptOutput> {
412    encode_prompt_embeds(
413        weights,
414        cfg,
415        input_ids,
416        batch,
417        seq,
418        DEFAULT_TEXT_ENCODER_LAYERS,
419    )
420}