Skip to main content

rlx_flux2/
pipeline.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Multi-step rectified-flow sampling and end-to-end generate helpers.
17
18use super::conditioning::Flux2ReferenceConditioning;
19use super::latent_ops::{
20    concat_latent_ids, concat_packed_latents, prepare_latent_ids, slice_gen_noise,
21};
22use super::scheduler::{flow_match_euler_step, flow_match_sigmas};
23use crate::runner::Flux2Runner;
24use anyhow::{Result, ensure};
25
26/// Sampling / generation options.
27#[derive(Debug, Clone)]
28pub struct Flux2SampleParams<'a> {
29    pub encoder_hidden_states: &'a [f32],
30    pub encoder_negative: Option<&'a [f32]>,
31    pub txt_ids: &'a [f32],
32    pub neg_txt_ids: Option<&'a [f32]>,
33    pub num_inference_steps: usize,
34    pub cfg_scale: f32,
35    pub guidance: Option<&'a [f32]>,
36    pub latent_h: usize,
37    pub latent_w: usize,
38    pub seed: u64,
39    /// img2img: starting step index (from [`flow_match_init_timestep`]).
40    pub init_timestep: usize,
41    /// img2img: pre-blended latents (skips fresh noise init when set).
42    pub initial_latents: Option<&'a [f32]>,
43    /// Edit mode: fixed reference tokens concatenated before each forward.
44    pub reference: Option<&'a Flux2ReferenceConditioning>,
45}
46
47#[derive(Debug)]
48pub struct Flux2SampleOutput {
49    pub latents: Vec<f32>,
50    pub img_ids: Vec<f32>,
51    pub img_seq: usize,
52}
53
54/// Initialize Gaussian latents `[batch, img_seq, in_channels]`.
55pub fn init_latent_noise(batch: usize, img_seq: usize, channels: usize, seed: u64) -> Vec<f32> {
56    let n = batch * img_seq * channels;
57    let mut out = vec![0.0f32; n];
58    let mut state = seed.wrapping_add(1);
59    for v in &mut out {
60        state ^= state << 13;
61        state ^= state >> 7;
62        state ^= state << 17;
63        let u = (state as f32) / (u32::MAX as f32);
64        let r = (-2.0 * u.max(1e-7).ln()).sqrt();
65        state ^= state << 13;
66        state ^= state >> 7;
67        state ^= state << 17;
68        let u2 = (state as f32) / (u32::MAX as f32);
69        let theta = 2.0 * std::f32::consts::PI * u2;
70        *v = r * theta.cos() * 0.5;
71    }
72    out
73}
74
75fn img_seq_from_ids(img_ids: &[f32], batch: usize) -> usize {
76    img_ids.len() / (batch * 4)
77}
78
79/// Run Flow-Match Euler steps on the denoiser (compiled or native per [`Flux2Runner`]).
80pub fn sample_rectified_flow(
81    runner: &Flux2Runner,
82    params: &Flux2SampleParams<'_>,
83) -> Result<Flux2SampleOutput> {
84    let cfg = runner.config();
85    let batch = runner.batch();
86    let txt_seq = runner.txt_seq();
87    let gen_seq = params.latent_h * params.latent_w;
88    ensure!(params.encoder_hidden_states.len() == batch * txt_seq * cfg.joint_attention_dim);
89
90    let gen_ids = prepare_latent_ids(batch, params.latent_h, params.latent_w);
91    let (img_ids, _total_seq) = if let Some(r) = params.reference {
92        (
93            concat_latent_ids(&gen_ids, &r.img_ids, batch),
94            gen_seq + r.seq,
95        )
96    } else {
97        (gen_ids.clone(), gen_seq)
98    };
99
100    runner.warmup_denoiser(&img_ids, params.txt_ids)?;
101
102    let mut latents = if let Some(init) = params.initial_latents {
103        init.to_vec()
104    } else {
105        init_latent_noise(batch, gen_seq, cfg.in_channels, params.seed)
106    };
107    ensure!(latents.len() == batch * gen_seq * cfg.in_channels);
108    let sigmas = flow_match_sigmas(params.num_inference_steps);
109    let default_guidance = vec![3.5f32; batch];
110    let guidance = params.guidance.unwrap_or(&default_guidance);
111    let init_step = params.init_timestep.min(params.num_inference_steps);
112
113    for i in init_step..params.num_inference_steps {
114        let sigma = sigmas[i];
115        let sigma_next = sigmas[i + 1];
116        let timestep = vec![sigma; batch];
117
118        let hidden = if let Some(r) = params.reference {
119            concat_packed_latents(&latents, &r.packed, batch, cfg.in_channels)
120        } else {
121            latents.clone()
122        };
123
124        let noise = if params.cfg_scale > 1.0 {
125            if let (Some(neg_e), Some(neg_ids)) = (params.encoder_negative, params.neg_txt_ids) {
126                runner
127                    .forward_cfg(
128                        &hidden,
129                        params.encoder_hidden_states,
130                        neg_e,
131                        &timestep,
132                        Some(guidance),
133                        &img_ids,
134                        params.txt_ids,
135                        neg_ids,
136                        params.cfg_scale,
137                    )?
138                    .noise_pred
139            } else {
140                runner
141                    .forward(
142                        &hidden,
143                        params.encoder_hidden_states,
144                        &timestep,
145                        Some(guidance),
146                        &img_ids,
147                        params.txt_ids,
148                    )?
149                    .noise_pred
150            }
151        } else {
152            runner
153                .forward(
154                    &hidden,
155                    params.encoder_hidden_states,
156                    &timestep,
157                    Some(guidance),
158                    &img_ids,
159                    params.txt_ids,
160                )?
161                .noise_pred
162        };
163
164        let noise = if params.reference.is_some() {
165            slice_gen_noise(&noise, batch, cfg.in_channels, gen_seq)
166        } else {
167            noise
168        };
169
170        ensure!(noise.len() == latents.len());
171        flow_match_euler_step(&mut latents, &noise, sigma, sigma_next);
172    }
173
174    Ok(Flux2SampleOutput {
175        latents,
176        img_ids: gen_ids,
177        img_seq: gen_seq,
178    })
179}
180
181/// Sample then VAE-decode to RGB u8 when VAE weights are loaded.
182pub fn generate_to_rgb(
183    runner: &Flux2Runner,
184    params: &Flux2SampleParams<'_>,
185) -> Result<(Vec<u8>, u32, u32)> {
186    let sample = sample_rectified_flow(runner, params)?;
187    let (rgb, h, w) = runner.decode_to_rgb(
188        &sample.latents,
189        &sample.img_ids,
190        params.latent_h,
191        params.latent_w,
192    )?;
193    Ok((rgb, h, w))
194}
195
196/// Write planar RGB u8 HWC to a simple PPM (for `rlx-flux2 --output`).
197pub fn write_ppm(path: &std::path::Path, rgb: &[u8], width: u32, height: u32) -> Result<()> {
198    use std::io::Write;
199    let mut f = std::fs::File::create(path)?;
200    writeln!(f, "P6")?;
201    writeln!(f, "{width} {height}")?;
202    writeln!(f, "255")?;
203    let w = width as usize;
204    let h = height as usize;
205    for y in 0..h {
206        for x in 0..w {
207            let i = (y * w + x) * 3;
208            f.write_all(&rgb[i..i + 3])?;
209        }
210    }
211    Ok(())
212}
213
214#[allow(dead_code)]
215fn _img_seq_helper(ids: &[f32], batch: usize) -> usize {
216    img_seq_from_ids(ids, batch)
217}