1use super::config::Flux2VaeConfig;
19use super::weights::{
20 AttnBlockWeights, Conv2dWeight, DownEncoderBlockWeights, Flux2VaeWeights, GroupNormWeight,
21 ResnetBlockWeights, UpDecoderBlockWeights,
22};
23use crate::builder::Flux2GraphParams;
24use crate::compile_util::{
25 compile_hir_cached, flux2_vae_decoder_aot_key, flux2_vae_encoder_aot_key,
26};
27use anyhow::Result;
28use rlx_ir::hir::{HirGraphExt, HirModule, HirMut, HirNodeId};
29use rlx_ir::op::{Activation, MaskKind};
30use rlx_ir::{DType, Op, Shape};
31use rlx_runtime::Device;
32
33pub struct Flux2VaeGraph {
34 pub hir: HirModule,
35 pub params: Flux2GraphParams,
36}
37
38pub fn build_flux2_vae_hir(
39 cfg: &Flux2VaeConfig,
40 weights: &Flux2VaeWeights,
41 batch: usize,
42 h: usize,
43 w: usize,
44) -> Result<Flux2VaeGraph> {
45 let lc = cfg.latent_channels;
46 let f = DType::F32;
47 let mut hir =
48 HirModule::new("flux2_vae_decoder").with_fusion_policy(rlx_ir::hir::FusionPolicy::Direct);
49 let mut params = Flux2GraphParams::new();
50 let latents = hir.input("latents", Shape::new(&[batch, lc, h, w], f));
51 let mut b = VaeHirBuilder::from_emit_parts(&mut hir, &mut params, cfg, weights, batch, h, w);
52 let (out, _, _, _) = b.emit_decoder(latents)?;
53 hir.outputs = vec![out];
54 Ok(Flux2VaeGraph { hir, params })
55}
56
57pub fn build_flux2_vae_encoder_hir(
58 cfg: &Flux2VaeConfig,
59 weights: &Flux2VaeWeights,
60 batch: usize,
61 h: usize,
62 w: usize,
63) -> Result<Flux2VaeGraph> {
64 let in_c = cfg.in_channels;
65 let f = DType::F32;
66 let mut hir =
67 HirModule::new("flux2_vae_encoder").with_fusion_policy(rlx_ir::hir::FusionPolicy::Direct);
68 let mut params = Flux2GraphParams::new();
69 let rgb = hir.input("rgb", Shape::new(&[batch, in_c, h, w], f));
70 let mut b = VaeHirBuilder::from_emit_parts(&mut hir, &mut params, cfg, weights, batch, h, w);
71 let out = b.emit_encoder(rgb)?;
72 hir.outputs = vec![out];
73 Ok(Flux2VaeGraph { hir, params })
74}
75
76pub fn compile_flux2_vae_hir(
77 cfg: &Flux2VaeConfig,
78 weights: &Flux2VaeWeights,
79 batch: usize,
80 h: usize,
81 w: usize,
82 device: Device,
83 aot: Option<&rlx_runtime::AotCache>,
84) -> Result<(rlx_runtime::CompiledGraph, Flux2GraphParams)> {
85 crate::device::assert_flux2_device_available(device)?;
86 let g = build_flux2_vae_hir(cfg, weights, batch, h, w)?;
87 let key = flux2_vae_decoder_aot_key(device, batch, h, w);
88 let mut compiled = compile_hir_cached(
89 device,
90 aot,
91 &key,
92 g.hir,
93 &crate::compile_util::flux2_compile_profile(),
94 )?;
95 for (name, data) in &g.params {
96 compiled.set_param(name, data);
97 }
98 Ok((compiled, g.params))
99}
100
101pub fn compile_flux2_vae_encoder_hir(
102 cfg: &Flux2VaeConfig,
103 weights: &Flux2VaeWeights,
104 batch: usize,
105 h: usize,
106 w: usize,
107 device: Device,
108 aot: Option<&rlx_runtime::AotCache>,
109) -> Result<(rlx_runtime::CompiledGraph, Flux2GraphParams)> {
110 crate::device::assert_flux2_device_available(device)?;
111 let g = build_flux2_vae_encoder_hir(cfg, weights, batch, h, w)?;
112 let key = flux2_vae_encoder_aot_key(device, batch, h, w);
113 let mut compiled = compile_hir_cached(
114 device,
115 aot,
116 &key,
117 g.hir,
118 &crate::compile_util::flux2_compile_profile(),
119 )?;
120 for (name, data) in &g.params {
121 compiled.set_param(name, data);
122 }
123 Ok((compiled, g.params))
124}
125
126pub(crate) struct VaeHirBuilder<'a> {
127 hir: &'a mut HirModule,
128 params: &'a mut Flux2GraphParams,
129 cfg: &'a Flux2VaeConfig,
130 weights: &'a Flux2VaeWeights,
131 batch: usize,
132 h: usize,
133 w: usize,
134 f: DType,
135 eps: f32,
136 groups: usize,
137}
138
139impl<'a> VaeHirBuilder<'a> {
140 pub(crate) fn from_emit_parts(
141 hir: &'a mut HirModule,
142 params: &'a mut Flux2GraphParams,
143 cfg: &'a Flux2VaeConfig,
144 weights: &'a Flux2VaeWeights,
145 batch: usize,
146 h: usize,
147 w: usize,
148 ) -> Self {
149 Self {
150 hir,
151 params,
152 cfg,
153 weights,
154 batch,
155 h,
156 w,
157 f: DType::F32,
158 eps: 1e-6,
159 groups: cfg.norm_num_groups,
160 }
161 }
162
163 pub(crate) fn emit_decoder(
164 &mut self,
165 mut x: HirNodeId,
166 ) -> Result<(HirNodeId, usize, usize, usize)> {
167 let lc = self.cfg.latent_channels;
168 let mut channels = lc;
169 let mut h = self.h;
170 let mut w = self.w;
171
172 if let Some(pqc) = &self.weights.post_quant_conv {
173 x = self.conv2d_bias(x, pqc, "post_quant_conv", channels, h, w)?;
174 channels = pqc.out_c;
175 }
176 x = self.conv2d_bias(x, &self.weights.conv_in, "conv_in", channels, h, w)?;
177 channels = self.weights.conv_in.out_c;
178
179 for (i, resnet) in self.weights.mid_resnets.iter().enumerate() {
180 x = self.resnet_block(x, resnet, &format!("mid.{i}"), channels, h, w)?;
181 channels = resnet.conv2.out_c;
182 }
183 if let Some(attn) = &self.weights.mid_attn {
184 x = self.spatial_attention(x, attn, "mid.attn", channels, h, w)?;
185 }
186
187 for (i, block) in self.weights.up_blocks.iter().enumerate() {
188 let (cur, c, hh, ww) = self.up_block(x, block, &format!("up.{i}"), channels, h, w)?;
189 x = cur;
190 channels = c;
191 h = hh;
192 w = ww;
193 }
194
195 let shape = self.nchw(channels, h, w);
196 x = self.group_norm(
197 x,
198 &self.weights.conv_norm_out,
199 "conv_norm_out",
200 shape.clone(),
201 )?;
202 x = self.g().activation(Activation::Silu, x, shape.clone());
203 x = self.conv2d_bias(x, &self.weights.conv_out, "conv_out", channels, h, w)?;
204 let out_c = self.weights.conv_out.out_c;
205 Ok((x, out_c, h, w))
206 }
207
208 fn nchw(&self, c: usize, h: usize, w: usize) -> Shape {
209 Shape::new(&[self.batch, c, h, w], self.f)
210 }
211
212 fn register_param(&mut self, name: &str, data: Vec<f32>, shape: Shape) -> HirNodeId {
213 let id = self.hir.param(name, shape);
214 self.params.insert(name.to_string(), data);
215 id
216 }
217
218 fn g(&mut self) -> HirMut<'_> {
219 HirMut::new(self.hir)
220 }
221
222 pub(crate) fn emit_encoder(&mut self, mut x: HirNodeId) -> Result<HirNodeId> {
223 let in_c = self.cfg.in_channels;
224 let mut channels = in_c;
225 let mut h = self.h;
226 let mut w = self.w;
227
228 x = self.conv2d_bias(
229 x,
230 &self.weights.encoder_conv_in,
231 "encoder.conv_in",
232 channels,
233 h,
234 w,
235 )?;
236 channels = self.weights.encoder_conv_in.out_c;
237
238 for (i, block) in self.weights.encoder_down_blocks.iter().enumerate() {
239 let (cur, c, hh, ww) =
240 self.down_block(x, block, &format!("encoder.down.{i}"), channels, h, w)?;
241 x = cur;
242 channels = c;
243 h = hh;
244 w = ww;
245 }
246
247 for (i, resnet) in self.weights.encoder_mid_resnets.iter().enumerate() {
248 x = self.resnet_block(x, resnet, &format!("encoder.mid.{i}"), channels, h, w)?;
249 channels = resnet.conv2.out_c;
250 }
251 if let Some(attn) = &self.weights.encoder_mid_attn {
252 x = self.spatial_attention(x, attn, "encoder.mid.attn", channels, h, w)?;
253 }
254
255 let shape = self.nchw(channels, h, w);
256 x = self.group_norm(
257 x,
258 &self.weights.encoder_conv_norm_out,
259 "encoder.conv_norm_out",
260 shape.clone(),
261 )?;
262 x = self.g().activation(Activation::Silu, x, shape.clone());
263 x = self.conv2d_bias(
264 x,
265 &self.weights.encoder_conv_out,
266 "encoder.conv_out",
267 channels,
268 h,
269 w,
270 )?;
271 channels = self.weights.encoder_conv_out.out_c;
272
273 x = self.conv2d_bias(x, &self.weights.quant_conv, "quant_conv", channels, h, w)?;
274 let mean_c = self.weights.quant_conv.out_c / 2;
275 Ok(self.g().narrow_(x, 1, 0, mean_c))
276 }
277
278 fn group_norm(
279 &mut self,
280 x: HirNodeId,
281 gn: &GroupNormWeight,
282 name: &str,
283 shape: Shape,
284 ) -> Result<HirNodeId> {
285 let c = shape.dim(1).unwrap_static();
286 let g = self.register_param(
287 &format!("{name}.weight"),
288 gn.gamma.clone(),
289 Shape::new(&[c], self.f),
290 );
291 let b = self.register_param(
292 &format!("{name}.bias"),
293 gn.beta.clone(),
294 Shape::new(&[c], self.f),
295 );
296 let groups = self.groups;
297 let eps = self.eps;
298 Ok(self.g().group_norm(x, g, b, groups, eps))
299 }
300
301 fn conv2d_bias(
302 &mut self,
303 x: HirNodeId,
304 conv: &Conv2dWeight,
305 name: &str,
306 _in_c: usize,
307 h: usize,
308 w: usize,
309 ) -> Result<HirNodeId> {
310 let is_1x1 = conv.weight.len() == conv.out_c * conv.in_c;
311 let (kh, kw) = if is_1x1 { (1, 1) } else { (3, 3) };
312 let (pad, stride) = if is_1x1 {
313 ([0, 0], [1, 1])
314 } else {
315 ([1, 1], [1, 1])
316 };
317 let w_shape = if is_1x1 {
318 Shape::new(&[conv.out_c, conv.in_c, 1, 1], self.f)
319 } else {
320 Shape::new(&[conv.out_c, conv.in_c, 3, 3], self.f)
321 };
322 let weight = self.register_param(&format!("{name}.weight"), conv.weight.clone(), w_shape);
323 let out_shape = self.nchw(conv.out_c, h, w);
324 let y = self
325 .g()
326 .conv2d(x, weight, [kh, kw], stride, pad, 1, out_shape.clone());
327 let bias = self.register_param(
328 &format!("{name}.bias"),
329 conv.bias.clone(),
330 Shape::new(&[conv.out_c], self.f),
331 );
332 let bias4 = self.g().reshape_(bias, vec![1, conv.out_c as i64, 1, 1]);
333 let batch = self.batch;
334 let expanded = self.g().add_node(
335 Op::Expand {
336 target_shape: vec![batch as i64, conv.out_c as i64, h as i64, w as i64],
337 },
338 vec![bias4],
339 out_shape.clone(),
340 );
341 Ok(self.g().add(y, expanded))
342 }
343
344 fn resnet_block(
345 &mut self,
346 x: HirNodeId,
347 b: &ResnetBlockWeights,
348 name: &str,
349 in_c: usize,
350 h: usize,
351 w: usize,
352 ) -> Result<HirNodeId> {
353 let shape = self.nchw(in_c, h, w);
354 let mut residual = x;
355 let mut h1 = self.group_norm(x, &b.norm1, &format!("{name}.norm1"), shape.clone())?;
356 h1 = self.g().activation(Activation::Silu, h1, shape.clone());
357 h1 = self.conv2d_bias(h1, &b.conv1, &format!("{name}.conv1"), in_c, h, w)?;
358 let c1 = b.conv1.out_c;
359 let s1 = self.nchw(c1, h, w);
360 h1 = self.group_norm(h1, &b.norm2, &format!("{name}.norm2"), s1.clone())?;
361 h1 = self.g().activation(Activation::Silu, h1, s1.clone());
362 h1 = self.conv2d_bias(h1, &b.conv2, &format!("{name}.conv2"), c1, h, w)?;
363 let out_c = b.conv2.out_c;
364 if let Some(sc) = &b.shortcut {
365 residual = self.conv2d_bias(residual, sc, &format!("{name}.shortcut"), in_c, h, w)?;
366 }
367 let _out_shape = self.nchw(out_c, h, w);
368 Ok(self.g().add(h1, residual))
369 }
370
371 fn spatial_attention(
372 &mut self,
373 x: HirNodeId,
374 attn: &AttnBlockWeights,
375 name: &str,
376 channels: usize,
377 h: usize,
378 w: usize,
379 ) -> Result<HirNodeId> {
380 let shape = self.nchw(channels, h, w);
381 let normed = self.group_norm(x, &attn.norm, &format!("{name}.norm"), shape.clone())?;
382 let q = self.conv2d_bias(normed, &attn.to_q, &format!("{name}.to_q"), channels, h, w)?;
383 let k = self.conv2d_bias(normed, &attn.to_k, &format!("{name}.to_k"), channels, h, w)?;
384 let v = self.conv2d_bias(normed, &attn.to_v, &format!("{name}.to_v"), channels, h, w)?;
385 let seq = h * w;
386 let batch = self.batch;
387 let bsh = Shape::new(&[batch, seq, channels], self.f);
388 let q2 = self
389 .g()
390 .reshape_(q, vec![batch as i64, seq as i64, channels as i64]);
391 let k2 = self
392 .g()
393 .reshape_(k, vec![batch as i64, seq as i64, channels as i64]);
394 let v2 = self
395 .g()
396 .reshape_(v, vec![batch as i64, seq as i64, channels as i64]);
397 let fixed = self
398 .g()
399 .attention_kind(q2, k2, v2, 1, channels, MaskKind::None, bsh.clone());
400 let fixed4 = self.g().reshape_(
401 fixed,
402 vec![batch as i64, channels as i64, h as i64, w as i64],
403 );
404 let proj = self.conv2d_bias(
405 fixed4,
406 &attn.to_out,
407 &format!("{name}.to_out"),
408 channels,
409 h,
410 w,
411 )?;
412 Ok(self.g().add(x, proj))
413 }
414
415 fn up_block(
416 &mut self,
417 x: HirNodeId,
418 block: &UpDecoderBlockWeights,
419 name: &str,
420 mut in_c: usize,
421 h: usize,
422 w: usize,
423 ) -> Result<(HirNodeId, usize, usize, usize)> {
424 let mut cur = x;
425 for (j, resnet) in block.resnets.iter().enumerate() {
426 let out_c = resnet.conv2.out_c;
427 cur = self.resnet_block(cur, resnet, &format!("{name}.resnet.{j}"), in_c, h, w)?;
428 in_c = out_c;
429 }
430 let mut out_h = h;
431 let mut out_w = w;
432 if let Some(up) = &block.upsample {
433 let uped = self.g().resize_nearest_2x(cur);
434 out_h = h * 2;
435 out_w = w * 2;
436 cur = self.conv2d_bias(uped, up, &format!("{name}.upsample"), in_c, out_h, out_w)?;
437 in_c = up.out_c;
438 }
439 Ok((cur, in_c, out_h, out_w))
440 }
441
442 fn down_block(
443 &mut self,
444 x: HirNodeId,
445 block: &DownEncoderBlockWeights,
446 name: &str,
447 mut in_c: usize,
448 h: usize,
449 w: usize,
450 ) -> Result<(HirNodeId, usize, usize, usize)> {
451 let mut cur = x;
452 for (j, resnet) in block.resnets.iter().enumerate() {
453 let out_c = resnet.conv2.out_c;
454 cur = self.resnet_block(cur, resnet, &format!("{name}.resnet.{j}"), in_c, h, w)?;
455 in_c = out_c;
456 }
457 let mut out_h = h;
458 let mut out_w = w;
459 if let Some(down) = &block.downsample {
460 out_h = (h + 1 - 3) / 2 + 1;
461 out_w = (w + 1 - 3) / 2 + 1;
462 cur = self.conv2d_downsample(
463 cur,
464 down,
465 &format!("{name}.downsample"),
466 in_c,
467 h,
468 w,
469 out_h,
470 out_w,
471 )?;
472 in_c = down.out_c;
473 }
474 Ok((cur, in_c, out_h, out_w))
475 }
476
477 fn conv2d_downsample(
478 &mut self,
479 x: HirNodeId,
480 conv: &Conv2dWeight,
481 name: &str,
482 _in_c: usize,
483 _h: usize,
484 _w: usize,
485 out_h: usize,
486 out_w: usize,
487 ) -> Result<HirNodeId> {
488 let w_shape = Shape::new(&[conv.out_c, conv.in_c, 3, 3], self.f);
489 let weight = self.register_param(&format!("{name}.weight"), conv.weight.clone(), w_shape);
490 let out_shape = self.nchw(conv.out_c, out_h, out_w);
491 let y = self
492 .g()
493 .conv2d(x, weight, [3, 3], [2, 2], [1, 1], 1, out_shape.clone());
494 let bias = self.register_param(
495 &format!("{name}.bias"),
496 conv.bias.clone(),
497 Shape::new(&[conv.out_c], self.f),
498 );
499 let bias4 = self.g().reshape_(bias, vec![1, conv.out_c as i64, 1, 1]);
500 let batch = self.batch;
501 let expanded = self.g().add_node(
502 Op::Expand {
503 target_shape: vec![batch as i64, conv.out_c as i64, out_h as i64, out_w as i64],
504 },
505 vec![bias4],
506 out_shape.clone(),
507 );
508 Ok(self.g().add(y, expanded))
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515 use crate::vae::{Flux2VaeConfig, flux2_vae_decode, synthetic_vae_weights};
516 use rlx_runtime::Device;
517
518 #[test]
519 fn vae_hir_lowers() {
520 let cfg = Flux2VaeConfig::tiny();
521 let w = synthetic_vae_weights(&cfg);
522 let g = build_flux2_vae_hir(&cfg, &w, 1, 4, 4).unwrap();
523 g.hir.lower_to_mir().expect("lower");
524 }
525
526 #[test]
527 fn compiled_vae_encoder_matches_native() {
528 let cfg = Flux2VaeConfig::tiny();
529 let w = synthetic_vae_weights(&cfg);
530 let batch = 1usize;
531 let h = 32usize;
532 let w_px = 32usize;
533 let rgb: Vec<f32> = (0..batch * 3 * h * w_px)
534 .map(|i| (i as f32 * 0.001).sin())
535 .collect();
536
537 let native =
538 super::super::encoder::flux2_vae_encode(&w, &cfg, &rgb, batch, h, w_px).unwrap();
539
540 let (mut compiled, _) =
541 compile_flux2_vae_encoder_hir(&cfg, &w, batch, h, w_px, Device::Cpu, None).unwrap();
542 let mut out = compiled.run(&[("rgb", rgb.as_slice())]).remove(0);
543 if cfg.scaling_factor != 1.0 || cfg.shift_factor != 0.0 {
544 for v in &mut out {
545 *v = (*v - cfg.shift_factor) * cfg.scaling_factor;
546 }
547 }
548
549 assert_eq!(out.len(), native.len());
550 let max = out
551 .iter()
552 .zip(&native)
553 .map(|(a, b)| (a - b).abs())
554 .fold(0.0f32, f32::max);
555 assert!(max < 5e-2, "HIR encoder vs native max_abs_diff={max}");
556 }
557
558 #[test]
559 fn compiled_vae_matches_native() {
560 let cfg = Flux2VaeConfig::tiny();
561 let w = synthetic_vae_weights(&cfg);
562 let batch = 1usize;
563 let h = 4usize;
564 let w_px = 4usize;
565 let latents = vec![0.1f32; batch * cfg.latent_channels * h * w_px];
566
567 let native = flux2_vae_decode(&w, &cfg, &latents, batch, h, w_px).unwrap();
568
569 let (mut compiled, _) =
570 compile_flux2_vae_hir(&cfg, &w, batch, h, w_px, Device::Cpu, None).unwrap();
571 let out = compiled.run(&[("latents", latents.as_slice())]).remove(0);
572
573 assert_eq!(out.len(), native.len());
574 let up = 2usize.pow(cfg.block_out_channels.len().saturating_sub(1) as u32);
575 assert_eq!(out.len(), batch * cfg.out_channels * h * up * w_px * up);
576 let max = out
577 .iter()
578 .zip(&native)
579 .map(|(a, b)| (a - b).abs())
580 .fold(0.0f32, f32::max);
581 assert!(max < 2e-2, "HIR vs native VAE max_abs_diff={max}");
582 }
583}