Skip to main content

rlx_flux2/
forward.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Native CPU forward for the FLUX.2 transformer denoiser.
17
18use super::config::Flux2Config;
19use super::layers::{
20    ada_layer_norm_continuous, double_stream_mod, dual_attention, feed_forward, gate_mul,
21    layer_norm_no_affine, linear_no_bias, modulate, modulate_scale_shift, parallel_attention,
22    single_stream_mod, time_guidance_embed,
23};
24use super::rope::flux2_pos_embed;
25use super::weights::Flux2Weights;
26use anyhow::{Result, ensure};
27
28/// Inputs for one transformer forward (noise prediction).
29pub struct Flux2ForwardInput<'a> {
30    /// Image latents `[batch, img_seq, in_channels]`.
31    pub hidden_states: &'a [f32],
32    /// Text encoder states `[batch, txt_seq, joint_attention_dim]`.
33    pub encoder_hidden_states: &'a [f32],
34    /// Per-batch timestep (sigma); multiplied by 1000 inside the model.
35    pub timestep: &'a [f32],
36    /// Optional second time for flow-map embedding (average with `timestep`).
37    pub timestep_target: Option<&'a [f32]>,
38    /// Per-batch guidance scale; multiplied by 1000 when guidance embeds are enabled.
39    pub guidance: Option<&'a [f32]>,
40    /// Position ids `[img_seq + txt_seq, 4]` (concatenated txt then img along axis 0).
41    pub img_ids: &'a [f32],
42    pub txt_ids: &'a [f32],
43    pub batch: usize,
44    pub img_seq: usize,
45    pub txt_seq: usize,
46}
47
48/// Run the FLUX.2 transformer and return noise prediction
49/// `[batch, img_seq, patch_size² * out_channels]`.
50pub fn flux2_transformer_forward(
51    weights: &Flux2Weights,
52    cfg: &Flux2Config,
53    input: Flux2ForwardInput<'_>,
54) -> Result<Vec<f32>> {
55    let dim = cfg.inner_dim();
56    let heads = cfg.num_attention_heads;
57    let head_dim = cfg.attention_head_dim;
58    let eps = cfg.eps as f32;
59    let rope_dim: usize = cfg.axes_dims_rope.iter().sum();
60    let b = input.batch;
61    let img_seq = input.img_seq;
62    let txt_seq = input.txt_seq;
63    ensure!(input.hidden_states.len() == b * img_seq * cfg.in_channels);
64    ensure!(input.encoder_hidden_states.len() == b * txt_seq * cfg.joint_attention_dim);
65    ensure!(input.timestep.len() == b);
66
67    let t_scaled: Vec<f32> = input.timestep.iter().map(|t| t * 1000.0).collect();
68    let g_scaled = input
69        .guidance
70        .map(|g| g.iter().map(|x| x * 1000.0).collect::<Vec<_>>());
71    let tg_tgt = weights
72        .time_guidance_target
73        .as_ref()
74        .unwrap_or(&weights.time_guidance);
75    let temb = if let Some(t_tgt) = input.timestep_target {
76        let tgt_scaled: Vec<f32> = t_tgt.iter().map(|t| t * 1000.0).collect();
77        super::layers::time_guidance_embed_dual(
78            &t_scaled,
79            &tgt_scaled,
80            g_scaled.as_deref(),
81            &weights.time_guidance,
82            tg_tgt,
83            dim,
84        )?
85    } else {
86        time_guidance_embed(&t_scaled, g_scaled.as_deref(), &weights.time_guidance, dim)?
87    };
88
89    let mod_img = double_stream_mod(&temb, b, dim, &weights.double_mod_img.linear)?;
90    let mod_txt = double_stream_mod(&temb, b, dim, &weights.double_mod_txt.linear)?;
91    let single_mod = single_stream_mod(&temb, b, dim, &weights.single_mod.linear)?;
92
93    let mut hidden = linear_no_bias(input.hidden_states, b * img_seq, &weights.x_embedder)?;
94    let mut encoder = linear_no_bias(
95        input.encoder_hidden_states,
96        b * txt_seq,
97        &weights.context_embedder,
98    )?;
99
100    let n_axes = 4usize;
101    let total_seq = txt_seq + img_seq;
102    let mut ids = vec![0.0f32; total_seq * n_axes];
103    for t in 0..txt_seq {
104        for a in 0..n_axes {
105            ids[t * n_axes + a] = input.txt_ids[t * n_axes + a];
106        }
107    }
108    for t in 0..img_seq {
109        for a in 0..n_axes {
110            ids[(txt_seq + t) * n_axes + a] = input.img_ids[t * n_axes + a];
111        }
112    }
113    let (cos, sin) = flux2_pos_embed(cfg, &ids, total_seq, n_axes);
114
115    for block in &weights.transformer_blocks {
116        let (img_msa, img_mlp) = &mod_img;
117        let (txt_msa, txt_mlp) = &mod_txt;
118
119        let n1 = layer_norm_no_affine(&hidden, dim, eps)?;
120        let n1 = modulate(&n1, &img_msa.0, &img_msa.1, dim, b, img_seq);
121        let nc = layer_norm_no_affine(&encoder, dim, eps)?;
122        let nc = modulate(&nc, &txt_msa.0, &txt_msa.1, dim, b, txt_seq);
123
124        let (enc_attn, img_attn) = dual_attention(
125            &block.attn,
126            &n1,
127            &nc,
128            b,
129            img_seq,
130            txt_seq,
131            heads,
132            head_dim,
133            dim,
134            &cos,
135            &sin,
136            rope_dim,
137        )?;
138        hidden = add_residual(&hidden, &gate_mul(&img_attn, &img_msa.2, dim, b, img_seq));
139        encoder = add_residual(&encoder, &gate_mul(&enc_attn, &txt_msa.2, dim, b, txt_seq));
140
141        let n2 = layer_norm_no_affine(&hidden, dim, eps)?;
142        let n2 = modulate_scale_shift(&n2, &img_mlp.1, &img_mlp.0, dim, b, img_seq);
143        let ff = feed_forward(&block.ff, &n2, b * img_seq, dim)?;
144        hidden = add_residual(&hidden, &gate_mul(&ff, &img_mlp.2, dim, b, img_seq));
145
146        let nc2 = layer_norm_no_affine(&encoder, dim, eps)?;
147        let nc2 = modulate_scale_shift(&nc2, &txt_mlp.1, &txt_mlp.0, dim, b, txt_seq);
148        let ffc = feed_forward(&block.ff_context, &nc2, b * txt_seq, dim)?;
149        encoder = add_residual(&encoder, &gate_mul(&ffc, &txt_mlp.2, dim, b, txt_seq));
150    }
151
152    let mut concat = vec![0.0f32; b * (txt_seq + img_seq) * dim];
153    for bi in 0..b {
154        concat[bi * (txt_seq + img_seq) * dim..bi * (txt_seq + img_seq) * dim + txt_seq * dim]
155            .copy_from_slice(&encoder[bi * txt_seq * dim..(bi + 1) * txt_seq * dim]);
156        concat
157            [bi * (txt_seq + img_seq) * dim + txt_seq * dim..(bi + 1) * (txt_seq + img_seq) * dim]
158            .copy_from_slice(&hidden[bi * img_seq * dim..(bi + 1) * img_seq * dim]);
159    }
160
161    let mlp_hidden = (dim as f64 * cfg.mlp_ratio) as usize;
162    let mut stream = concat;
163    for block in &weights.single_transformer_blocks {
164        let n = layer_norm_no_affine(&stream, dim, eps)?;
165        let n = modulate(&n, &single_mod.0, &single_mod.1, dim, b, txt_seq + img_seq);
166        let attn = parallel_attention(
167            &block.attn,
168            &n,
169            b,
170            txt_seq + img_seq,
171            heads,
172            head_dim,
173            dim,
174            mlp_hidden,
175            &cos,
176            &sin,
177            rope_dim,
178        )?;
179        stream = add_residual(
180            &stream,
181            &gate_mul(&attn, &single_mod.2, dim, b, txt_seq + img_seq),
182        );
183    }
184
185    let mut hidden = vec![0.0f32; b * img_seq * dim];
186    for bi in 0..b {
187        hidden[bi * img_seq * dim..(bi + 1) * img_seq * dim].copy_from_slice(
188            &stream[bi * (txt_seq + img_seq) * dim + txt_seq * dim
189                ..(bi + 1) * (txt_seq + img_seq) * dim],
190        );
191    }
192
193    let normed = ada_layer_norm_continuous(
194        &hidden,
195        &temb,
196        b,
197        img_seq,
198        dim,
199        &weights.norm_out.linear,
200        eps,
201    )?;
202    linear_no_bias(&normed, b * img_seq, &weights.proj_out)
203}
204
205fn add_residual(base: &[f32], delta: &[f32]) -> Vec<f32> {
206    base.iter().zip(delta.iter()).map(|(a, d)| a + d).collect()
207}