Skip to main content

rlx_flux2/
runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! High-level FLUX.2 runner (denoiser + text encoder + VAE caches).
17
18use anyhow::{Context, Result, anyhow, bail, ensure};
19use rlx_runtime::Device;
20use std::path::PathBuf;
21use std::sync::Mutex;
22
23/// Noise prediction from [`Flux2Runner::forward`].
24#[derive(Debug, Clone)]
25pub struct Flux2Output {
26    pub noise_pred: Vec<f32>,
27    pub img_seq: usize,
28    pub out_dim: usize,
29}
30
31/// Builder for [`Flux2Runner`].
32#[derive(Debug, Clone, Default)]
33pub struct Flux2RunnerBuilder {
34    weights: Option<PathBuf>,
35    config: Option<crate::Flux2Config>,
36    config_path: Option<PathBuf>,
37    text_encoder_dir: Option<PathBuf>,
38    text_encoder_config_path: Option<PathBuf>,
39    vae_dir: Option<PathBuf>,
40    vae_config_path: Option<PathBuf>,
41    tokenizer_path: Option<PathBuf>,
42    batch: Option<usize>,
43    img_seq: Option<usize>,
44    txt_seq: Option<usize>,
45    device: Option<Device>,
46    /// Use HIR denoiser on CPU too (default: native on CPU, compiled on GPU backends).
47    compiled_denoiser: bool,
48    /// Use HIR text encoder on CPU too (default: native CPU; compiled on Metal/MLX only).
49    compiled_text_encoder: bool,
50    /// Use HIR VAE decoder on CPU too (default: native on CPU, compiled on GPU backends).
51    compiled_vae: bool,
52    /// Load NVFP4 packed linears from weights (`None` = auto-detect U8+F8 pairs).
53    nvfp4: Option<bool>,
54    /// Skip loading Qwen3 text encoder weights (saves ~8GB RAM; img2img/edit with empty prompt).
55    skip_text_encoder: bool,
56    /// Persist compiled LIR to disk (speeds up repeat runs).
57    aot_cache_dir: Option<PathBuf>,
58    /// After prompt encode, drop TE weights + compiled cache to free RAM before denoiser.
59    drop_text_encoder_after_encode: Option<bool>,
60    /// Optional LoRA safetensors (merged into base weights before extract).
61    lora_path: Option<PathBuf>,
62    lora_scale: f32,
63    /// Compile denoiser via tier-0 [`crate::Flux2Flow`] API (AOT key suffix `_flow`).
64    use_flow_api: bool,
65    /// Second timestep embedder for flow-map dual-time (auto on when LoRA is set).
66    dual_time_embedder: bool,
67}
68
69impl Flux2RunnerBuilder {
70    pub fn weights<P: Into<PathBuf>>(mut self, p: P) -> Self {
71        self.weights = Some(p.into());
72        self
73    }
74    pub fn config(mut self, cfg: crate::Flux2Config) -> Self {
75        self.config = Some(cfg);
76        self
77    }
78    pub fn config_path<P: Into<PathBuf>>(mut self, p: P) -> Self {
79        self.config_path = Some(p.into());
80        self
81    }
82    pub fn batch(mut self, n: usize) -> Self {
83        self.batch = Some(n);
84        self
85    }
86    pub fn img_seq(mut self, n: usize) -> Self {
87        self.img_seq = Some(n);
88        self
89    }
90    pub fn txt_seq(mut self, n: usize) -> Self {
91        self.txt_seq = Some(n);
92        self
93    }
94    pub fn device(mut self, d: Device) -> Self {
95        self.device = Some(d);
96        self
97    }
98
99    /// Run the denoiser via compiled HIR on CPU as well (for parity / bench).
100    pub fn compiled_denoiser(mut self, yes: bool) -> Self {
101        self.compiled_denoiser = yes;
102        self
103    }
104
105    /// Run the text encoder via compiled HIR on CPU as well (for parity / bench).
106    pub fn compiled_text_encoder(mut self, yes: bool) -> Self {
107        self.compiled_text_encoder = yes;
108        self
109    }
110
111    /// Run the VAE decoder via compiled HIR on CPU as well (for parity / bench).
112    pub fn compiled_vae(mut self, yes: bool) -> Self {
113        self.compiled_vae = yes;
114        self
115    }
116
117    pub fn text_encoder_dir<P: Into<PathBuf>>(mut self, path: P) -> Self {
118        self.text_encoder_dir = Some(path.into());
119        self
120    }
121
122    pub fn text_encoder_config_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
123        self.text_encoder_config_path = Some(path.into());
124        self
125    }
126
127    pub fn tokenizer_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
128        self.tokenizer_path = Some(path.into());
129        self
130    }
131
132    pub fn vae_dir<P: Into<PathBuf>>(mut self, path: P) -> Self {
133        self.vae_dir = Some(path.into());
134        self
135    }
136
137    pub fn vae_config_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
138        self.vae_config_path = Some(path.into());
139        self
140    }
141
142    /// Force NVFP4 packed weights on/off (`None` = auto-detect in safetensors).
143    pub fn nvfp4(mut self, enable: bool) -> Self {
144        self.nvfp4 = Some(enable);
145        self
146    }
147
148    /// Do not load `text_encoder/` even when present (saves RAM; use with empty prompt).
149    pub fn skip_text_encoder(mut self, yes: bool) -> Self {
150        self.skip_text_encoder = yes;
151        self
152    }
153
154    /// Directory for AOT compile cache (denoiser / TE / VAE / CFG graphs).
155    pub fn aot_cache_dir<P: Into<PathBuf>>(mut self, path: P) -> Self {
156        self.aot_cache_dir = Some(path.into());
157        self
158    }
159
160    /// Drop text-encoder weights after first encode (default: true on GPU compiled paths).
161    pub fn drop_text_encoder_after_encode(mut self, yes: bool) -> Self {
162        self.drop_text_encoder_after_encode = Some(yes);
163        self
164    }
165
166    /// Merge LoRA adapter weights from `path` with strength `scale` before loading the denoiser.
167    pub fn lora<P: Into<PathBuf>>(mut self, path: P, scale: f32) -> Self {
168        self.lora_path = Some(path.into());
169        self.lora_scale = scale;
170        self.dual_time_embedder = true;
171        self
172    }
173
174    /// Use separate (or cloned) timestep embedder weights for dual-time flow-map forwards.
175    pub fn dual_time_embedder(mut self, yes: bool) -> Self {
176        self.dual_time_embedder = yes;
177        self
178    }
179
180    /// Build the denoiser via [`crate::Flux2Flow`] instead of direct HIR builder.
181    pub fn use_flow_api(mut self, yes: bool) -> Self {
182        self.use_flow_api = yes;
183        self
184    }
185
186    /// Cache key for [`crate::Flux2SessionCache`].
187    pub fn session_key(&self) -> Option<crate::Flux2SessionKey> {
188        self.weights.as_ref().map(|w| crate::Flux2SessionKey {
189            weights: w.clone(),
190            device: self.device.unwrap_or(Device::Cpu),
191            config_path: self.config_path.clone(),
192            lora_path: self.lora_path.clone(),
193            lora_scale_bits: self.lora_scale.to_bits(),
194            nvfp4: self.nvfp4,
195        })
196    }
197
198    pub fn build(self) -> Result<Flux2Runner> {
199        use crate::Flux2VaeConfig;
200        use crate::{
201            ExtractFlux2Opts, Flux2Config, extract_flux2_weights_with_opts,
202            load_flux2_nvfp4_from_file, load_flux2_vae_weights, load_flux2_weight_map,
203            load_text_encoder_weights, load_typed_linears_from_file, prepare_weight_map,
204            resolve_text_encoder_dir, resolve_transformer_config, resolve_vae_dir,
205            safetensors_has_nvfp4,
206        };
207        use rlx_core::gguf_support::{ResolveWeightsOptions, resolve_weights_file_with_options};
208        use rlx_gguf::GgufFile;
209        use rlx_qwen3::Qwen3Config;
210
211        let weights_path = self
212            .weights
213            .as_ref()
214            .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?;
215        let weights_file = resolve_weights_file_with_options(
216            weights_path,
217            &ResolveWeightsOptions {
218                prefer_gguf_substring: Some("Q4_K_M"),
219                ..Default::default()
220            },
221        )?;
222        let is_gguf = weights_file.extension().and_then(|s| s.to_str()) == Some("gguf");
223
224        let cfg = match (self.config, self.config_path.clone()) {
225            (Some(c), _) => c,
226            (_, Some(p)) if !is_gguf => Flux2Config::from_file(&p)?,
227            _ if is_gguf => {
228                let raw = GgufFile::from_path(&weights_file)
229                    .with_context(|| format!("opening GGUF {weights_file:?}"))?;
230                Flux2Config::from_gguf(&raw)?
231            }
232            _ => {
233                if let Some(p) = resolve_transformer_config(&weights_file, None) {
234                    Flux2Config::from_file(&p)?
235                } else {
236                    Flux2Config::flux2_dev()
237                }
238            }
239        };
240        let path = weights_file
241            .to_str()
242            .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
243
244        let device = self.device.unwrap_or(Device::Cpu);
245        use crate::{
246            assert_flux2_device_available, flux2_prefers_compiled_hir, flux2_prefers_compiled_te,
247        };
248        rlx_core::validate_standard_device("flux2", device)?;
249        assert_flux2_device_available(device)?;
250        if self.compiled_text_encoder && !flux2_prefers_compiled_te(device) {
251            anyhow::bail!(
252                "compiled text encoder on {device:?} can take hours and exhaust VRAM; \
253                 use native CPU TE (default on CUDA/ROCm/wgpu/Vulkan)"
254            );
255        }
256        let compiled_denoiser = self.compiled_denoiser || flux2_prefers_compiled_hir(device);
257        let compiled_text_encoder = self.compiled_text_encoder || flux2_prefers_compiled_te(device);
258        let compiled_vae = self.compiled_vae || flux2_prefers_compiled_hir(device);
259
260        let use_nvfp4 = if is_gguf {
261            false
262        } else {
263            match self.nvfp4 {
264                Some(yes) => yes,
265                None => safetensors_has_nvfp4(&weights_file).unwrap_or(false),
266            }
267        };
268        let use_gguf_packed = is_gguf
269            && match self.nvfp4 {
270                Some(false) => false,
271                _ => crate::packed_gguf::gguf_has_packed_linears(&weights_file).unwrap_or(false),
272            };
273        let packed = if use_nvfp4 {
274            Some(load_flux2_nvfp4_from_file(&weights_file)?)
275        } else {
276            None
277        };
278
279        let mut exclude_f32 = std::collections::HashSet::new();
280        if let Some(p) = &packed {
281            exclude_f32.extend(p.exclude_f32_keys());
282        }
283        let typed_linears = if is_gguf {
284            None
285        } else if compiled_denoiser && self.lora_path.is_none() {
286            Some(load_typed_linears_from_file(&weights_file, &exclude_f32)?)
287        } else {
288            if self.lora_path.is_some() && compiled_denoiser {
289                eprintln!(
290                    "[flux2] LoRA active — using merged f32 weights (typed BF16 linears disabled)"
291                );
292            }
293            None
294        };
295        if let Some(t) = &typed_linears {
296            exclude_f32.extend(t.skip_keys());
297        }
298
299        let (mut wm, gguf_packed) = if use_gguf_packed {
300            if self.lora_path.is_some() {
301                bail!("LoRA merge is not supported on GGUF denoiser weights; use safetensors");
302            }
303            eprintln!("[flux2] loading denoiser GGUF with packed DequantMatMul {weights_file:?}");
304            {
305                let (wm, g) = crate::packed_gguf::load_flux2_from_gguf(&weights_file)?;
306                (wm, Some(g))
307            }
308        } else if is_gguf {
309            if self.lora_path.is_some() {
310                bail!("LoRA merge is not supported on GGUF denoiser weights; use safetensors");
311            }
312            eprintln!("[flux2] loading denoiser from GGUF (F32 drain) {weights_file:?}");
313            (load_flux2_weight_map(&weights_file)?, None)
314        } else {
315            use rlx_core::weight_map::WeightMap;
316            (WeightMap::from_file_excluding(path, &exclude_f32)?, None)
317        };
318        let packed = match (packed, gguf_packed) {
319            (Some(nv), None) => Some(nv),
320            (None, Some(g)) => Some(g),
321            (None, None) => None,
322            (Some(_), Some(_)) => unreachable!("nvfp4 and gguf packed are mutually exclusive"),
323        };
324        if let Some(lora_path) = &self.lora_path {
325            let n = if lora_path.is_dir() {
326                crate::load_and_apply_flux2_lora_dir(&mut wm, lora_path, self.lora_scale)?
327            } else {
328                crate::load_and_apply_flux2_lora(&mut wm, lora_path, self.lora_scale)?
329            };
330            eprintln!(
331                "[flux2] merged {n} LoRA layers from {:?} (scale={})",
332                lora_path, self.lora_scale
333            );
334        }
335        let extract_opts = ExtractFlux2Opts {
336            typed_linears: typed_linears.as_ref(),
337            packed_linears: packed.as_ref(),
338            dual_time_embedder: self.dual_time_embedder || self.lora_path.is_some(),
339        };
340        if extract_opts.dual_time_embedder {
341            eprintln!("[flux2] dual-time timestep embedder enabled (flow-map / Diamond Maps)");
342        }
343        let model = extract_flux2_weights_with_opts(prepare_weight_map(wm), &cfg, extract_opts)?;
344
345        let te_dir = if self.skip_text_encoder {
346            None
347        } else {
348            self.text_encoder_dir
349                .or_else(|| resolve_text_encoder_dir(&weights_file))
350        };
351        let (text_encoder, text_encoder_cfg) = if let Some(dir) = te_dir {
352            let te_cfg_path = self
353                .text_encoder_config_path
354                .unwrap_or_else(|| dir.join("config.json"));
355            let te_cfg = Qwen3Config::from_file(&te_cfg_path)?;
356            let te = load_text_encoder_weights(&dir, &te_cfg)?;
357            (Some(te), Some(te_cfg))
358        } else {
359            (None, None)
360        };
361
362        let vae_dir = self.vae_dir.or_else(|| resolve_vae_dir(&weights_file));
363        let (vae, vae_cfg) = if let Some(dir) = vae_dir {
364            let vae_cfg_path = self
365                .vae_config_path
366                .unwrap_or_else(|| dir.join("config.json"));
367            let vae_cfg = Flux2VaeConfig::from_file(&vae_cfg_path)?;
368            let vae = load_flux2_vae_weights(&dir, &vae_cfg)?;
369            (Some(vae), Some(vae_cfg))
370        } else {
371            (None, None)
372        };
373
374        let drop_text_encoder_after_encode = self
375            .drop_text_encoder_after_encode
376            .unwrap_or(!self.skip_text_encoder && compiled_denoiser);
377
378        Ok(Flux2Runner {
379            model,
380            cfg,
381            batch: self.batch.unwrap_or(1),
382            img_seq: self.img_seq.unwrap_or(256),
383            txt_seq: self.txt_seq.unwrap_or(128),
384            device,
385            compiled_denoiser,
386            compiled_text_encoder,
387            compiled_vae,
388            packed,
389            typed_linears,
390            aot_cache_dir: self.aot_cache_dir,
391            drop_text_encoder_after_encode,
392            use_flow_api: self.use_flow_api,
393            text_encoder: Mutex::new(text_encoder),
394            text_encoder_cfg: Mutex::new(text_encoder_cfg),
395            vae,
396            vae_cfg,
397            tokenizer_path: self.tokenizer_path,
398            model_root: weights_file,
399            denoiser: Mutex::new(None),
400            text_encoder_compiled: Mutex::new(None),
401            vae_compiled: Mutex::new(None),
402            vae_encoder_compiled: Mutex::new(None),
403            cfg_compiled: Mutex::new(None),
404        })
405    }
406}
407
408struct Flux2DenoiserCache {
409    compiled: rlx_runtime::CompiledGraph,
410    device: Device,
411    batch: usize,
412    img_seq: usize,
413    txt_seq: usize,
414    img_ids: Vec<f32>,
415    txt_ids: Vec<f32>,
416}
417
418struct Flux2TextEncoderCache {
419    compiled: rlx_runtime::CompiledGraph,
420    device: Device,
421    batch: usize,
422    txt_seq: usize,
423}
424
425struct Flux2VaeCache {
426    compiled: rlx_runtime::CompiledGraph,
427    device: Device,
428    batch: usize,
429    h: usize,
430    w: usize,
431}
432
433struct Flux2VaeEncoderCache {
434    compiled: rlx_runtime::CompiledGraph,
435    device: Device,
436    batch: usize,
437    h: usize,
438    w: usize,
439}
440
441struct Flux2CfgCache {
442    compiled: rlx_runtime::CompiledGraph,
443    device: Device,
444    batch: usize,
445    img_seq: usize,
446    out_dim: usize,
447}
448
449/// FLUX.2 denoiser runner — native CPU or compiled HIR on any [`Device`].
450pub struct Flux2Runner {
451    model: crate::Flux2Weights,
452    cfg: crate::Flux2Config,
453    batch: usize,
454    img_seq: usize,
455    txt_seq: usize,
456    device: Device,
457    /// When true, use HIR even on CPU; otherwise CPU uses native forward.
458    compiled_denoiser: bool,
459    /// When true, use HIR text encoder even on CPU.
460    compiled_text_encoder: bool,
461    compiled_vae: bool,
462    packed: Option<crate::Flux2PackedParams>,
463    typed_linears: Option<crate::TypedLinearStore>,
464    aot_cache_dir: Option<PathBuf>,
465    drop_text_encoder_after_encode: bool,
466    use_flow_api: bool,
467    text_encoder: Mutex<Option<crate::Flux2TextEncoderWeights>>,
468    text_encoder_cfg: Mutex<Option<rlx_qwen3::Qwen3Config>>,
469    vae: Option<crate::Flux2VaeWeights>,
470    vae_cfg: Option<crate::Flux2VaeConfig>,
471    tokenizer_path: Option<PathBuf>,
472    model_root: PathBuf,
473    denoiser: Mutex<Option<Flux2DenoiserCache>>,
474    text_encoder_compiled: Mutex<Option<Flux2TextEncoderCache>>,
475    vae_compiled: Mutex<Option<Flux2VaeCache>>,
476    vae_encoder_compiled: Mutex<Option<Flux2VaeEncoderCache>>,
477    cfg_compiled: Mutex<Option<Flux2CfgCache>>,
478}
479
480impl Flux2Runner {
481    pub fn builder() -> Flux2RunnerBuilder {
482        Flux2RunnerBuilder::default()
483    }
484
485    fn aot_cache(&self) -> Option<rlx_runtime::AotCache> {
486        self.aot_cache_dir
487            .as_ref()
488            .map(|p| rlx_runtime::AotCache::new(p.clone()))
489    }
490
491    pub fn drop_text_encoder_weights(&self) -> Result<()> {
492        if let Ok(mut te) = self.text_encoder.lock() {
493            if te.is_some() {
494                eprintln!("[flux2] dropping text encoder weights (~8GB RAM)");
495                *te = None;
496            }
497        }
498        if let Ok(mut cfg) = self.text_encoder_cfg.lock() {
499            *cfg = None;
500        }
501        if let Ok(mut cache) = self.text_encoder_compiled.lock() {
502            *cache = None;
503        }
504        Ok(())
505    }
506    pub fn config(&self) -> &crate::Flux2Config {
507        &self.cfg
508    }
509    pub fn device(&self) -> Device {
510        self.device
511    }
512
513    pub fn batch(&self) -> usize {
514        self.batch
515    }
516
517    pub fn img_seq(&self) -> usize {
518        self.img_seq
519    }
520
521    pub fn txt_seq(&self) -> usize {
522        self.txt_seq
523    }
524
525    pub fn uses_nvfp4(&self) -> bool {
526        self.packed.is_some()
527    }
528
529    pub fn has_text_encoder(&self) -> bool {
530        self.text_encoder
531            .lock()
532            .map(|g| g.is_some())
533            .unwrap_or(false)
534    }
535
536    pub fn has_vae(&self) -> bool {
537        self.vae.is_some()
538    }
539
540    /// True when denoiser forwards use compiled HIR ([`Self::device`]).
541    pub fn uses_compiled_denoiser(&self) -> bool {
542        self.compiled_denoiser
543    }
544
545    /// True when text encoding uses compiled HIR on [`Self::device`].
546    pub fn uses_compiled_text_encoder(&self) -> bool {
547        self.compiled_text_encoder
548    }
549
550    pub fn uses_compiled_vae(&self) -> bool {
551        self.compiled_vae
552    }
553
554    /// Pre-compile the denoiser HIR for the given position ids (RoPE tables are baked in).
555    pub fn warmup_denoiser(&self, img_ids: &[f32], txt_ids: &[f32]) -> Result<()> {
556        if self.uses_compiled_denoiser() {
557            self.ensure_denoiser_compiled(img_ids, txt_ids)?;
558        }
559        Ok(())
560    }
561
562    fn ensure_denoiser_compiled(&self, img_ids: &[f32], txt_ids: &[f32]) -> Result<()> {
563        use crate::{compile_flux2_forward, compile_flux2_forward_via_flow};
564
565        let mut guard = self
566            .denoiser
567            .lock()
568            .map_err(|e| anyhow!("denoiser cache lock poisoned: {e}"))?;
569        let img_seq = img_ids.len() / (self.batch * 4);
570        let recompile = guard.as_ref().is_none_or(|c| {
571            c.device != self.device
572                || c.batch != self.batch
573                || c.img_seq != img_seq
574                || c.txt_seq != self.txt_seq
575                || c.img_ids != img_ids
576                || c.txt_ids != txt_ids
577        });
578        if recompile {
579            eprintln!(
580                "[flux2] compiling denoiser HIR on {:?} (img_seq={img_seq}, txt_seq={}, flow={})…",
581                self.device, self.txt_seq, self.use_flow_api
582            );
583            let aot = self.aot_cache();
584            let (compiled, _) = if self.use_flow_api {
585                compile_flux2_forward_via_flow(
586                    &self.cfg,
587                    &self.model,
588                    self.batch,
589                    img_seq,
590                    self.txt_seq,
591                    img_ids,
592                    txt_ids,
593                    self.device,
594                    self.packed.as_ref(),
595                    self.typed_linears.as_ref(),
596                    aot.as_ref(),
597                )?
598            } else {
599                compile_flux2_forward(
600                    &self.cfg,
601                    &self.model,
602                    self.batch,
603                    img_seq,
604                    self.txt_seq,
605                    img_ids,
606                    txt_ids,
607                    self.device,
608                    self.packed.as_ref(),
609                    self.typed_linears.as_ref(),
610                    aot.as_ref(),
611                )?
612            };
613            *guard = Some(Flux2DenoiserCache {
614                compiled,
615                device: self.device,
616                batch: self.batch,
617                img_seq,
618                txt_seq: self.txt_seq,
619                img_ids: img_ids.to_vec(),
620                txt_ids: txt_ids.to_vec(),
621            });
622        }
623        Ok(())
624    }
625
626    /// Encode a text prompt into FLUX.2 `encoder_hidden_states` and `txt_ids`.
627    ///
628    /// Uses compiled HIR on Metal / MLX when [`Self::uses_compiled_text_encoder`];
629    /// native CPU on CUDA / ROCm / wgpu / Vulkan and CPU otherwise.
630    pub fn encode_prompt(&self, prompt: &str) -> Result<(Vec<f32>, Vec<f32>)> {
631        if self.uses_compiled_text_encoder() {
632            return self.encode_prompt_compiled(prompt);
633        }
634        eprintln!("[flux2] text encoder: native CPU forward");
635        self.encode_prompt_native(prompt)
636    }
637
638    /// Native CPU text encoder (no IR compile).
639    pub fn encode_prompt_native(&self, prompt: &str) -> Result<(Vec<f32>, Vec<f32>)> {
640        use crate::{
641            DEFAULT_TEXT_ENCODER_LAYERS, encode_flux2_prompt, encode_prompt_padded,
642            resolve_tokenizer_path,
643        };
644
645        let tok_path = resolve_tokenizer_path(&self.model_root, self.tokenizer_path.as_deref())
646            .ok_or_else(|| {
647                anyhow!(
648                    "no tokenizer found near {:?}; pass .tokenizer_path(...)",
649                    self.model_root
650                )
651            })?;
652        let input_ids = encode_prompt_padded(&tok_path, prompt, self.txt_seq)?;
653        let (out, txt_ids) = {
654            let te_guard = self
655                .text_encoder
656                .lock()
657                .map_err(|e| anyhow!("text encoder lock poisoned: {e}"))?;
658            let te = te_guard.as_ref().ok_or_else(|| {
659                anyhow!("text encoder not loaded (pass .text_encoder_dir(...) on build)")
660            })?;
661            let cfg_guard = self
662                .text_encoder_cfg
663                .lock()
664                .map_err(|e| anyhow!("text encoder cfg lock poisoned: {e}"))?;
665            let te_cfg = cfg_guard
666                .as_ref()
667                .ok_or_else(|| anyhow!("text encoder config missing"))?;
668            encode_flux2_prompt(
669                te,
670                te_cfg,
671                &input_ids,
672                self.batch,
673                self.txt_seq,
674                DEFAULT_TEXT_ENCODER_LAYERS,
675            )?
676        };
677        ensure!(
678            out.joint_dim == self.cfg.joint_attention_dim,
679            "text encoder joint_dim {} != transformer joint_attention_dim {}",
680            out.joint_dim,
681            self.cfg.joint_attention_dim
682        );
683        Ok((out.prompt_embeds, txt_ids))
684    }
685
686    fn ensure_text_encoder_compiled(&self) -> Result<()> {
687        use crate::{DEFAULT_TEXT_ENCODER_LAYERS, compile_flux2_text_encoder_hir};
688
689        let te = {
690            let guard = self
691                .text_encoder
692                .lock()
693                .map_err(|e| anyhow!("text encoder lock poisoned: {e}"))?;
694            guard
695                .as_ref()
696                .ok_or_else(|| anyhow!("text encoder not loaded"))?
697                .clone()
698        };
699        let te_cfg = {
700            let guard = self
701                .text_encoder_cfg
702                .lock()
703                .map_err(|e| anyhow!("text encoder cfg lock poisoned: {e}"))?;
704            guard
705                .as_ref()
706                .ok_or_else(|| anyhow!("text encoder config missing"))?
707                .clone()
708        };
709
710        let mut guard = self
711            .text_encoder_compiled
712            .lock()
713            .map_err(|e| anyhow!("text encoder cache lock poisoned: {e}"))?;
714        let recompile = guard.as_ref().is_none_or(|c| {
715            c.device != self.device || c.batch != self.batch || c.txt_seq != self.txt_seq
716        });
717        if recompile {
718            eprintln!(
719                "[flux2] compiling text encoder HIR on {:?} (txt_seq={})…",
720                self.device, self.txt_seq
721            );
722            let aot = self.aot_cache();
723            let (compiled, _) = compile_flux2_text_encoder_hir(
724                &te_cfg,
725                &te,
726                self.batch,
727                self.txt_seq,
728                DEFAULT_TEXT_ENCODER_LAYERS,
729                self.device,
730                aot.as_ref(),
731            )?;
732            *guard = Some(Flux2TextEncoderCache {
733                compiled,
734                device: self.device,
735                batch: self.batch,
736                txt_seq: self.txt_seq,
737            });
738        }
739        Ok(())
740    }
741
742    fn ensure_cfg_compiled(&self, img_seq: usize) -> Result<()> {
743        use crate::compile_flux2_cfg_combine;
744
745        let out_dim = self.cfg.proj_out_dim();
746        let mut guard = self
747            .cfg_compiled
748            .lock()
749            .map_err(|e| anyhow!("cfg cache lock poisoned: {e}"))?;
750        let recompile = guard.as_ref().is_none_or(|c| {
751            c.device != self.device
752                || c.batch != self.batch
753                || c.img_seq != img_seq
754                || c.out_dim != out_dim
755        });
756        if recompile {
757            let aot = self.aot_cache();
758            let compiled =
759                compile_flux2_cfg_combine(self.batch, img_seq, out_dim, self.device, aot.as_ref())?;
760            *guard = Some(Flux2CfgCache {
761                compiled,
762                device: self.device,
763                batch: self.batch,
764                img_seq,
765                out_dim,
766            });
767        }
768        Ok(())
769    }
770
771    fn cfg_combine_compiled(
772        &self,
773        pos: &[f32],
774        neg: &[f32],
775        scale: f32,
776        img_seq: usize,
777    ) -> Result<Vec<f32>> {
778        self.ensure_cfg_compiled(img_seq)?;
779        let mut guard = self
780            .cfg_compiled
781            .lock()
782            .map_err(|e| anyhow!("cfg cache lock poisoned: {e}"))?;
783        let cache = guard
784            .as_mut()
785            .ok_or_else(|| anyhow!("cfg compile cache missing"))?;
786        Ok(cache
787            .compiled
788            .run(&[("pos", pos), ("neg", neg), ("guidance_scale", &[scale])])
789            .remove(0))
790    }
791
792    /// Encode via compiled text-encoder HIR on [`Self::device`].
793    pub fn encode_prompt_compiled(&self, prompt: &str) -> Result<(Vec<f32>, Vec<f32>)> {
794        use crate::{
795            DEFAULT_TEXT_ENCODER_LAYERS, encode_prompt_padded, prepare_text_ids,
796            resolve_tokenizer_path,
797        };
798
799        let te_cfg = self
800            .text_encoder_cfg
801            .lock()
802            .map_err(|e| anyhow!("text encoder cfg lock poisoned: {e}"))?;
803        let te_cfg = te_cfg
804            .as_ref()
805            .ok_or_else(|| anyhow!("text encoder config missing"))?;
806
807        let tok_path = resolve_tokenizer_path(&self.model_root, self.tokenizer_path.as_deref())
808            .ok_or_else(|| anyhow!("no tokenizer found near {:?}", self.model_root))?;
809        let input_ids = encode_prompt_padded(&tok_path, prompt, self.txt_seq)?;
810        let ids_f32: Vec<f32> = input_ids.iter().map(|&x| x as f32).collect();
811
812        self.ensure_text_encoder_compiled()?;
813        let mut guard = self
814            .text_encoder_compiled
815            .lock()
816            .map_err(|e| anyhow!("text encoder cache lock poisoned: {e}"))?;
817        let cache = guard
818            .as_mut()
819            .ok_or_else(|| anyhow!("text encoder compile cache missing"))?;
820        let embeds = cache
821            .compiled
822            .run(&[("input_ids", ids_f32.as_slice())])
823            .remove(0);
824        let joint = te_cfg.hidden_size * DEFAULT_TEXT_ENCODER_LAYERS.len();
825        ensure!(
826            joint == self.cfg.joint_attention_dim,
827            "text encoder joint_dim {joint} != transformer {}",
828            self.cfg.joint_attention_dim
829        );
830        let txt_ids = prepare_text_ids(self.batch, self.txt_seq);
831        Ok((embeds, txt_ids))
832    }
833
834    /// One denoiser forward: latents + text context → noise prediction.
835    pub fn forward(
836        &self,
837        hidden_states: &[f32],
838        encoder_hidden_states: &[f32],
839        timestep: &[f32],
840        guidance: Option<&[f32]>,
841        img_ids: &[f32],
842        txt_ids: &[f32],
843    ) -> Result<Flux2Output> {
844        let noise_pred = self.forward_noise(
845            hidden_states,
846            encoder_hidden_states,
847            timestep,
848            guidance,
849            img_ids,
850            txt_ids,
851        )?;
852        Ok(Flux2Output {
853            noise_pred,
854            img_seq: hidden_states.len() / (self.batch * self.cfg.in_channels),
855            out_dim: self.cfg.proj_out_dim(),
856        })
857    }
858
859    /// VAE encode RGB planar `[-1,1]` NCHW → latent (compiled on GPU when enabled).
860    pub fn vae_encode_rgb(&self, rgb: &[f32], pixel_h: usize, pixel_w: usize) -> Result<Vec<f32>> {
861        if self.uses_compiled_vae() {
862            return self.vae_encode_rgb_compiled(rgb, pixel_h, pixel_w);
863        }
864        let vae = self.vae.as_ref().ok_or_else(|| anyhow!("VAE not loaded"))?;
865        let vae_cfg = self
866            .vae_cfg
867            .as_ref()
868            .ok_or_else(|| anyhow!("VAE config missing"))?;
869        crate::flux2_vae_encode(vae, vae_cfg, rgb, self.batch, pixel_h, pixel_w)
870    }
871
872    fn vae_encode_rgb_compiled(
873        &self,
874        rgb: &[f32],
875        pixel_h: usize,
876        pixel_w: usize,
877    ) -> Result<Vec<f32>> {
878        let vae_cfg = self
879            .vae_cfg
880            .as_ref()
881            .ok_or_else(|| anyhow!("VAE config missing"))?;
882
883        self.ensure_vae_encoder_compiled(pixel_h, pixel_w)?;
884        let mut guard = self
885            .vae_encoder_compiled
886            .lock()
887            .map_err(|e| anyhow!("vae encoder cache lock poisoned: {e}"))?;
888        let cache = guard
889            .as_mut()
890            .ok_or_else(|| anyhow!("vae encoder compile cache missing"))?;
891        let mut latent = cache.compiled.run(&[("rgb", rgb)]).remove(0);
892        if vae_cfg.scaling_factor != 1.0 || vae_cfg.shift_factor != 0.0 {
893            for v in &mut latent {
894                *v = (*v - vae_cfg.shift_factor) * vae_cfg.scaling_factor;
895            }
896        }
897        Ok(latent)
898    }
899
900    /// Encode planar RGB `[-1,1]` NCHW to packed transformer latents.
901    pub fn encode_rgb_to_packed(
902        &self,
903        rgb: &[f32],
904        pixel_h: usize,
905        pixel_w: usize,
906        latent_h: usize,
907        latent_w: usize,
908        eff_h: usize,
909        eff_w: usize,
910    ) -> Result<Vec<f32>> {
911        use crate::pack_encoded_latents;
912
913        let vae = self
914            .vae
915            .as_ref()
916            .ok_or_else(|| anyhow!("VAE not loaded (required for img2img / edit)"))?;
917        let vae_cfg = self
918            .vae_cfg
919            .as_ref()
920            .ok_or_else(|| anyhow!("VAE config missing"))?;
921        let stride = vae_cfg.encode_spatial_stride();
922        let enc_h = pixel_h / stride;
923        let enc_w = pixel_w / stride;
924        ensure!(
925            enc_h > 0 && enc_w > 0,
926            "encoded spatial dims too small for {pixel_h}x{pixel_w}"
927        );
928        let encoded = self.vae_encode_rgb(rgb, pixel_h, pixel_w)?;
929        ensure!(
930            encoded.len() == self.batch * vae_cfg.latent_channels * enc_h * enc_w,
931            "encoded len {} != expected {}",
932            encoded.len(),
933            self.batch * vae_cfg.latent_channels * enc_h * enc_w
934        );
935        pack_encoded_latents(
936            vae, vae_cfg, encoded, self.batch, enc_h, enc_w, eff_h, eff_w, latent_h, latent_w,
937        )
938    }
939
940    pub fn has_vae_encoder(&self) -> bool {
941        self.vae.is_some()
942    }
943
944    /// img2img: encode source RGB and blend with noise at the strength-derived sigma.
945    pub fn prepare_img2img_packed(
946        &self,
947        rgb: &[f32],
948        pixel_h: usize,
949        pixel_w: usize,
950        latent_h: usize,
951        latent_w: usize,
952        eff_h: usize,
953        eff_w: usize,
954        noise: &[f32],
955        image_strength: f32,
956        num_inference_steps: usize,
957    ) -> Result<Vec<f32>> {
958        use crate::latent_ops::blend_latents_with_noise;
959        use crate::{flow_match_init_timestep, flow_match_sigmas};
960
961        let clean =
962            self.encode_rgb_to_packed(rgb, pixel_h, pixel_w, latent_h, latent_w, eff_h, eff_w)?;
963        ensure!(clean.len() == noise.len());
964        let sigmas = flow_match_sigmas(num_inference_steps);
965        let init_step = flow_match_init_timestep(image_strength, num_inference_steps);
966        let sigma = sigmas[init_step.min(sigmas.len() - 1)];
967        Ok(blend_latents_with_noise(&clean, noise, sigma))
968    }
969
970    /// Edit mode: encode reference images into concat conditioning tokens.
971    pub fn prepare_edit_conditioning(
972        &self,
973        images: &[(&[f32], usize, usize)],
974        eff_h: usize,
975        eff_w: usize,
976        latent_h: usize,
977        latent_w: usize,
978    ) -> Result<crate::Flux2ReferenceConditioning> {
979        use crate::{
980            Flux2ReferenceConditioning, concat_latent_ids, concat_packed_latents,
981            prepare_latent_ids_with_t,
982        };
983
984        ensure!(
985            !images.is_empty(),
986            "edit requires at least one reference image"
987        );
988        let vae_cfg = self
989            .vae_cfg
990            .as_ref()
991            .ok_or_else(|| anyhow!("VAE config missing"))?;
992        let channels = vae_cfg.bn_channels();
993        let mut packed_acc: Option<Vec<f32>> = None;
994        let mut ids_acc: Option<Vec<f32>> = None;
995        let mut total_seq = 0usize;
996
997        for (i, (rgb, ph, pw)) in images.iter().enumerate() {
998            let packed =
999                self.encode_rgb_to_packed(rgb, *ph, *pw, latent_h, latent_w, eff_h, eff_w)?;
1000            let seq = packed.len() / (self.batch * channels);
1001            total_seq += seq;
1002            let ids = prepare_latent_ids_with_t(self.batch, latent_h, latent_w, 10 + 10 * i as i32);
1003            packed_acc = Some(match packed_acc {
1004                Some(prev) => concat_packed_latents(&prev, &packed, self.batch, channels),
1005                None => packed,
1006            });
1007            ids_acc = Some(match ids_acc {
1008                Some(prev) => concat_latent_ids(&prev, &ids, self.batch),
1009                None => ids,
1010            });
1011        }
1012
1013        Ok(Flux2ReferenceConditioning {
1014            packed: packed_acc.unwrap(),
1015            img_ids: ids_acc.unwrap(),
1016            seq: total_seq,
1017        })
1018    }
1019
1020    /// Denoiser noise prediction (compiled on [`Self::device`] when not CPU-native).
1021    pub fn forward_noise(
1022        &self,
1023        hidden_states: &[f32],
1024        encoder_hidden_states: &[f32],
1025        timestep: &[f32],
1026        guidance: Option<&[f32]>,
1027        img_ids: &[f32],
1028        txt_ids: &[f32],
1029    ) -> Result<Vec<f32>> {
1030        if self.uses_compiled_denoiser() {
1031            self.forward_noise_compiled(
1032                hidden_states,
1033                encoder_hidden_states,
1034                timestep,
1035                guidance,
1036                img_ids,
1037                txt_ids,
1038            )
1039        } else {
1040            self.forward_noise_native(
1041                hidden_states,
1042                encoder_hidden_states,
1043                timestep,
1044                guidance,
1045                img_ids,
1046                txt_ids,
1047            )
1048        }
1049    }
1050
1051    /// Native CPU reference forward (no IR compile).
1052    pub fn forward_noise_native(
1053        &self,
1054        hidden_states: &[f32],
1055        encoder_hidden_states: &[f32],
1056        timestep: &[f32],
1057        guidance: Option<&[f32]>,
1058        img_ids: &[f32],
1059        txt_ids: &[f32],
1060    ) -> Result<Vec<f32>> {
1061        use crate::{Flux2ForwardInput, flux2_transformer_forward};
1062
1063        flux2_transformer_forward(
1064            &self.model,
1065            &self.cfg,
1066            Flux2ForwardInput {
1067                hidden_states,
1068                encoder_hidden_states,
1069                timestep,
1070                timestep_target: None,
1071                guidance,
1072                img_ids,
1073                txt_ids,
1074                batch: self.batch,
1075                img_seq: hidden_states.len() / (self.batch * self.cfg.in_channels),
1076                txt_seq: self.txt_seq,
1077            },
1078        )
1079    }
1080
1081    /// Native forward with dual-time embedding (flow-map).
1082    pub fn forward_noise_dual_native(
1083        &self,
1084        hidden_states: &[f32],
1085        encoder_hidden_states: &[f32],
1086        timestep: &[f32],
1087        timestep_target: &[f32],
1088        guidance: Option<&[f32]>,
1089        img_ids: &[f32],
1090        txt_ids: &[f32],
1091    ) -> Result<Vec<f32>> {
1092        use crate::{Flux2ForwardInput, flux2_transformer_forward};
1093
1094        flux2_transformer_forward(
1095            &self.model,
1096            &self.cfg,
1097            Flux2ForwardInput {
1098                hidden_states,
1099                encoder_hidden_states,
1100                timestep,
1101                timestep_target: Some(timestep_target),
1102                guidance,
1103                img_ids,
1104                txt_ids,
1105                batch: self.batch,
1106                img_seq: hidden_states.len() / (self.batch * self.cfg.in_channels),
1107                txt_seq: self.txt_seq,
1108            },
1109        )
1110    }
1111
1112    /// Compiled HIR denoiser on [`Self::device`] (Metal / MLX / CUDA / CPU).
1113    pub fn forward_noise_compiled(
1114        &self,
1115        hidden_states: &[f32],
1116        encoder_hidden_states: &[f32],
1117        timestep: &[f32],
1118        guidance: Option<&[f32]>,
1119        img_ids: &[f32],
1120        txt_ids: &[f32],
1121    ) -> Result<Vec<f32>> {
1122        use crate::host_temb;
1123
1124        self.ensure_denoiser_compiled(img_ids, txt_ids)?;
1125        let temb = host_temb(&self.model, &self.cfg, timestep, guidance)?;
1126        let mut guard = self
1127            .denoiser
1128            .lock()
1129            .map_err(|e| anyhow!("denoiser cache lock poisoned: {e}"))?;
1130        let cache = guard
1131            .as_mut()
1132            .ok_or_else(|| anyhow!("denoiser compile cache missing after ensure"))?;
1133        Ok(cache
1134            .compiled
1135            .run(&[
1136                ("hidden", hidden_states),
1137                ("encoder", encoder_hidden_states),
1138                ("temb", temb.as_slice()),
1139            ])
1140            .remove(0))
1141    }
1142
1143    /// Compiled forward with dual-time temb (flow-map).
1144    pub fn forward_noise_dual_compiled(
1145        &self,
1146        hidden_states: &[f32],
1147        encoder_hidden_states: &[f32],
1148        timestep: &[f32],
1149        timestep_target: &[f32],
1150        guidance: Option<&[f32]>,
1151        img_ids: &[f32],
1152        txt_ids: &[f32],
1153    ) -> Result<Vec<f32>> {
1154        use crate::host_temb_dual;
1155
1156        self.ensure_denoiser_compiled(img_ids, txt_ids)?;
1157        let temb = host_temb_dual(&self.model, &self.cfg, timestep, timestep_target, guidance)?;
1158        let mut guard = self
1159            .denoiser
1160            .lock()
1161            .map_err(|e| anyhow!("denoiser cache lock poisoned: {e}"))?;
1162        let cache = guard
1163            .as_mut()
1164            .ok_or_else(|| anyhow!("denoiser compile cache missing after ensure"))?;
1165        Ok(cache
1166            .compiled
1167            .run(&[
1168                ("hidden", hidden_states),
1169                ("encoder", encoder_hidden_states),
1170                ("temb", temb.as_slice()),
1171            ])
1172            .remove(0))
1173    }
1174
1175    /// Classifier-free guidance: positive + negative text, then
1176    /// `neg + cfg_scale * (pos - neg)` on the noise prediction.
1177    pub fn forward_cfg(
1178        &self,
1179        hidden_states: &[f32],
1180        pos_encoder: &[f32],
1181        neg_encoder: &[f32],
1182        timestep: &[f32],
1183        guidance: Option<&[f32]>,
1184        img_ids: &[f32],
1185        pos_txt_ids: &[f32],
1186        neg_txt_ids: &[f32],
1187        cfg_scale: f32,
1188    ) -> Result<Flux2Output> {
1189        use crate::cfg_combine;
1190
1191        let pos = self.forward_noise(
1192            hidden_states,
1193            pos_encoder,
1194            timestep,
1195            guidance,
1196            img_ids,
1197            pos_txt_ids,
1198        )?;
1199        if cfg_scale <= 1.0 {
1200            return Ok(Flux2Output {
1201                noise_pred: pos,
1202                img_seq: hidden_states.len() / (self.batch * self.cfg.in_channels),
1203                out_dim: self.cfg.proj_out_dim(),
1204            });
1205        }
1206        let neg = self.forward_noise(
1207            hidden_states,
1208            neg_encoder,
1209            timestep,
1210            guidance,
1211            img_ids,
1212            neg_txt_ids,
1213        )?;
1214        let img_seq = hidden_states.len() / (self.batch * self.cfg.in_channels);
1215        let noise_pred = if self.uses_compiled_denoiser() {
1216            self.cfg_combine_compiled(&pos, &neg, cfg_scale, img_seq)?
1217        } else {
1218            cfg_combine(&pos, &neg, cfg_scale)
1219        };
1220        Ok(Flux2Output {
1221            noise_pred,
1222            img_seq,
1223            out_dim: self.cfg.proj_out_dim(),
1224        })
1225    }
1226
1227    /// Tokenize and encode positive + optional negative prompts.
1228    #[allow(clippy::type_complexity)]
1229    pub fn encode_prompt_pair(
1230        &self,
1231        prompt: &str,
1232        negative_prompt: Option<&str>,
1233    ) -> Result<(Vec<f32>, Vec<f32>, Option<Vec<f32>>, Option<Vec<f32>>)> {
1234        let (pos, pos_ids) = self.encode_prompt(prompt)?;
1235        let (neg, neg_ids) = match negative_prompt {
1236            Some(n) => {
1237                let (e, ids) = self.encode_prompt(n)?;
1238                (Some(e), Some(ids))
1239            }
1240            None => (None, None),
1241        };
1242        if self.drop_text_encoder_after_encode {
1243            self.drop_text_encoder_weights()?;
1244        }
1245        Ok((pos, pos_ids, neg, neg_ids))
1246    }
1247
1248    pub fn vae_config(&self) -> Option<&crate::Flux2VaeConfig> {
1249        self.vae_cfg.as_ref()
1250    }
1251
1252    /// Decode denoised packed latents to interleaved RGB u8 (HWC) and pixel `(height, width)`.
1253    pub fn decode_to_rgb(
1254        &self,
1255        packed_latents: &[f32],
1256        img_ids: &[f32],
1257        latent_h: usize,
1258        latent_w: usize,
1259    ) -> Result<(Vec<u8>, u32, u32)> {
1260        if self.uses_compiled_vae() {
1261            return self.decode_to_rgb_compiled(packed_latents, img_ids, latent_h, latent_w);
1262        }
1263        self.decode_to_rgb_native(packed_latents, img_ids, latent_h, latent_w)
1264    }
1265
1266    /// Native CPU decode (unpack / BN / unpatchify + VAE decoder).
1267    pub fn decode_to_rgb_native(
1268        &self,
1269        packed_latents: &[f32],
1270        img_ids: &[f32],
1271        latent_h: usize,
1272        latent_w: usize,
1273    ) -> Result<(Vec<u8>, u32, u32)> {
1274        use crate::{flux2_decode_packed_latents, flux2_rgb_to_u8};
1275
1276        let vae = self
1277            .vae
1278            .as_ref()
1279            .ok_or_else(|| anyhow!("VAE not loaded (place vae/ next to weights)"))?;
1280        let vae_cfg = self
1281            .vae_cfg
1282            .as_ref()
1283            .ok_or_else(|| anyhow!("VAE config missing"))?;
1284        let packed_channels = self.cfg.in_channels;
1285        let img_seq = img_ids.len() / (self.batch * 4);
1286        let rgb = flux2_decode_packed_latents(
1287            vae,
1288            vae_cfg,
1289            packed_latents,
1290            img_ids,
1291            self.batch,
1292            img_seq,
1293            packed_channels,
1294            latent_h,
1295            latent_w,
1296        )?;
1297        let up_stages = vae_cfg.block_out_channels.len().saturating_sub(1);
1298        let scale = 2usize.pow(up_stages as u32 + 1);
1299        let h_px = latent_h * scale;
1300        let w_px = latent_w * scale;
1301        let u8 = flux2_rgb_to_u8(&rgb, self.batch, 3, h_px, w_px);
1302        Ok((u8, h_px as u32, w_px as u32))
1303    }
1304
1305    /// Compiled VAE decoder on [`Self::device`] (unpack/BN/unpatchify stay on CPU).
1306    pub fn decode_to_rgb_compiled(
1307        &self,
1308        packed_latents: &[f32],
1309        img_ids: &[f32],
1310        latent_h: usize,
1311        latent_w: usize,
1312    ) -> Result<(Vec<u8>, u32, u32)> {
1313        use crate::{
1314            denorm_patchified_latents, flux2_rgb_to_u8, unpack_latents_with_ids, unpatchify_latents,
1315        };
1316
1317        let vae_cfg = self
1318            .vae_cfg
1319            .as_ref()
1320            .ok_or_else(|| anyhow!("VAE config missing"))?;
1321        let vae = self.vae.as_ref().ok_or_else(|| anyhow!("VAE not loaded"))?;
1322        let packed_channels = self.cfg.in_channels;
1323        let img_seq = img_ids.len() / (self.batch * 4);
1324        let spatial = unpack_latents_with_ids(
1325            packed_latents,
1326            img_ids,
1327            self.batch,
1328            img_seq,
1329            packed_channels,
1330            latent_h,
1331            latent_w,
1332        )?;
1333        let denorm = denorm_patchified_latents(
1334            &spatial,
1335            &vae.bn_running_mean,
1336            &vae.bn_running_var,
1337            vae_cfg.batch_norm_eps,
1338        );
1339        let mut latents =
1340            unpatchify_latents(&denorm, self.batch, packed_channels, latent_h, latent_w);
1341        if vae_cfg.scaling_factor != 1.0 || vae_cfg.shift_factor != 0.0 {
1342            for v in &mut latents {
1343                *v = *v / vae_cfg.scaling_factor + vae_cfg.shift_factor;
1344            }
1345        }
1346        let h2 = latent_h * 2;
1347        let w2 = latent_w * 2;
1348
1349        self.ensure_vae_compiled(h2, w2)?;
1350        let mut guard = self
1351            .vae_compiled
1352            .lock()
1353            .map_err(|e| anyhow!("vae cache lock poisoned: {e}"))?;
1354        let cache = guard
1355            .as_mut()
1356            .ok_or_else(|| anyhow!("vae compile cache missing"))?;
1357        let rgb = cache
1358            .compiled
1359            .run(&[("latents", latents.as_slice())])
1360            .remove(0);
1361
1362        let up_stages = vae_cfg.block_out_channels.len().saturating_sub(1);
1363        let scale = 2usize.pow(up_stages as u32 + 1);
1364        let h_px = latent_h * scale;
1365        let w_px = latent_w * scale;
1366        let u8 = flux2_rgb_to_u8(&rgb, self.batch, 3, h_px, w_px);
1367        Ok((u8, h_px as u32, w_px as u32))
1368    }
1369
1370    fn ensure_vae_compiled(&self, h: usize, w: usize) -> Result<()> {
1371        use crate::compile_flux2_vae_hir;
1372
1373        let vae = self.vae.as_ref().ok_or_else(|| anyhow!("VAE not loaded"))?;
1374        let vae_cfg = self
1375            .vae_cfg
1376            .as_ref()
1377            .ok_or_else(|| anyhow!("VAE config missing"))?;
1378
1379        let mut guard = self
1380            .vae_compiled
1381            .lock()
1382            .map_err(|e| anyhow!("vae cache lock poisoned: {e}"))?;
1383        let recompile = guard.as_ref().is_none_or(|c| {
1384            c.device != self.device || c.batch != self.batch || c.h != h || c.w != w
1385        });
1386        if recompile {
1387            let aot = self.aot_cache();
1388            let (compiled, _) =
1389                compile_flux2_vae_hir(vae_cfg, vae, self.batch, h, w, self.device, aot.as_ref())?;
1390            *guard = Some(Flux2VaeCache {
1391                compiled,
1392                device: self.device,
1393                batch: self.batch,
1394                h,
1395                w,
1396            });
1397        }
1398        Ok(())
1399    }
1400
1401    fn ensure_vae_encoder_compiled(&self, h: usize, w: usize) -> Result<()> {
1402        use crate::compile_flux2_vae_encoder_hir;
1403
1404        let vae = self.vae.as_ref().ok_or_else(|| anyhow!("VAE not loaded"))?;
1405        let vae_cfg = self
1406            .vae_cfg
1407            .as_ref()
1408            .ok_or_else(|| anyhow!("VAE config missing"))?;
1409
1410        let mut guard = self
1411            .vae_encoder_compiled
1412            .lock()
1413            .map_err(|e| anyhow!("vae encoder cache lock poisoned: {e}"))?;
1414        let recompile = guard.as_ref().is_none_or(|c| {
1415            c.device != self.device || c.batch != self.batch || c.h != h || c.w != w
1416        });
1417        if recompile {
1418            let aot = self.aot_cache();
1419            let (compiled, _) = compile_flux2_vae_encoder_hir(
1420                vae_cfg,
1421                vae,
1422                self.batch,
1423                h,
1424                w,
1425                self.device,
1426                aot.as_ref(),
1427            )?;
1428            *guard = Some(Flux2VaeEncoderCache {
1429                compiled,
1430                device: self.device,
1431                batch: self.batch,
1432                h,
1433                w,
1434            });
1435        }
1436        Ok(())
1437    }
1438}