1use super::conditioning::Flux2ReferenceConditioning;
19use super::latent_ops::{
20 concat_latent_ids, concat_packed_latents, prepare_latent_ids, slice_gen_noise,
21};
22use super::scheduler::{flow_match_euler_step, flow_match_sigmas};
23use crate::runner::Flux2Runner;
24use anyhow::{Result, ensure};
25
26#[derive(Debug, Clone)]
28pub struct Flux2SampleParams<'a> {
29 pub encoder_hidden_states: &'a [f32],
30 pub encoder_negative: Option<&'a [f32]>,
31 pub txt_ids: &'a [f32],
32 pub neg_txt_ids: Option<&'a [f32]>,
33 pub num_inference_steps: usize,
34 pub cfg_scale: f32,
35 pub guidance: Option<&'a [f32]>,
36 pub latent_h: usize,
37 pub latent_w: usize,
38 pub seed: u64,
39 pub init_timestep: usize,
41 pub initial_latents: Option<&'a [f32]>,
43 pub reference: Option<&'a Flux2ReferenceConditioning>,
45}
46
47#[derive(Debug)]
48pub struct Flux2SampleOutput {
49 pub latents: Vec<f32>,
50 pub img_ids: Vec<f32>,
51 pub img_seq: usize,
52}
53
54pub fn init_latent_noise(batch: usize, img_seq: usize, channels: usize, seed: u64) -> Vec<f32> {
56 let n = batch * img_seq * channels;
57 let mut out = vec![0.0f32; n];
58 let mut state = seed.wrapping_add(1);
59 for v in &mut out {
60 state ^= state << 13;
61 state ^= state >> 7;
62 state ^= state << 17;
63 let u = (state as f32) / (u32::MAX as f32);
64 let r = (-2.0 * u.max(1e-7).ln()).sqrt();
65 state ^= state << 13;
66 state ^= state >> 7;
67 state ^= state << 17;
68 let u2 = (state as f32) / (u32::MAX as f32);
69 let theta = 2.0 * std::f32::consts::PI * u2;
70 *v = r * theta.cos() * 0.5;
71 }
72 out
73}
74
75fn img_seq_from_ids(img_ids: &[f32], batch: usize) -> usize {
76 img_ids.len() / (batch * 4)
77}
78
79pub fn sample_rectified_flow(
81 runner: &Flux2Runner,
82 params: &Flux2SampleParams<'_>,
83) -> Result<Flux2SampleOutput> {
84 let cfg = runner.config();
85 let batch = runner.batch();
86 let txt_seq = runner.txt_seq();
87 let gen_seq = params.latent_h * params.latent_w;
88 ensure!(params.encoder_hidden_states.len() == batch * txt_seq * cfg.joint_attention_dim);
89
90 let gen_ids = prepare_latent_ids(batch, params.latent_h, params.latent_w);
91 let (img_ids, _total_seq) = if let Some(r) = params.reference {
92 (
93 concat_latent_ids(&gen_ids, &r.img_ids, batch),
94 gen_seq + r.seq,
95 )
96 } else {
97 (gen_ids.clone(), gen_seq)
98 };
99
100 runner.warmup_denoiser(&img_ids, params.txt_ids)?;
101
102 let mut latents = if let Some(init) = params.initial_latents {
103 init.to_vec()
104 } else {
105 init_latent_noise(batch, gen_seq, cfg.in_channels, params.seed)
106 };
107 ensure!(latents.len() == batch * gen_seq * cfg.in_channels);
108 let sigmas = flow_match_sigmas(params.num_inference_steps);
109 let default_guidance = vec![3.5f32; batch];
110 let guidance = params.guidance.unwrap_or(&default_guidance);
111 let init_step = params.init_timestep.min(params.num_inference_steps);
112
113 for i in init_step..params.num_inference_steps {
114 let sigma = sigmas[i];
115 let sigma_next = sigmas[i + 1];
116 let timestep = vec![sigma; batch];
117
118 let hidden = if let Some(r) = params.reference {
119 concat_packed_latents(&latents, &r.packed, batch, cfg.in_channels)
120 } else {
121 latents.clone()
122 };
123
124 let noise = if params.cfg_scale > 1.0 {
125 if let (Some(neg_e), Some(neg_ids)) = (params.encoder_negative, params.neg_txt_ids) {
126 runner
127 .forward_cfg(
128 &hidden,
129 params.encoder_hidden_states,
130 neg_e,
131 ×tep,
132 Some(guidance),
133 &img_ids,
134 params.txt_ids,
135 neg_ids,
136 params.cfg_scale,
137 )?
138 .noise_pred
139 } else {
140 runner
141 .forward(
142 &hidden,
143 params.encoder_hidden_states,
144 ×tep,
145 Some(guidance),
146 &img_ids,
147 params.txt_ids,
148 )?
149 .noise_pred
150 }
151 } else {
152 runner
153 .forward(
154 &hidden,
155 params.encoder_hidden_states,
156 ×tep,
157 Some(guidance),
158 &img_ids,
159 params.txt_ids,
160 )?
161 .noise_pred
162 };
163
164 let noise = if params.reference.is_some() {
165 slice_gen_noise(&noise, batch, cfg.in_channels, gen_seq)
166 } else {
167 noise
168 };
169
170 ensure!(noise.len() == latents.len());
171 flow_match_euler_step(&mut latents, &noise, sigma, sigma_next);
172 }
173
174 Ok(Flux2SampleOutput {
175 latents,
176 img_ids: gen_ids,
177 img_seq: gen_seq,
178 })
179}
180
181pub fn generate_to_rgb(
183 runner: &Flux2Runner,
184 params: &Flux2SampleParams<'_>,
185) -> Result<(Vec<u8>, u32, u32)> {
186 let sample = sample_rectified_flow(runner, params)?;
187 let (rgb, h, w) = runner.decode_to_rgb(
188 &sample.latents,
189 &sample.img_ids,
190 params.latent_h,
191 params.latent_w,
192 )?;
193 Ok((rgb, h, w))
194}
195
196pub fn write_ppm(path: &std::path::Path, rgb: &[u8], width: u32, height: u32) -> Result<()> {
198 use std::io::Write;
199 let mut f = std::fs::File::create(path)?;
200 writeln!(f, "P6")?;
201 writeln!(f, "{width} {height}")?;
202 writeln!(f, "255")?;
203 let w = width as usize;
204 let h = height as usize;
205 for y in 0..h {
206 for x in 0..w {
207 let i = (y * w + x) * 3;
208 f.write_all(&rgb[i..i + 3])?;
209 }
210 }
211 Ok(())
212}
213
214#[allow(dead_code)]
215fn _img_seq_helper(ids: &[f32], batch: usize) -> usize {
216 img_seq_from_ids(ids, batch)
217}