1use super::config::Flux2Config;
19use super::layers::{
20 ada_layer_norm_continuous, double_stream_mod, dual_attention, feed_forward, gate_mul,
21 layer_norm_no_affine, linear_no_bias, modulate, modulate_scale_shift, parallel_attention,
22 single_stream_mod, time_guidance_embed,
23};
24use super::rope::flux2_pos_embed;
25use super::weights::Flux2Weights;
26use anyhow::{Result, ensure};
27
28pub struct Flux2ForwardInput<'a> {
30 pub hidden_states: &'a [f32],
32 pub encoder_hidden_states: &'a [f32],
34 pub timestep: &'a [f32],
36 pub timestep_target: Option<&'a [f32]>,
38 pub guidance: Option<&'a [f32]>,
40 pub img_ids: &'a [f32],
42 pub txt_ids: &'a [f32],
43 pub batch: usize,
44 pub img_seq: usize,
45 pub txt_seq: usize,
46}
47
48pub fn flux2_transformer_forward(
51 weights: &Flux2Weights,
52 cfg: &Flux2Config,
53 input: Flux2ForwardInput<'_>,
54) -> Result<Vec<f32>> {
55 let dim = cfg.inner_dim();
56 let heads = cfg.num_attention_heads;
57 let head_dim = cfg.attention_head_dim;
58 let eps = cfg.eps as f32;
59 let rope_dim: usize = cfg.axes_dims_rope.iter().sum();
60 let b = input.batch;
61 let img_seq = input.img_seq;
62 let txt_seq = input.txt_seq;
63 ensure!(input.hidden_states.len() == b * img_seq * cfg.in_channels);
64 ensure!(input.encoder_hidden_states.len() == b * txt_seq * cfg.joint_attention_dim);
65 ensure!(input.timestep.len() == b);
66
67 let t_scaled: Vec<f32> = input.timestep.iter().map(|t| t * 1000.0).collect();
68 let g_scaled = input
69 .guidance
70 .map(|g| g.iter().map(|x| x * 1000.0).collect::<Vec<_>>());
71 let tg_tgt = weights
72 .time_guidance_target
73 .as_ref()
74 .unwrap_or(&weights.time_guidance);
75 let temb = if let Some(t_tgt) = input.timestep_target {
76 let tgt_scaled: Vec<f32> = t_tgt.iter().map(|t| t * 1000.0).collect();
77 super::layers::time_guidance_embed_dual(
78 &t_scaled,
79 &tgt_scaled,
80 g_scaled.as_deref(),
81 &weights.time_guidance,
82 tg_tgt,
83 dim,
84 )?
85 } else {
86 time_guidance_embed(&t_scaled, g_scaled.as_deref(), &weights.time_guidance, dim)?
87 };
88
89 let mod_img = double_stream_mod(&temb, b, dim, &weights.double_mod_img.linear)?;
90 let mod_txt = double_stream_mod(&temb, b, dim, &weights.double_mod_txt.linear)?;
91 let single_mod = single_stream_mod(&temb, b, dim, &weights.single_mod.linear)?;
92
93 let mut hidden = linear_no_bias(input.hidden_states, b * img_seq, &weights.x_embedder)?;
94 let mut encoder = linear_no_bias(
95 input.encoder_hidden_states,
96 b * txt_seq,
97 &weights.context_embedder,
98 )?;
99
100 let n_axes = 4usize;
101 let total_seq = txt_seq + img_seq;
102 let mut ids = vec![0.0f32; total_seq * n_axes];
103 for t in 0..txt_seq {
104 for a in 0..n_axes {
105 ids[t * n_axes + a] = input.txt_ids[t * n_axes + a];
106 }
107 }
108 for t in 0..img_seq {
109 for a in 0..n_axes {
110 ids[(txt_seq + t) * n_axes + a] = input.img_ids[t * n_axes + a];
111 }
112 }
113 let (cos, sin) = flux2_pos_embed(cfg, &ids, total_seq, n_axes);
114
115 for block in &weights.transformer_blocks {
116 let (img_msa, img_mlp) = &mod_img;
117 let (txt_msa, txt_mlp) = &mod_txt;
118
119 let n1 = layer_norm_no_affine(&hidden, dim, eps)?;
120 let n1 = modulate(&n1, &img_msa.0, &img_msa.1, dim, b, img_seq);
121 let nc = layer_norm_no_affine(&encoder, dim, eps)?;
122 let nc = modulate(&nc, &txt_msa.0, &txt_msa.1, dim, b, txt_seq);
123
124 let (enc_attn, img_attn) = dual_attention(
125 &block.attn,
126 &n1,
127 &nc,
128 b,
129 img_seq,
130 txt_seq,
131 heads,
132 head_dim,
133 dim,
134 &cos,
135 &sin,
136 rope_dim,
137 )?;
138 hidden = add_residual(&hidden, &gate_mul(&img_attn, &img_msa.2, dim, b, img_seq));
139 encoder = add_residual(&encoder, &gate_mul(&enc_attn, &txt_msa.2, dim, b, txt_seq));
140
141 let n2 = layer_norm_no_affine(&hidden, dim, eps)?;
142 let n2 = modulate_scale_shift(&n2, &img_mlp.1, &img_mlp.0, dim, b, img_seq);
143 let ff = feed_forward(&block.ff, &n2, b * img_seq, dim)?;
144 hidden = add_residual(&hidden, &gate_mul(&ff, &img_mlp.2, dim, b, img_seq));
145
146 let nc2 = layer_norm_no_affine(&encoder, dim, eps)?;
147 let nc2 = modulate_scale_shift(&nc2, &txt_mlp.1, &txt_mlp.0, dim, b, txt_seq);
148 let ffc = feed_forward(&block.ff_context, &nc2, b * txt_seq, dim)?;
149 encoder = add_residual(&encoder, &gate_mul(&ffc, &txt_mlp.2, dim, b, txt_seq));
150 }
151
152 let mut concat = vec![0.0f32; b * (txt_seq + img_seq) * dim];
153 for bi in 0..b {
154 concat[bi * (txt_seq + img_seq) * dim..bi * (txt_seq + img_seq) * dim + txt_seq * dim]
155 .copy_from_slice(&encoder[bi * txt_seq * dim..(bi + 1) * txt_seq * dim]);
156 concat
157 [bi * (txt_seq + img_seq) * dim + txt_seq * dim..(bi + 1) * (txt_seq + img_seq) * dim]
158 .copy_from_slice(&hidden[bi * img_seq * dim..(bi + 1) * img_seq * dim]);
159 }
160
161 let mlp_hidden = (dim as f64 * cfg.mlp_ratio) as usize;
162 let mut stream = concat;
163 for block in &weights.single_transformer_blocks {
164 let n = layer_norm_no_affine(&stream, dim, eps)?;
165 let n = modulate(&n, &single_mod.0, &single_mod.1, dim, b, txt_seq + img_seq);
166 let attn = parallel_attention(
167 &block.attn,
168 &n,
169 b,
170 txt_seq + img_seq,
171 heads,
172 head_dim,
173 dim,
174 mlp_hidden,
175 &cos,
176 &sin,
177 rope_dim,
178 )?;
179 stream = add_residual(
180 &stream,
181 &gate_mul(&attn, &single_mod.2, dim, b, txt_seq + img_seq),
182 );
183 }
184
185 let mut hidden = vec![0.0f32; b * img_seq * dim];
186 for bi in 0..b {
187 hidden[bi * img_seq * dim..(bi + 1) * img_seq * dim].copy_from_slice(
188 &stream[bi * (txt_seq + img_seq) * dim + txt_seq * dim
189 ..(bi + 1) * (txt_seq + img_seq) * dim],
190 );
191 }
192
193 let normed = ada_layer_norm_continuous(
194 &hidden,
195 &temb,
196 b,
197 img_seq,
198 dim,
199 &weights.norm_out.linear,
200 eps,
201 )?;
202 linear_no_bias(&normed, b * img_seq, &weights.proj_out)
203}
204
205fn add_residual(base: &[f32], delta: &[f32]) -> Vec<f32> {
206 base.iter().zip(delta.iter()).map(|(a, d)| a + d).collect()
207}