1use super::latent_ops::{
19 bn_normalize_patchified_latents, concat_latent_ids, concat_packed_latents, pack_latents,
20 patchify_latents, prepare_latent_ids, prepare_latent_ids_with_t,
21};
22use super::scheduler::{flow_match_init_timestep, flow_match_sigmas};
23use super::vae::{Flux2VaeConfig, Flux2VaeWeights, flux2_vae_encode};
24use anyhow::{Result, ensure};
25use std::path::Path;
26
27#[derive(Debug, Clone)]
29pub struct Flux2ReferenceConditioning {
30 pub packed: Vec<f32>,
31 pub img_ids: Vec<f32>,
32 pub seq: usize,
33}
34
35pub fn crop_latents_to_even(
37 latents: &[f32],
38 batch: usize,
39 channels: usize,
40 h: usize,
41 w: usize,
42) -> (Vec<f32>, usize, usize) {
43 let mut out_h = h;
44 let mut out_w = w;
45 if !out_h.is_multiple_of(2) {
46 out_h -= 1;
47 }
48 if !out_w.is_multiple_of(2) {
49 out_w -= 1;
50 }
51 if out_h == h && out_w == w {
52 return (latents.to_vec(), h, w);
53 }
54 let mut out = vec![0.0f32; batch * channels * out_h * out_w];
55 for b in 0..batch {
56 for c in 0..channels {
57 for y in 0..out_h {
58 for x in 0..out_w {
59 let src = b * channels * h * w + c * h * w + y * w + x;
60 let dst = b * channels * out_h * out_w + c * out_h * out_w + y * out_w + x;
61 out[dst] = latents[src];
62 }
63 }
64 }
65 }
66 (out, out_h, out_w)
67}
68
69pub fn match_latent_spatial_size(
71 latents: &[f32],
72 batch: usize,
73 channels: usize,
74 h: usize,
75 w: usize,
76 target_h: usize,
77 target_w: usize,
78) -> Vec<f32> {
79 if h == target_h && w == target_w {
80 return latents.to_vec();
81 }
82 let mut out = vec![0.0f32; batch * channels * target_h * target_w];
83 if h >= target_h {
84 let off_y = (h - target_h) / 2;
85 let off_x = if w >= target_w { (w - target_w) / 2 } else { 0 };
86 let use_w = target_w.min(w);
87 for b in 0..batch {
88 for c in 0..channels {
89 for y in 0..target_h {
90 for x in 0..use_w {
91 let src = b * channels * h * w + c * h * w + (y + off_y) * w + (x + off_x);
92 let dst = b * channels * target_h * target_w
93 + c * target_h * target_w
94 + y * target_w
95 + x;
96 out[dst] = latents[src];
97 }
98 }
99 }
100 }
101 } else {
102 let pad_y = (target_h - h) / 2;
103 let pad_x = if w < target_w { (target_w - w) / 2 } else { 0 };
104 let use_w = w.min(target_w);
105 for b in 0..batch {
106 for c in 0..channels {
107 for y in 0..h {
108 for x in 0..use_w {
109 let src = b * channels * h * w + c * h * w + y * w + x;
110 let dst = b * channels * target_h * target_w
111 + c * target_h * target_w
112 + (y + pad_y) * target_w
113 + (x + pad_x);
114 out[dst] = latents[src];
115 }
116 }
117 }
118 }
119 }
120 out
121}
122
123pub fn pack_encoded_latents(
125 vae_weights: &Flux2VaeWeights,
126 vae_cfg: &Flux2VaeConfig,
127 encoded: Vec<f32>,
128 batch: usize,
129 enc_h: usize,
130 enc_w: usize,
131 eff_h: usize,
132 eff_w: usize,
133 latent_h: usize,
134 latent_w: usize,
135) -> Result<Vec<f32>> {
136 let (cropped, ch, cw) =
137 crop_latents_to_even(&encoded, batch, vae_cfg.latent_channels, enc_h, enc_w);
138 let encoded = match_latent_spatial_size(
139 &cropped,
140 batch,
141 vae_cfg.latent_channels,
142 ch,
143 cw,
144 eff_h,
145 eff_w,
146 );
147 let patch = patchify_latents(&encoded, batch, vae_cfg.latent_channels, latent_h, latent_w);
148 let norm = bn_normalize_patchified_latents(
149 &patch,
150 &vae_weights.bn_running_mean,
151 &vae_weights.bn_running_var,
152 vae_cfg.batch_norm_eps,
153 );
154 Ok(pack_latents(
155 &norm,
156 batch,
157 vae_cfg.bn_channels(),
158 latent_h,
159 latent_w,
160 ))
161}
162
163pub fn encode_rgb_to_packed(
165 vae_weights: &Flux2VaeWeights,
166 vae_cfg: &Flux2VaeConfig,
167 rgb: &[f32],
168 batch: usize,
169 pixel_h: usize,
170 pixel_w: usize,
171 eff_h: usize,
172 eff_w: usize,
173 latent_h: usize,
174 latent_w: usize,
175) -> Result<Vec<f32>> {
176 let stride = vae_cfg.encode_spatial_stride();
177 let enc_h = pixel_h / stride;
178 let enc_w = pixel_w / stride;
179 ensure!(
180 enc_h > 0 && enc_w > 0,
181 "encoded spatial dims too small for {pixel_h}x{pixel_w}"
182 );
183 let encoded = flux2_vae_encode(vae_weights, vae_cfg, rgb, batch, pixel_h, pixel_w)?;
184 ensure!(
185 encoded.len() == batch * vae_cfg.latent_channels * enc_h * enc_w,
186 "encoded len {} != expected {}",
187 encoded.len(),
188 batch * vae_cfg.latent_channels * enc_h * enc_w
189 );
190 pack_encoded_latents(
191 vae_weights,
192 vae_cfg,
193 encoded,
194 batch,
195 enc_h,
196 enc_w,
197 eff_h,
198 eff_w,
199 latent_h,
200 latent_w,
201 )
202}
203
204pub fn prepare_img2img_latents(
206 vae_weights: &Flux2VaeWeights,
207 vae_cfg: &Flux2VaeConfig,
208 rgb: &[f32],
209 batch: usize,
210 pixel_h: usize,
211 pixel_w: usize,
212 latent_h: usize,
213 latent_w: usize,
214 eff_h: usize,
215 eff_w: usize,
216 noise: &[f32],
217 image_strength: f32,
218 num_inference_steps: usize,
219) -> Result<Vec<f32>> {
220 let clean = encode_rgb_to_packed(
221 vae_weights,
222 vae_cfg,
223 rgb,
224 batch,
225 pixel_h,
226 pixel_w,
227 eff_h,
228 eff_w,
229 latent_h,
230 latent_w,
231 )?;
232 ensure!(clean.len() == noise.len());
233 let sigmas = flow_match_sigmas(num_inference_steps);
234 let init_step = flow_match_init_timestep(image_strength, num_inference_steps);
235 let sigma = sigmas[init_step.min(sigmas.len() - 1)];
236 Ok(super::latent_ops::blend_latents_with_noise(
237 &clean, noise, sigma,
238 ))
239}
240
241pub fn prepare_reference_conditioning(
243 vae_weights: &Flux2VaeWeights,
244 vae_cfg: &Flux2VaeConfig,
245 images: &[(&[f32], usize, usize)],
246 batch: usize,
247 eff_h: usize,
248 eff_w: usize,
249 latent_h: usize,
250 latent_w: usize,
251) -> Result<Flux2ReferenceConditioning> {
252 ensure!(
253 !images.is_empty(),
254 "edit requires at least one reference image"
255 );
256 let channels = vae_cfg.bn_channels();
257 let mut packed_acc: Option<Vec<f32>> = None;
258 let mut ids_acc: Option<Vec<f32>> = None;
259 let mut total_seq = 0usize;
260
261 for (i, (rgb, ph, pw)) in images.iter().enumerate() {
262 let packed = encode_rgb_to_packed(
263 vae_weights,
264 vae_cfg,
265 rgb,
266 batch,
267 *ph,
268 *pw,
269 eff_h,
270 eff_w,
271 latent_h,
272 latent_w,
273 )?;
274 let seq = packed.len() / (batch * channels);
275 total_seq += seq;
276 let ids = prepare_latent_ids_with_t(batch, latent_h, latent_w, 10 + 10 * i as i32);
277 packed_acc = Some(match packed_acc {
278 Some(prev) => concat_packed_latents(&prev, &packed, batch, channels),
279 None => packed,
280 });
281 ids_acc = Some(match ids_acc {
282 Some(prev) => concat_latent_ids(&prev, &ids, batch),
283 None => ids,
284 });
285 }
286
287 Ok(Flux2ReferenceConditioning {
288 packed: packed_acc.unwrap(),
289 img_ids: ids_acc.unwrap(),
290 seq: total_seq,
291 })
292}
293
294pub fn prepare_generation_ids(batch: usize, latent_h: usize, latent_w: usize) -> Vec<f32> {
296 prepare_latent_ids(batch, latent_h, latent_w)
297}
298
299pub fn encode_image_path_to_packed(
301 _vae_weights: &Flux2VaeWeights,
302 _vae_cfg: &Flux2VaeConfig,
303 _path: &Path,
304 _batch: usize,
305 _pixel_h: usize,
306 _pixel_w: usize,
307 _eff_h: usize,
308 _eff_w: usize,
309 _latent_h: usize,
310 _latent_w: usize,
311) -> Result<Vec<f32>> {
312 anyhow::bail!("use encode_rgb_to_packed with planar RGB from the caller")
313}