Skip to main content

rlx_flux2/
conditioning.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! FLUX.2 img2img / edit latent conditioning (matches mflux).
17
18use super::latent_ops::{
19    bn_normalize_patchified_latents, concat_latent_ids, concat_packed_latents, pack_latents,
20    patchify_latents, prepare_latent_ids, prepare_latent_ids_with_t,
21};
22use super::scheduler::{flow_match_init_timestep, flow_match_sigmas};
23use super::vae::{Flux2VaeConfig, Flux2VaeWeights, flux2_vae_encode};
24use anyhow::{Result, ensure};
25use std::path::Path;
26
27/// Reference image conditioning for edit mode.
28#[derive(Debug, Clone)]
29pub struct Flux2ReferenceConditioning {
30    pub packed: Vec<f32>,
31    pub img_ids: Vec<f32>,
32    pub seq: usize,
33}
34
35/// Crop `[batch, C, H, W]` to even H/W if needed.
36pub fn crop_latents_to_even(
37    latents: &[f32],
38    batch: usize,
39    channels: usize,
40    h: usize,
41    w: usize,
42) -> (Vec<f32>, usize, usize) {
43    let mut out_h = h;
44    let mut out_w = w;
45    if !out_h.is_multiple_of(2) {
46        out_h -= 1;
47    }
48    if !out_w.is_multiple_of(2) {
49        out_w -= 1;
50    }
51    if out_h == h && out_w == w {
52        return (latents.to_vec(), h, w);
53    }
54    let mut out = vec![0.0f32; batch * channels * out_h * out_w];
55    for b in 0..batch {
56        for c in 0..channels {
57            for y in 0..out_h {
58                for x in 0..out_w {
59                    let src = b * channels * h * w + c * h * w + y * w + x;
60                    let dst = b * channels * out_h * out_w + c * out_h * out_w + y * out_w + x;
61                    out[dst] = latents[src];
62                }
63            }
64        }
65    }
66    (out, out_h, out_w)
67}
68
69/// Center-crop or zero-pad spatial dims to `(target_h, target_w)`.
70pub fn match_latent_spatial_size(
71    latents: &[f32],
72    batch: usize,
73    channels: usize,
74    h: usize,
75    w: usize,
76    target_h: usize,
77    target_w: usize,
78) -> Vec<f32> {
79    if h == target_h && w == target_w {
80        return latents.to_vec();
81    }
82    let mut out = vec![0.0f32; batch * channels * target_h * target_w];
83    if h >= target_h {
84        let off_y = (h - target_h) / 2;
85        let off_x = if w >= target_w { (w - target_w) / 2 } else { 0 };
86        let use_w = target_w.min(w);
87        for b in 0..batch {
88            for c in 0..channels {
89                for y in 0..target_h {
90                    for x in 0..use_w {
91                        let src = b * channels * h * w + c * h * w + (y + off_y) * w + (x + off_x);
92                        let dst = b * channels * target_h * target_w
93                            + c * target_h * target_w
94                            + y * target_w
95                            + x;
96                        out[dst] = latents[src];
97                    }
98                }
99            }
100        }
101    } else {
102        let pad_y = (target_h - h) / 2;
103        let pad_x = if w < target_w { (target_w - w) / 2 } else { 0 };
104        let use_w = w.min(target_w);
105        for b in 0..batch {
106            for c in 0..channels {
107                for y in 0..h {
108                    for x in 0..use_w {
109                        let src = b * channels * h * w + c * h * w + y * w + x;
110                        let dst = b * channels * target_h * target_w
111                            + c * target_h * target_w
112                            + (y + pad_y) * target_w
113                            + (x + pad_x);
114                        out[dst] = latents[src];
115                    }
116                }
117            }
118        }
119    }
120    out
121}
122
123/// Post-process VAE-encoded latents → packed transformer input.
124pub fn pack_encoded_latents(
125    vae_weights: &Flux2VaeWeights,
126    vae_cfg: &Flux2VaeConfig,
127    encoded: Vec<f32>,
128    batch: usize,
129    enc_h: usize,
130    enc_w: usize,
131    eff_h: usize,
132    eff_w: usize,
133    latent_h: usize,
134    latent_w: usize,
135) -> Result<Vec<f32>> {
136    let (cropped, ch, cw) =
137        crop_latents_to_even(&encoded, batch, vae_cfg.latent_channels, enc_h, enc_w);
138    let encoded = match_latent_spatial_size(
139        &cropped,
140        batch,
141        vae_cfg.latent_channels,
142        ch,
143        cw,
144        eff_h,
145        eff_w,
146    );
147    let patch = patchify_latents(&encoded, batch, vae_cfg.latent_channels, latent_h, latent_w);
148    let norm = bn_normalize_patchified_latents(
149        &patch,
150        &vae_weights.bn_running_mean,
151        &vae_weights.bn_running_var,
152        vae_cfg.batch_norm_eps,
153    );
154    Ok(pack_latents(
155        &norm,
156        batch,
157        vae_cfg.bn_channels(),
158        latent_h,
159        latent_w,
160    ))
161}
162
163/// VAE encode → patchify → BN → pack for transformer input.
164pub fn encode_rgb_to_packed(
165    vae_weights: &Flux2VaeWeights,
166    vae_cfg: &Flux2VaeConfig,
167    rgb: &[f32],
168    batch: usize,
169    pixel_h: usize,
170    pixel_w: usize,
171    eff_h: usize,
172    eff_w: usize,
173    latent_h: usize,
174    latent_w: usize,
175) -> Result<Vec<f32>> {
176    let stride = vae_cfg.encode_spatial_stride();
177    let enc_h = pixel_h / stride;
178    let enc_w = pixel_w / stride;
179    ensure!(
180        enc_h > 0 && enc_w > 0,
181        "encoded spatial dims too small for {pixel_h}x{pixel_w}"
182    );
183    let encoded = flux2_vae_encode(vae_weights, vae_cfg, rgb, batch, pixel_h, pixel_w)?;
184    ensure!(
185        encoded.len() == batch * vae_cfg.latent_channels * enc_h * enc_w,
186        "encoded len {} != expected {}",
187        encoded.len(),
188        batch * vae_cfg.latent_channels * enc_h * enc_w
189    );
190    pack_encoded_latents(
191        vae_weights,
192        vae_cfg,
193        encoded,
194        batch,
195        enc_h,
196        enc_w,
197        eff_h,
198        eff_w,
199        latent_h,
200        latent_w,
201    )
202}
203
204/// img2img init: noisy blend of encoded source + fresh noise.
205pub fn prepare_img2img_latents(
206    vae_weights: &Flux2VaeWeights,
207    vae_cfg: &Flux2VaeConfig,
208    rgb: &[f32],
209    batch: usize,
210    pixel_h: usize,
211    pixel_w: usize,
212    latent_h: usize,
213    latent_w: usize,
214    eff_h: usize,
215    eff_w: usize,
216    noise: &[f32],
217    image_strength: f32,
218    num_inference_steps: usize,
219) -> Result<Vec<f32>> {
220    let clean = encode_rgb_to_packed(
221        vae_weights,
222        vae_cfg,
223        rgb,
224        batch,
225        pixel_h,
226        pixel_w,
227        eff_h,
228        eff_w,
229        latent_h,
230        latent_w,
231    )?;
232    ensure!(clean.len() == noise.len());
233    let sigmas = flow_match_sigmas(num_inference_steps);
234    let init_step = flow_match_init_timestep(image_strength, num_inference_steps);
235    let sigma = sigmas[init_step.min(sigmas.len() - 1)];
236    Ok(super::latent_ops::blend_latents_with_noise(
237        &clean, noise, sigma,
238    ))
239}
240
241/// Encode one or more reference images for edit-mode concat conditioning.
242pub fn prepare_reference_conditioning(
243    vae_weights: &Flux2VaeWeights,
244    vae_cfg: &Flux2VaeConfig,
245    images: &[(&[f32], usize, usize)],
246    batch: usize,
247    eff_h: usize,
248    eff_w: usize,
249    latent_h: usize,
250    latent_w: usize,
251) -> Result<Flux2ReferenceConditioning> {
252    ensure!(
253        !images.is_empty(),
254        "edit requires at least one reference image"
255    );
256    let channels = vae_cfg.bn_channels();
257    let mut packed_acc: Option<Vec<f32>> = None;
258    let mut ids_acc: Option<Vec<f32>> = None;
259    let mut total_seq = 0usize;
260
261    for (i, (rgb, ph, pw)) in images.iter().enumerate() {
262        let packed = encode_rgb_to_packed(
263            vae_weights,
264            vae_cfg,
265            rgb,
266            batch,
267            *ph,
268            *pw,
269            eff_h,
270            eff_w,
271            latent_h,
272            latent_w,
273        )?;
274        let seq = packed.len() / (batch * channels);
275        total_seq += seq;
276        let ids = prepare_latent_ids_with_t(batch, latent_h, latent_w, 10 + 10 * i as i32);
277        packed_acc = Some(match packed_acc {
278            Some(prev) => concat_packed_latents(&prev, &packed, batch, channels),
279            None => packed,
280        });
281        ids_acc = Some(match ids_acc {
282            Some(prev) => concat_latent_ids(&prev, &ids, batch),
283            None => ids,
284        });
285    }
286
287    Ok(Flux2ReferenceConditioning {
288        packed: packed_acc.unwrap(),
289        img_ids: ids_acc.unwrap(),
290        seq: total_seq,
291    })
292}
293
294/// Gen-only ids for txt2img / img2img / edit output tokens.
295pub fn prepare_generation_ids(batch: usize, latent_h: usize, latent_w: usize) -> Vec<f32> {
296    prepare_latent_ids(batch, latent_h, latent_w)
297}
298
299/// Placeholder for future path-based loading in rlx-models tests.
300pub fn encode_image_path_to_packed(
301    _vae_weights: &Flux2VaeWeights,
302    _vae_cfg: &Flux2VaeConfig,
303    _path: &Path,
304    _batch: usize,
305    _pixel_h: usize,
306    _pixel_w: usize,
307    _eff_h: usize,
308    _eff_w: usize,
309    _latent_h: usize,
310    _latent_w: usize,
311) -> Result<Vec<f32>> {
312    anyhow::bail!("use encode_rgb_to_packed with planar RGB from the caller")
313}