1use anyhow::{Context, Result, anyhow, bail, ensure};
19use rlx_runtime::Device;
20use std::path::PathBuf;
21use std::sync::Mutex;
22
23#[derive(Debug, Clone)]
25pub struct Flux2Output {
26 pub noise_pred: Vec<f32>,
27 pub img_seq: usize,
28 pub out_dim: usize,
29}
30
31#[derive(Debug, Clone, Default)]
33pub struct Flux2RunnerBuilder {
34 weights: Option<PathBuf>,
35 config: Option<crate::Flux2Config>,
36 config_path: Option<PathBuf>,
37 text_encoder_dir: Option<PathBuf>,
38 text_encoder_config_path: Option<PathBuf>,
39 vae_dir: Option<PathBuf>,
40 vae_config_path: Option<PathBuf>,
41 tokenizer_path: Option<PathBuf>,
42 batch: Option<usize>,
43 img_seq: Option<usize>,
44 txt_seq: Option<usize>,
45 device: Option<Device>,
46 compiled_denoiser: bool,
48 compiled_text_encoder: bool,
50 compiled_vae: bool,
52 nvfp4: Option<bool>,
54 skip_text_encoder: bool,
56 aot_cache_dir: Option<PathBuf>,
58 drop_text_encoder_after_encode: Option<bool>,
60 lora_path: Option<PathBuf>,
62 lora_scale: f32,
63 use_flow_api: bool,
65 dual_time_embedder: bool,
67}
68
69impl Flux2RunnerBuilder {
70 pub fn weights<P: Into<PathBuf>>(mut self, p: P) -> Self {
71 self.weights = Some(p.into());
72 self
73 }
74 pub fn config(mut self, cfg: crate::Flux2Config) -> Self {
75 self.config = Some(cfg);
76 self
77 }
78 pub fn config_path<P: Into<PathBuf>>(mut self, p: P) -> Self {
79 self.config_path = Some(p.into());
80 self
81 }
82 pub fn batch(mut self, n: usize) -> Self {
83 self.batch = Some(n);
84 self
85 }
86 pub fn img_seq(mut self, n: usize) -> Self {
87 self.img_seq = Some(n);
88 self
89 }
90 pub fn txt_seq(mut self, n: usize) -> Self {
91 self.txt_seq = Some(n);
92 self
93 }
94 pub fn device(mut self, d: Device) -> Self {
95 self.device = Some(d);
96 self
97 }
98
99 pub fn compiled_denoiser(mut self, yes: bool) -> Self {
101 self.compiled_denoiser = yes;
102 self
103 }
104
105 pub fn compiled_text_encoder(mut self, yes: bool) -> Self {
107 self.compiled_text_encoder = yes;
108 self
109 }
110
111 pub fn compiled_vae(mut self, yes: bool) -> Self {
113 self.compiled_vae = yes;
114 self
115 }
116
117 pub fn text_encoder_dir<P: Into<PathBuf>>(mut self, path: P) -> Self {
118 self.text_encoder_dir = Some(path.into());
119 self
120 }
121
122 pub fn text_encoder_config_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
123 self.text_encoder_config_path = Some(path.into());
124 self
125 }
126
127 pub fn tokenizer_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
128 self.tokenizer_path = Some(path.into());
129 self
130 }
131
132 pub fn vae_dir<P: Into<PathBuf>>(mut self, path: P) -> Self {
133 self.vae_dir = Some(path.into());
134 self
135 }
136
137 pub fn vae_config_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
138 self.vae_config_path = Some(path.into());
139 self
140 }
141
142 pub fn nvfp4(mut self, enable: bool) -> Self {
144 self.nvfp4 = Some(enable);
145 self
146 }
147
148 pub fn skip_text_encoder(mut self, yes: bool) -> Self {
150 self.skip_text_encoder = yes;
151 self
152 }
153
154 pub fn aot_cache_dir<P: Into<PathBuf>>(mut self, path: P) -> Self {
156 self.aot_cache_dir = Some(path.into());
157 self
158 }
159
160 pub fn drop_text_encoder_after_encode(mut self, yes: bool) -> Self {
162 self.drop_text_encoder_after_encode = Some(yes);
163 self
164 }
165
166 pub fn lora<P: Into<PathBuf>>(mut self, path: P, scale: f32) -> Self {
168 self.lora_path = Some(path.into());
169 self.lora_scale = scale;
170 self.dual_time_embedder = true;
171 self
172 }
173
174 pub fn dual_time_embedder(mut self, yes: bool) -> Self {
176 self.dual_time_embedder = yes;
177 self
178 }
179
180 pub fn use_flow_api(mut self, yes: bool) -> Self {
182 self.use_flow_api = yes;
183 self
184 }
185
186 pub fn session_key(&self) -> Option<crate::Flux2SessionKey> {
188 self.weights.as_ref().map(|w| crate::Flux2SessionKey {
189 weights: w.clone(),
190 device: self.device.unwrap_or(Device::Cpu),
191 config_path: self.config_path.clone(),
192 lora_path: self.lora_path.clone(),
193 lora_scale_bits: self.lora_scale.to_bits(),
194 nvfp4: self.nvfp4,
195 })
196 }
197
198 pub fn build(self) -> Result<Flux2Runner> {
199 use crate::Flux2VaeConfig;
200 use crate::{
201 ExtractFlux2Opts, Flux2Config, extract_flux2_weights_with_opts,
202 load_flux2_nvfp4_from_file, load_flux2_vae_weights, load_flux2_weight_map,
203 load_text_encoder_weights, load_typed_linears_from_file, prepare_weight_map,
204 resolve_text_encoder_dir, resolve_transformer_config, resolve_vae_dir,
205 safetensors_has_nvfp4,
206 };
207 use rlx_core::gguf_support::{ResolveWeightsOptions, resolve_weights_file_with_options};
208 use rlx_gguf::GgufFile;
209 use rlx_qwen3::Qwen3Config;
210
211 let weights_path = self
212 .weights
213 .as_ref()
214 .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?;
215 let weights_file = resolve_weights_file_with_options(
216 weights_path,
217 &ResolveWeightsOptions {
218 prefer_gguf_substring: Some("Q4_K_M"),
219 ..Default::default()
220 },
221 )?;
222 let is_gguf = weights_file.extension().and_then(|s| s.to_str()) == Some("gguf");
223
224 let cfg = match (self.config, self.config_path.clone()) {
225 (Some(c), _) => c,
226 (_, Some(p)) if !is_gguf => Flux2Config::from_file(&p)?,
227 _ if is_gguf => {
228 let raw = GgufFile::from_path(&weights_file)
229 .with_context(|| format!("opening GGUF {weights_file:?}"))?;
230 Flux2Config::from_gguf(&raw)?
231 }
232 _ => {
233 if let Some(p) = resolve_transformer_config(&weights_file, None) {
234 Flux2Config::from_file(&p)?
235 } else {
236 Flux2Config::flux2_dev()
237 }
238 }
239 };
240 let path = weights_file
241 .to_str()
242 .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
243
244 let device = self.device.unwrap_or(Device::Cpu);
245 use crate::{
246 assert_flux2_device_available, flux2_prefers_compiled_hir, flux2_prefers_compiled_te,
247 };
248 rlx_core::validate_standard_device("flux2", device)?;
249 assert_flux2_device_available(device)?;
250 if self.compiled_text_encoder && !flux2_prefers_compiled_te(device) {
251 anyhow::bail!(
252 "compiled text encoder on {device:?} can take hours and exhaust VRAM; \
253 use native CPU TE (default on CUDA/ROCm/wgpu/Vulkan)"
254 );
255 }
256 let compiled_denoiser = self.compiled_denoiser || flux2_prefers_compiled_hir(device);
257 let compiled_text_encoder = self.compiled_text_encoder || flux2_prefers_compiled_te(device);
258 let compiled_vae = self.compiled_vae || flux2_prefers_compiled_hir(device);
259
260 let use_nvfp4 = if is_gguf {
261 false
262 } else {
263 match self.nvfp4 {
264 Some(yes) => yes,
265 None => safetensors_has_nvfp4(&weights_file).unwrap_or(false),
266 }
267 };
268 let use_gguf_packed = is_gguf
269 && match self.nvfp4 {
270 Some(false) => false,
271 _ => crate::packed_gguf::gguf_has_packed_linears(&weights_file).unwrap_or(false),
272 };
273 let packed = if use_nvfp4 {
274 Some(load_flux2_nvfp4_from_file(&weights_file)?)
275 } else {
276 None
277 };
278
279 let mut exclude_f32 = std::collections::HashSet::new();
280 if let Some(p) = &packed {
281 exclude_f32.extend(p.exclude_f32_keys());
282 }
283 let typed_linears = if is_gguf {
284 None
285 } else if compiled_denoiser && self.lora_path.is_none() {
286 Some(load_typed_linears_from_file(&weights_file, &exclude_f32)?)
287 } else {
288 if self.lora_path.is_some() && compiled_denoiser {
289 eprintln!(
290 "[flux2] LoRA active — using merged f32 weights (typed BF16 linears disabled)"
291 );
292 }
293 None
294 };
295 if let Some(t) = &typed_linears {
296 exclude_f32.extend(t.skip_keys());
297 }
298
299 let (mut wm, gguf_packed) = if use_gguf_packed {
300 if self.lora_path.is_some() {
301 bail!("LoRA merge is not supported on GGUF denoiser weights; use safetensors");
302 }
303 eprintln!("[flux2] loading denoiser GGUF with packed DequantMatMul {weights_file:?}");
304 {
305 let (wm, g) = crate::packed_gguf::load_flux2_from_gguf(&weights_file)?;
306 (wm, Some(g))
307 }
308 } else if is_gguf {
309 if self.lora_path.is_some() {
310 bail!("LoRA merge is not supported on GGUF denoiser weights; use safetensors");
311 }
312 eprintln!("[flux2] loading denoiser from GGUF (F32 drain) {weights_file:?}");
313 (load_flux2_weight_map(&weights_file)?, None)
314 } else {
315 use rlx_core::weight_map::WeightMap;
316 (WeightMap::from_file_excluding(path, &exclude_f32)?, None)
317 };
318 let packed = match (packed, gguf_packed) {
319 (Some(nv), None) => Some(nv),
320 (None, Some(g)) => Some(g),
321 (None, None) => None,
322 (Some(_), Some(_)) => unreachable!("nvfp4 and gguf packed are mutually exclusive"),
323 };
324 if let Some(lora_path) = &self.lora_path {
325 let n = if lora_path.is_dir() {
326 crate::load_and_apply_flux2_lora_dir(&mut wm, lora_path, self.lora_scale)?
327 } else {
328 crate::load_and_apply_flux2_lora(&mut wm, lora_path, self.lora_scale)?
329 };
330 eprintln!(
331 "[flux2] merged {n} LoRA layers from {:?} (scale={})",
332 lora_path, self.lora_scale
333 );
334 }
335 let extract_opts = ExtractFlux2Opts {
336 typed_linears: typed_linears.as_ref(),
337 packed_linears: packed.as_ref(),
338 dual_time_embedder: self.dual_time_embedder || self.lora_path.is_some(),
339 };
340 if extract_opts.dual_time_embedder {
341 eprintln!("[flux2] dual-time timestep embedder enabled (flow-map / Diamond Maps)");
342 }
343 let model = extract_flux2_weights_with_opts(prepare_weight_map(wm), &cfg, extract_opts)?;
344
345 let te_dir = if self.skip_text_encoder {
346 None
347 } else {
348 self.text_encoder_dir
349 .or_else(|| resolve_text_encoder_dir(&weights_file))
350 };
351 let (text_encoder, text_encoder_cfg) = if let Some(dir) = te_dir {
352 let te_cfg_path = self
353 .text_encoder_config_path
354 .unwrap_or_else(|| dir.join("config.json"));
355 let te_cfg = Qwen3Config::from_file(&te_cfg_path)?;
356 let te = load_text_encoder_weights(&dir, &te_cfg)?;
357 (Some(te), Some(te_cfg))
358 } else {
359 (None, None)
360 };
361
362 let vae_dir = self.vae_dir.or_else(|| resolve_vae_dir(&weights_file));
363 let (vae, vae_cfg) = if let Some(dir) = vae_dir {
364 let vae_cfg_path = self
365 .vae_config_path
366 .unwrap_or_else(|| dir.join("config.json"));
367 let vae_cfg = Flux2VaeConfig::from_file(&vae_cfg_path)?;
368 let vae = load_flux2_vae_weights(&dir, &vae_cfg)?;
369 (Some(vae), Some(vae_cfg))
370 } else {
371 (None, None)
372 };
373
374 let drop_text_encoder_after_encode = self
375 .drop_text_encoder_after_encode
376 .unwrap_or(!self.skip_text_encoder && compiled_denoiser);
377
378 Ok(Flux2Runner {
379 model,
380 cfg,
381 batch: self.batch.unwrap_or(1),
382 img_seq: self.img_seq.unwrap_or(256),
383 txt_seq: self.txt_seq.unwrap_or(128),
384 device,
385 compiled_denoiser,
386 compiled_text_encoder,
387 compiled_vae,
388 packed,
389 typed_linears,
390 aot_cache_dir: self.aot_cache_dir,
391 drop_text_encoder_after_encode,
392 use_flow_api: self.use_flow_api,
393 text_encoder: Mutex::new(text_encoder),
394 text_encoder_cfg: Mutex::new(text_encoder_cfg),
395 vae,
396 vae_cfg,
397 tokenizer_path: self.tokenizer_path,
398 model_root: weights_file,
399 denoiser: Mutex::new(None),
400 text_encoder_compiled: Mutex::new(None),
401 vae_compiled: Mutex::new(None),
402 vae_encoder_compiled: Mutex::new(None),
403 cfg_compiled: Mutex::new(None),
404 })
405 }
406}
407
408struct Flux2DenoiserCache {
409 compiled: rlx_runtime::CompiledGraph,
410 device: Device,
411 batch: usize,
412 img_seq: usize,
413 txt_seq: usize,
414 img_ids: Vec<f32>,
415 txt_ids: Vec<f32>,
416}
417
418struct Flux2TextEncoderCache {
419 compiled: rlx_runtime::CompiledGraph,
420 device: Device,
421 batch: usize,
422 txt_seq: usize,
423}
424
425struct Flux2VaeCache {
426 compiled: rlx_runtime::CompiledGraph,
427 device: Device,
428 batch: usize,
429 h: usize,
430 w: usize,
431}
432
433struct Flux2VaeEncoderCache {
434 compiled: rlx_runtime::CompiledGraph,
435 device: Device,
436 batch: usize,
437 h: usize,
438 w: usize,
439}
440
441struct Flux2CfgCache {
442 compiled: rlx_runtime::CompiledGraph,
443 device: Device,
444 batch: usize,
445 img_seq: usize,
446 out_dim: usize,
447}
448
449pub struct Flux2Runner {
451 model: crate::Flux2Weights,
452 cfg: crate::Flux2Config,
453 batch: usize,
454 img_seq: usize,
455 txt_seq: usize,
456 device: Device,
457 compiled_denoiser: bool,
459 compiled_text_encoder: bool,
461 compiled_vae: bool,
462 packed: Option<crate::Flux2PackedParams>,
463 typed_linears: Option<crate::TypedLinearStore>,
464 aot_cache_dir: Option<PathBuf>,
465 drop_text_encoder_after_encode: bool,
466 use_flow_api: bool,
467 text_encoder: Mutex<Option<crate::Flux2TextEncoderWeights>>,
468 text_encoder_cfg: Mutex<Option<rlx_qwen3::Qwen3Config>>,
469 vae: Option<crate::Flux2VaeWeights>,
470 vae_cfg: Option<crate::Flux2VaeConfig>,
471 tokenizer_path: Option<PathBuf>,
472 model_root: PathBuf,
473 denoiser: Mutex<Option<Flux2DenoiserCache>>,
474 text_encoder_compiled: Mutex<Option<Flux2TextEncoderCache>>,
475 vae_compiled: Mutex<Option<Flux2VaeCache>>,
476 vae_encoder_compiled: Mutex<Option<Flux2VaeEncoderCache>>,
477 cfg_compiled: Mutex<Option<Flux2CfgCache>>,
478}
479
480impl Flux2Runner {
481 pub fn builder() -> Flux2RunnerBuilder {
482 Flux2RunnerBuilder::default()
483 }
484
485 fn aot_cache(&self) -> Option<rlx_runtime::AotCache> {
486 self.aot_cache_dir
487 .as_ref()
488 .map(|p| rlx_runtime::AotCache::new(p.clone()))
489 }
490
491 pub fn drop_text_encoder_weights(&self) -> Result<()> {
492 if let Ok(mut te) = self.text_encoder.lock() {
493 if te.is_some() {
494 eprintln!("[flux2] dropping text encoder weights (~8GB RAM)");
495 *te = None;
496 }
497 }
498 if let Ok(mut cfg) = self.text_encoder_cfg.lock() {
499 *cfg = None;
500 }
501 if let Ok(mut cache) = self.text_encoder_compiled.lock() {
502 *cache = None;
503 }
504 Ok(())
505 }
506 pub fn config(&self) -> &crate::Flux2Config {
507 &self.cfg
508 }
509 pub fn device(&self) -> Device {
510 self.device
511 }
512
513 pub fn batch(&self) -> usize {
514 self.batch
515 }
516
517 pub fn img_seq(&self) -> usize {
518 self.img_seq
519 }
520
521 pub fn txt_seq(&self) -> usize {
522 self.txt_seq
523 }
524
525 pub fn uses_nvfp4(&self) -> bool {
526 self.packed.is_some()
527 }
528
529 pub fn has_text_encoder(&self) -> bool {
530 self.text_encoder
531 .lock()
532 .map(|g| g.is_some())
533 .unwrap_or(false)
534 }
535
536 pub fn has_vae(&self) -> bool {
537 self.vae.is_some()
538 }
539
540 pub fn uses_compiled_denoiser(&self) -> bool {
542 self.compiled_denoiser
543 }
544
545 pub fn uses_compiled_text_encoder(&self) -> bool {
547 self.compiled_text_encoder
548 }
549
550 pub fn uses_compiled_vae(&self) -> bool {
551 self.compiled_vae
552 }
553
554 pub fn warmup_denoiser(&self, img_ids: &[f32], txt_ids: &[f32]) -> Result<()> {
556 if self.uses_compiled_denoiser() {
557 self.ensure_denoiser_compiled(img_ids, txt_ids)?;
558 }
559 Ok(())
560 }
561
562 fn ensure_denoiser_compiled(&self, img_ids: &[f32], txt_ids: &[f32]) -> Result<()> {
563 use crate::{compile_flux2_forward, compile_flux2_forward_via_flow};
564
565 let mut guard = self
566 .denoiser
567 .lock()
568 .map_err(|e| anyhow!("denoiser cache lock poisoned: {e}"))?;
569 let img_seq = img_ids.len() / (self.batch * 4);
570 let recompile = guard.as_ref().is_none_or(|c| {
571 c.device != self.device
572 || c.batch != self.batch
573 || c.img_seq != img_seq
574 || c.txt_seq != self.txt_seq
575 || c.img_ids != img_ids
576 || c.txt_ids != txt_ids
577 });
578 if recompile {
579 eprintln!(
580 "[flux2] compiling denoiser HIR on {:?} (img_seq={img_seq}, txt_seq={}, flow={})…",
581 self.device, self.txt_seq, self.use_flow_api
582 );
583 let aot = self.aot_cache();
584 let (compiled, _) = if self.use_flow_api {
585 compile_flux2_forward_via_flow(
586 &self.cfg,
587 &self.model,
588 self.batch,
589 img_seq,
590 self.txt_seq,
591 img_ids,
592 txt_ids,
593 self.device,
594 self.packed.as_ref(),
595 self.typed_linears.as_ref(),
596 aot.as_ref(),
597 )?
598 } else {
599 compile_flux2_forward(
600 &self.cfg,
601 &self.model,
602 self.batch,
603 img_seq,
604 self.txt_seq,
605 img_ids,
606 txt_ids,
607 self.device,
608 self.packed.as_ref(),
609 self.typed_linears.as_ref(),
610 aot.as_ref(),
611 )?
612 };
613 *guard = Some(Flux2DenoiserCache {
614 compiled,
615 device: self.device,
616 batch: self.batch,
617 img_seq,
618 txt_seq: self.txt_seq,
619 img_ids: img_ids.to_vec(),
620 txt_ids: txt_ids.to_vec(),
621 });
622 }
623 Ok(())
624 }
625
626 pub fn encode_prompt(&self, prompt: &str) -> Result<(Vec<f32>, Vec<f32>)> {
631 if self.uses_compiled_text_encoder() {
632 return self.encode_prompt_compiled(prompt);
633 }
634 eprintln!("[flux2] text encoder: native CPU forward");
635 self.encode_prompt_native(prompt)
636 }
637
638 pub fn encode_prompt_native(&self, prompt: &str) -> Result<(Vec<f32>, Vec<f32>)> {
640 use crate::{
641 DEFAULT_TEXT_ENCODER_LAYERS, encode_flux2_prompt, encode_prompt_padded,
642 resolve_tokenizer_path,
643 };
644
645 let tok_path = resolve_tokenizer_path(&self.model_root, self.tokenizer_path.as_deref())
646 .ok_or_else(|| {
647 anyhow!(
648 "no tokenizer found near {:?}; pass .tokenizer_path(...)",
649 self.model_root
650 )
651 })?;
652 let input_ids = encode_prompt_padded(&tok_path, prompt, self.txt_seq)?;
653 let (out, txt_ids) = {
654 let te_guard = self
655 .text_encoder
656 .lock()
657 .map_err(|e| anyhow!("text encoder lock poisoned: {e}"))?;
658 let te = te_guard.as_ref().ok_or_else(|| {
659 anyhow!("text encoder not loaded (pass .text_encoder_dir(...) on build)")
660 })?;
661 let cfg_guard = self
662 .text_encoder_cfg
663 .lock()
664 .map_err(|e| anyhow!("text encoder cfg lock poisoned: {e}"))?;
665 let te_cfg = cfg_guard
666 .as_ref()
667 .ok_or_else(|| anyhow!("text encoder config missing"))?;
668 encode_flux2_prompt(
669 te,
670 te_cfg,
671 &input_ids,
672 self.batch,
673 self.txt_seq,
674 DEFAULT_TEXT_ENCODER_LAYERS,
675 )?
676 };
677 ensure!(
678 out.joint_dim == self.cfg.joint_attention_dim,
679 "text encoder joint_dim {} != transformer joint_attention_dim {}",
680 out.joint_dim,
681 self.cfg.joint_attention_dim
682 );
683 Ok((out.prompt_embeds, txt_ids))
684 }
685
686 fn ensure_text_encoder_compiled(&self) -> Result<()> {
687 use crate::{DEFAULT_TEXT_ENCODER_LAYERS, compile_flux2_text_encoder_hir};
688
689 let te = {
690 let guard = self
691 .text_encoder
692 .lock()
693 .map_err(|e| anyhow!("text encoder lock poisoned: {e}"))?;
694 guard
695 .as_ref()
696 .ok_or_else(|| anyhow!("text encoder not loaded"))?
697 .clone()
698 };
699 let te_cfg = {
700 let guard = self
701 .text_encoder_cfg
702 .lock()
703 .map_err(|e| anyhow!("text encoder cfg lock poisoned: {e}"))?;
704 guard
705 .as_ref()
706 .ok_or_else(|| anyhow!("text encoder config missing"))?
707 .clone()
708 };
709
710 let mut guard = self
711 .text_encoder_compiled
712 .lock()
713 .map_err(|e| anyhow!("text encoder cache lock poisoned: {e}"))?;
714 let recompile = guard.as_ref().is_none_or(|c| {
715 c.device != self.device || c.batch != self.batch || c.txt_seq != self.txt_seq
716 });
717 if recompile {
718 eprintln!(
719 "[flux2] compiling text encoder HIR on {:?} (txt_seq={})…",
720 self.device, self.txt_seq
721 );
722 let aot = self.aot_cache();
723 let (compiled, _) = compile_flux2_text_encoder_hir(
724 &te_cfg,
725 &te,
726 self.batch,
727 self.txt_seq,
728 DEFAULT_TEXT_ENCODER_LAYERS,
729 self.device,
730 aot.as_ref(),
731 )?;
732 *guard = Some(Flux2TextEncoderCache {
733 compiled,
734 device: self.device,
735 batch: self.batch,
736 txt_seq: self.txt_seq,
737 });
738 }
739 Ok(())
740 }
741
742 fn ensure_cfg_compiled(&self, img_seq: usize) -> Result<()> {
743 use crate::compile_flux2_cfg_combine;
744
745 let out_dim = self.cfg.proj_out_dim();
746 let mut guard = self
747 .cfg_compiled
748 .lock()
749 .map_err(|e| anyhow!("cfg cache lock poisoned: {e}"))?;
750 let recompile = guard.as_ref().is_none_or(|c| {
751 c.device != self.device
752 || c.batch != self.batch
753 || c.img_seq != img_seq
754 || c.out_dim != out_dim
755 });
756 if recompile {
757 let aot = self.aot_cache();
758 let compiled =
759 compile_flux2_cfg_combine(self.batch, img_seq, out_dim, self.device, aot.as_ref())?;
760 *guard = Some(Flux2CfgCache {
761 compiled,
762 device: self.device,
763 batch: self.batch,
764 img_seq,
765 out_dim,
766 });
767 }
768 Ok(())
769 }
770
771 fn cfg_combine_compiled(
772 &self,
773 pos: &[f32],
774 neg: &[f32],
775 scale: f32,
776 img_seq: usize,
777 ) -> Result<Vec<f32>> {
778 self.ensure_cfg_compiled(img_seq)?;
779 let mut guard = self
780 .cfg_compiled
781 .lock()
782 .map_err(|e| anyhow!("cfg cache lock poisoned: {e}"))?;
783 let cache = guard
784 .as_mut()
785 .ok_or_else(|| anyhow!("cfg compile cache missing"))?;
786 Ok(cache
787 .compiled
788 .run(&[("pos", pos), ("neg", neg), ("guidance_scale", &[scale])])
789 .remove(0))
790 }
791
792 pub fn encode_prompt_compiled(&self, prompt: &str) -> Result<(Vec<f32>, Vec<f32>)> {
794 use crate::{
795 DEFAULT_TEXT_ENCODER_LAYERS, encode_prompt_padded, prepare_text_ids,
796 resolve_tokenizer_path,
797 };
798
799 let te_cfg = self
800 .text_encoder_cfg
801 .lock()
802 .map_err(|e| anyhow!("text encoder cfg lock poisoned: {e}"))?;
803 let te_cfg = te_cfg
804 .as_ref()
805 .ok_or_else(|| anyhow!("text encoder config missing"))?;
806
807 let tok_path = resolve_tokenizer_path(&self.model_root, self.tokenizer_path.as_deref())
808 .ok_or_else(|| anyhow!("no tokenizer found near {:?}", self.model_root))?;
809 let input_ids = encode_prompt_padded(&tok_path, prompt, self.txt_seq)?;
810 let ids_f32: Vec<f32> = input_ids.iter().map(|&x| x as f32).collect();
811
812 self.ensure_text_encoder_compiled()?;
813 let mut guard = self
814 .text_encoder_compiled
815 .lock()
816 .map_err(|e| anyhow!("text encoder cache lock poisoned: {e}"))?;
817 let cache = guard
818 .as_mut()
819 .ok_or_else(|| anyhow!("text encoder compile cache missing"))?;
820 let embeds = cache
821 .compiled
822 .run(&[("input_ids", ids_f32.as_slice())])
823 .remove(0);
824 let joint = te_cfg.hidden_size * DEFAULT_TEXT_ENCODER_LAYERS.len();
825 ensure!(
826 joint == self.cfg.joint_attention_dim,
827 "text encoder joint_dim {joint} != transformer {}",
828 self.cfg.joint_attention_dim
829 );
830 let txt_ids = prepare_text_ids(self.batch, self.txt_seq);
831 Ok((embeds, txt_ids))
832 }
833
834 pub fn forward(
836 &self,
837 hidden_states: &[f32],
838 encoder_hidden_states: &[f32],
839 timestep: &[f32],
840 guidance: Option<&[f32]>,
841 img_ids: &[f32],
842 txt_ids: &[f32],
843 ) -> Result<Flux2Output> {
844 let noise_pred = self.forward_noise(
845 hidden_states,
846 encoder_hidden_states,
847 timestep,
848 guidance,
849 img_ids,
850 txt_ids,
851 )?;
852 Ok(Flux2Output {
853 noise_pred,
854 img_seq: hidden_states.len() / (self.batch * self.cfg.in_channels),
855 out_dim: self.cfg.proj_out_dim(),
856 })
857 }
858
859 pub fn vae_encode_rgb(&self, rgb: &[f32], pixel_h: usize, pixel_w: usize) -> Result<Vec<f32>> {
861 if self.uses_compiled_vae() {
862 return self.vae_encode_rgb_compiled(rgb, pixel_h, pixel_w);
863 }
864 let vae = self.vae.as_ref().ok_or_else(|| anyhow!("VAE not loaded"))?;
865 let vae_cfg = self
866 .vae_cfg
867 .as_ref()
868 .ok_or_else(|| anyhow!("VAE config missing"))?;
869 crate::flux2_vae_encode(vae, vae_cfg, rgb, self.batch, pixel_h, pixel_w)
870 }
871
872 fn vae_encode_rgb_compiled(
873 &self,
874 rgb: &[f32],
875 pixel_h: usize,
876 pixel_w: usize,
877 ) -> Result<Vec<f32>> {
878 let vae_cfg = self
879 .vae_cfg
880 .as_ref()
881 .ok_or_else(|| anyhow!("VAE config missing"))?;
882
883 self.ensure_vae_encoder_compiled(pixel_h, pixel_w)?;
884 let mut guard = self
885 .vae_encoder_compiled
886 .lock()
887 .map_err(|e| anyhow!("vae encoder cache lock poisoned: {e}"))?;
888 let cache = guard
889 .as_mut()
890 .ok_or_else(|| anyhow!("vae encoder compile cache missing"))?;
891 let mut latent = cache.compiled.run(&[("rgb", rgb)]).remove(0);
892 if vae_cfg.scaling_factor != 1.0 || vae_cfg.shift_factor != 0.0 {
893 for v in &mut latent {
894 *v = (*v - vae_cfg.shift_factor) * vae_cfg.scaling_factor;
895 }
896 }
897 Ok(latent)
898 }
899
900 pub fn encode_rgb_to_packed(
902 &self,
903 rgb: &[f32],
904 pixel_h: usize,
905 pixel_w: usize,
906 latent_h: usize,
907 latent_w: usize,
908 eff_h: usize,
909 eff_w: usize,
910 ) -> Result<Vec<f32>> {
911 use crate::pack_encoded_latents;
912
913 let vae = self
914 .vae
915 .as_ref()
916 .ok_or_else(|| anyhow!("VAE not loaded (required for img2img / edit)"))?;
917 let vae_cfg = self
918 .vae_cfg
919 .as_ref()
920 .ok_or_else(|| anyhow!("VAE config missing"))?;
921 let stride = vae_cfg.encode_spatial_stride();
922 let enc_h = pixel_h / stride;
923 let enc_w = pixel_w / stride;
924 ensure!(
925 enc_h > 0 && enc_w > 0,
926 "encoded spatial dims too small for {pixel_h}x{pixel_w}"
927 );
928 let encoded = self.vae_encode_rgb(rgb, pixel_h, pixel_w)?;
929 ensure!(
930 encoded.len() == self.batch * vae_cfg.latent_channels * enc_h * enc_w,
931 "encoded len {} != expected {}",
932 encoded.len(),
933 self.batch * vae_cfg.latent_channels * enc_h * enc_w
934 );
935 pack_encoded_latents(
936 vae, vae_cfg, encoded, self.batch, enc_h, enc_w, eff_h, eff_w, latent_h, latent_w,
937 )
938 }
939
940 pub fn has_vae_encoder(&self) -> bool {
941 self.vae.is_some()
942 }
943
944 pub fn prepare_img2img_packed(
946 &self,
947 rgb: &[f32],
948 pixel_h: usize,
949 pixel_w: usize,
950 latent_h: usize,
951 latent_w: usize,
952 eff_h: usize,
953 eff_w: usize,
954 noise: &[f32],
955 image_strength: f32,
956 num_inference_steps: usize,
957 ) -> Result<Vec<f32>> {
958 use crate::latent_ops::blend_latents_with_noise;
959 use crate::{flow_match_init_timestep, flow_match_sigmas};
960
961 let clean =
962 self.encode_rgb_to_packed(rgb, pixel_h, pixel_w, latent_h, latent_w, eff_h, eff_w)?;
963 ensure!(clean.len() == noise.len());
964 let sigmas = flow_match_sigmas(num_inference_steps);
965 let init_step = flow_match_init_timestep(image_strength, num_inference_steps);
966 let sigma = sigmas[init_step.min(sigmas.len() - 1)];
967 Ok(blend_latents_with_noise(&clean, noise, sigma))
968 }
969
970 pub fn prepare_edit_conditioning(
972 &self,
973 images: &[(&[f32], usize, usize)],
974 eff_h: usize,
975 eff_w: usize,
976 latent_h: usize,
977 latent_w: usize,
978 ) -> Result<crate::Flux2ReferenceConditioning> {
979 use crate::{
980 Flux2ReferenceConditioning, concat_latent_ids, concat_packed_latents,
981 prepare_latent_ids_with_t,
982 };
983
984 ensure!(
985 !images.is_empty(),
986 "edit requires at least one reference image"
987 );
988 let vae_cfg = self
989 .vae_cfg
990 .as_ref()
991 .ok_or_else(|| anyhow!("VAE config missing"))?;
992 let channels = vae_cfg.bn_channels();
993 let mut packed_acc: Option<Vec<f32>> = None;
994 let mut ids_acc: Option<Vec<f32>> = None;
995 let mut total_seq = 0usize;
996
997 for (i, (rgb, ph, pw)) in images.iter().enumerate() {
998 let packed =
999 self.encode_rgb_to_packed(rgb, *ph, *pw, latent_h, latent_w, eff_h, eff_w)?;
1000 let seq = packed.len() / (self.batch * channels);
1001 total_seq += seq;
1002 let ids = prepare_latent_ids_with_t(self.batch, latent_h, latent_w, 10 + 10 * i as i32);
1003 packed_acc = Some(match packed_acc {
1004 Some(prev) => concat_packed_latents(&prev, &packed, self.batch, channels),
1005 None => packed,
1006 });
1007 ids_acc = Some(match ids_acc {
1008 Some(prev) => concat_latent_ids(&prev, &ids, self.batch),
1009 None => ids,
1010 });
1011 }
1012
1013 Ok(Flux2ReferenceConditioning {
1014 packed: packed_acc.unwrap(),
1015 img_ids: ids_acc.unwrap(),
1016 seq: total_seq,
1017 })
1018 }
1019
1020 pub fn forward_noise(
1022 &self,
1023 hidden_states: &[f32],
1024 encoder_hidden_states: &[f32],
1025 timestep: &[f32],
1026 guidance: Option<&[f32]>,
1027 img_ids: &[f32],
1028 txt_ids: &[f32],
1029 ) -> Result<Vec<f32>> {
1030 if self.uses_compiled_denoiser() {
1031 self.forward_noise_compiled(
1032 hidden_states,
1033 encoder_hidden_states,
1034 timestep,
1035 guidance,
1036 img_ids,
1037 txt_ids,
1038 )
1039 } else {
1040 self.forward_noise_native(
1041 hidden_states,
1042 encoder_hidden_states,
1043 timestep,
1044 guidance,
1045 img_ids,
1046 txt_ids,
1047 )
1048 }
1049 }
1050
1051 pub fn forward_noise_native(
1053 &self,
1054 hidden_states: &[f32],
1055 encoder_hidden_states: &[f32],
1056 timestep: &[f32],
1057 guidance: Option<&[f32]>,
1058 img_ids: &[f32],
1059 txt_ids: &[f32],
1060 ) -> Result<Vec<f32>> {
1061 use crate::{Flux2ForwardInput, flux2_transformer_forward};
1062
1063 flux2_transformer_forward(
1064 &self.model,
1065 &self.cfg,
1066 Flux2ForwardInput {
1067 hidden_states,
1068 encoder_hidden_states,
1069 timestep,
1070 timestep_target: None,
1071 guidance,
1072 img_ids,
1073 txt_ids,
1074 batch: self.batch,
1075 img_seq: hidden_states.len() / (self.batch * self.cfg.in_channels),
1076 txt_seq: self.txt_seq,
1077 },
1078 )
1079 }
1080
1081 pub fn forward_noise_dual_native(
1083 &self,
1084 hidden_states: &[f32],
1085 encoder_hidden_states: &[f32],
1086 timestep: &[f32],
1087 timestep_target: &[f32],
1088 guidance: Option<&[f32]>,
1089 img_ids: &[f32],
1090 txt_ids: &[f32],
1091 ) -> Result<Vec<f32>> {
1092 use crate::{Flux2ForwardInput, flux2_transformer_forward};
1093
1094 flux2_transformer_forward(
1095 &self.model,
1096 &self.cfg,
1097 Flux2ForwardInput {
1098 hidden_states,
1099 encoder_hidden_states,
1100 timestep,
1101 timestep_target: Some(timestep_target),
1102 guidance,
1103 img_ids,
1104 txt_ids,
1105 batch: self.batch,
1106 img_seq: hidden_states.len() / (self.batch * self.cfg.in_channels),
1107 txt_seq: self.txt_seq,
1108 },
1109 )
1110 }
1111
1112 pub fn forward_noise_compiled(
1114 &self,
1115 hidden_states: &[f32],
1116 encoder_hidden_states: &[f32],
1117 timestep: &[f32],
1118 guidance: Option<&[f32]>,
1119 img_ids: &[f32],
1120 txt_ids: &[f32],
1121 ) -> Result<Vec<f32>> {
1122 use crate::host_temb;
1123
1124 self.ensure_denoiser_compiled(img_ids, txt_ids)?;
1125 let temb = host_temb(&self.model, &self.cfg, timestep, guidance)?;
1126 let mut guard = self
1127 .denoiser
1128 .lock()
1129 .map_err(|e| anyhow!("denoiser cache lock poisoned: {e}"))?;
1130 let cache = guard
1131 .as_mut()
1132 .ok_or_else(|| anyhow!("denoiser compile cache missing after ensure"))?;
1133 Ok(cache
1134 .compiled
1135 .run(&[
1136 ("hidden", hidden_states),
1137 ("encoder", encoder_hidden_states),
1138 ("temb", temb.as_slice()),
1139 ])
1140 .remove(0))
1141 }
1142
1143 pub fn forward_noise_dual_compiled(
1145 &self,
1146 hidden_states: &[f32],
1147 encoder_hidden_states: &[f32],
1148 timestep: &[f32],
1149 timestep_target: &[f32],
1150 guidance: Option<&[f32]>,
1151 img_ids: &[f32],
1152 txt_ids: &[f32],
1153 ) -> Result<Vec<f32>> {
1154 use crate::host_temb_dual;
1155
1156 self.ensure_denoiser_compiled(img_ids, txt_ids)?;
1157 let temb = host_temb_dual(&self.model, &self.cfg, timestep, timestep_target, guidance)?;
1158 let mut guard = self
1159 .denoiser
1160 .lock()
1161 .map_err(|e| anyhow!("denoiser cache lock poisoned: {e}"))?;
1162 let cache = guard
1163 .as_mut()
1164 .ok_or_else(|| anyhow!("denoiser compile cache missing after ensure"))?;
1165 Ok(cache
1166 .compiled
1167 .run(&[
1168 ("hidden", hidden_states),
1169 ("encoder", encoder_hidden_states),
1170 ("temb", temb.as_slice()),
1171 ])
1172 .remove(0))
1173 }
1174
1175 pub fn forward_cfg(
1178 &self,
1179 hidden_states: &[f32],
1180 pos_encoder: &[f32],
1181 neg_encoder: &[f32],
1182 timestep: &[f32],
1183 guidance: Option<&[f32]>,
1184 img_ids: &[f32],
1185 pos_txt_ids: &[f32],
1186 neg_txt_ids: &[f32],
1187 cfg_scale: f32,
1188 ) -> Result<Flux2Output> {
1189 use crate::cfg_combine;
1190
1191 let pos = self.forward_noise(
1192 hidden_states,
1193 pos_encoder,
1194 timestep,
1195 guidance,
1196 img_ids,
1197 pos_txt_ids,
1198 )?;
1199 if cfg_scale <= 1.0 {
1200 return Ok(Flux2Output {
1201 noise_pred: pos,
1202 img_seq: hidden_states.len() / (self.batch * self.cfg.in_channels),
1203 out_dim: self.cfg.proj_out_dim(),
1204 });
1205 }
1206 let neg = self.forward_noise(
1207 hidden_states,
1208 neg_encoder,
1209 timestep,
1210 guidance,
1211 img_ids,
1212 neg_txt_ids,
1213 )?;
1214 let img_seq = hidden_states.len() / (self.batch * self.cfg.in_channels);
1215 let noise_pred = if self.uses_compiled_denoiser() {
1216 self.cfg_combine_compiled(&pos, &neg, cfg_scale, img_seq)?
1217 } else {
1218 cfg_combine(&pos, &neg, cfg_scale)
1219 };
1220 Ok(Flux2Output {
1221 noise_pred,
1222 img_seq,
1223 out_dim: self.cfg.proj_out_dim(),
1224 })
1225 }
1226
1227 #[allow(clippy::type_complexity)]
1229 pub fn encode_prompt_pair(
1230 &self,
1231 prompt: &str,
1232 negative_prompt: Option<&str>,
1233 ) -> Result<(Vec<f32>, Vec<f32>, Option<Vec<f32>>, Option<Vec<f32>>)> {
1234 let (pos, pos_ids) = self.encode_prompt(prompt)?;
1235 let (neg, neg_ids) = match negative_prompt {
1236 Some(n) => {
1237 let (e, ids) = self.encode_prompt(n)?;
1238 (Some(e), Some(ids))
1239 }
1240 None => (None, None),
1241 };
1242 if self.drop_text_encoder_after_encode {
1243 self.drop_text_encoder_weights()?;
1244 }
1245 Ok((pos, pos_ids, neg, neg_ids))
1246 }
1247
1248 pub fn vae_config(&self) -> Option<&crate::Flux2VaeConfig> {
1249 self.vae_cfg.as_ref()
1250 }
1251
1252 pub fn decode_to_rgb(
1254 &self,
1255 packed_latents: &[f32],
1256 img_ids: &[f32],
1257 latent_h: usize,
1258 latent_w: usize,
1259 ) -> Result<(Vec<u8>, u32, u32)> {
1260 if self.uses_compiled_vae() {
1261 return self.decode_to_rgb_compiled(packed_latents, img_ids, latent_h, latent_w);
1262 }
1263 self.decode_to_rgb_native(packed_latents, img_ids, latent_h, latent_w)
1264 }
1265
1266 pub fn decode_to_rgb_native(
1268 &self,
1269 packed_latents: &[f32],
1270 img_ids: &[f32],
1271 latent_h: usize,
1272 latent_w: usize,
1273 ) -> Result<(Vec<u8>, u32, u32)> {
1274 use crate::{flux2_decode_packed_latents, flux2_rgb_to_u8};
1275
1276 let vae = self
1277 .vae
1278 .as_ref()
1279 .ok_or_else(|| anyhow!("VAE not loaded (place vae/ next to weights)"))?;
1280 let vae_cfg = self
1281 .vae_cfg
1282 .as_ref()
1283 .ok_or_else(|| anyhow!("VAE config missing"))?;
1284 let packed_channels = self.cfg.in_channels;
1285 let img_seq = img_ids.len() / (self.batch * 4);
1286 let rgb = flux2_decode_packed_latents(
1287 vae,
1288 vae_cfg,
1289 packed_latents,
1290 img_ids,
1291 self.batch,
1292 img_seq,
1293 packed_channels,
1294 latent_h,
1295 latent_w,
1296 )?;
1297 let up_stages = vae_cfg.block_out_channels.len().saturating_sub(1);
1298 let scale = 2usize.pow(up_stages as u32 + 1);
1299 let h_px = latent_h * scale;
1300 let w_px = latent_w * scale;
1301 let u8 = flux2_rgb_to_u8(&rgb, self.batch, 3, h_px, w_px);
1302 Ok((u8, h_px as u32, w_px as u32))
1303 }
1304
1305 pub fn decode_to_rgb_compiled(
1307 &self,
1308 packed_latents: &[f32],
1309 img_ids: &[f32],
1310 latent_h: usize,
1311 latent_w: usize,
1312 ) -> Result<(Vec<u8>, u32, u32)> {
1313 use crate::{
1314 denorm_patchified_latents, flux2_rgb_to_u8, unpack_latents_with_ids, unpatchify_latents,
1315 };
1316
1317 let vae_cfg = self
1318 .vae_cfg
1319 .as_ref()
1320 .ok_or_else(|| anyhow!("VAE config missing"))?;
1321 let vae = self.vae.as_ref().ok_or_else(|| anyhow!("VAE not loaded"))?;
1322 let packed_channels = self.cfg.in_channels;
1323 let img_seq = img_ids.len() / (self.batch * 4);
1324 let spatial = unpack_latents_with_ids(
1325 packed_latents,
1326 img_ids,
1327 self.batch,
1328 img_seq,
1329 packed_channels,
1330 latent_h,
1331 latent_w,
1332 )?;
1333 let denorm = denorm_patchified_latents(
1334 &spatial,
1335 &vae.bn_running_mean,
1336 &vae.bn_running_var,
1337 vae_cfg.batch_norm_eps,
1338 );
1339 let mut latents =
1340 unpatchify_latents(&denorm, self.batch, packed_channels, latent_h, latent_w);
1341 if vae_cfg.scaling_factor != 1.0 || vae_cfg.shift_factor != 0.0 {
1342 for v in &mut latents {
1343 *v = *v / vae_cfg.scaling_factor + vae_cfg.shift_factor;
1344 }
1345 }
1346 let h2 = latent_h * 2;
1347 let w2 = latent_w * 2;
1348
1349 self.ensure_vae_compiled(h2, w2)?;
1350 let mut guard = self
1351 .vae_compiled
1352 .lock()
1353 .map_err(|e| anyhow!("vae cache lock poisoned: {e}"))?;
1354 let cache = guard
1355 .as_mut()
1356 .ok_or_else(|| anyhow!("vae compile cache missing"))?;
1357 let rgb = cache
1358 .compiled
1359 .run(&[("latents", latents.as_slice())])
1360 .remove(0);
1361
1362 let up_stages = vae_cfg.block_out_channels.len().saturating_sub(1);
1363 let scale = 2usize.pow(up_stages as u32 + 1);
1364 let h_px = latent_h * scale;
1365 let w_px = latent_w * scale;
1366 let u8 = flux2_rgb_to_u8(&rgb, self.batch, 3, h_px, w_px);
1367 Ok((u8, h_px as u32, w_px as u32))
1368 }
1369
1370 fn ensure_vae_compiled(&self, h: usize, w: usize) -> Result<()> {
1371 use crate::compile_flux2_vae_hir;
1372
1373 let vae = self.vae.as_ref().ok_or_else(|| anyhow!("VAE not loaded"))?;
1374 let vae_cfg = self
1375 .vae_cfg
1376 .as_ref()
1377 .ok_or_else(|| anyhow!("VAE config missing"))?;
1378
1379 let mut guard = self
1380 .vae_compiled
1381 .lock()
1382 .map_err(|e| anyhow!("vae cache lock poisoned: {e}"))?;
1383 let recompile = guard.as_ref().is_none_or(|c| {
1384 c.device != self.device || c.batch != self.batch || c.h != h || c.w != w
1385 });
1386 if recompile {
1387 let aot = self.aot_cache();
1388 let (compiled, _) =
1389 compile_flux2_vae_hir(vae_cfg, vae, self.batch, h, w, self.device, aot.as_ref())?;
1390 *guard = Some(Flux2VaeCache {
1391 compiled,
1392 device: self.device,
1393 batch: self.batch,
1394 h,
1395 w,
1396 });
1397 }
1398 Ok(())
1399 }
1400
1401 fn ensure_vae_encoder_compiled(&self, h: usize, w: usize) -> Result<()> {
1402 use crate::compile_flux2_vae_encoder_hir;
1403
1404 let vae = self.vae.as_ref().ok_or_else(|| anyhow!("VAE not loaded"))?;
1405 let vae_cfg = self
1406 .vae_cfg
1407 .as_ref()
1408 .ok_or_else(|| anyhow!("VAE config missing"))?;
1409
1410 let mut guard = self
1411 .vae_encoder_compiled
1412 .lock()
1413 .map_err(|e| anyhow!("vae encoder cache lock poisoned: {e}"))?;
1414 let recompile = guard.as_ref().is_none_or(|c| {
1415 c.device != self.device || c.batch != self.batch || c.h != h || c.w != w
1416 });
1417 if recompile {
1418 let aot = self.aot_cache();
1419 let (compiled, _) = compile_flux2_vae_encoder_hir(
1420 vae_cfg,
1421 vae,
1422 self.batch,
1423 h,
1424 w,
1425 self.device,
1426 aot.as_ref(),
1427 )?;
1428 *guard = Some(Flux2VaeEncoderCache {
1429 compiled,
1430 device: self.device,
1431 batch: self.batch,
1432 h,
1433 w,
1434 });
1435 }
1436 Ok(())
1437 }
1438}