use anyhow::{bail, Result};
use candle_core::{DType, Device, IndexOp, Tensor};
use candle_nn::VarBuilder;
use mold_core::{GenerateRequest, GenerateResponse, ImageData, ModelPaths};
use std::sync::Mutex;
use std::time::Instant;
use super::sampling::{self, Flux2State};
use super::transformer::{Flux2Config, Flux2TransformerWrapper};
use super::vae::{Flux2AutoEncoder, Flux2VaeConfig};
use crate::cache::{
clear_cache, get_or_insert_cached_tensor, prompt_text_key, CachedTensor, LruCache,
DEFAULT_PROMPT_CACHE_CAPACITY,
};
use crate::device::{
check_memory_budget, fmt_gb, free_vram_bytes, memory_status_string, preflight_memory_check,
};
use crate::encoders;
use crate::engine::{rand_seed, InferenceEngine, LoadStrategy};
use crate::image::{build_output_metadata, encode_image};
use crate::progress::{ProgressCallback, ProgressReporter};
struct LoadedFlux2 {
transformer: Flux2TransformerWrapper,
text_encoder: encoders::qwen3::Qwen3Encoder,
vae: Flux2AutoEncoder,
device: Device,
dtype: DType,
}
pub struct Flux2Engine {
loaded: Option<LoadedFlux2>,
model_name: String,
paths: ModelPaths,
progress: ProgressReporter,
qwen3_variant: Option<String>,
load_strategy: LoadStrategy,
prompt_cache: Mutex<LruCache<String, CachedTensor>>,
}
impl Flux2Engine {
pub fn new(
model_name: String,
paths: ModelPaths,
qwen3_variant: Option<String>,
load_strategy: LoadStrategy,
) -> Self {
Self {
loaded: None,
model_name,
paths,
progress: ProgressReporter::default(),
qwen3_variant,
load_strategy,
prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
}
}
fn validate_paths(&self) -> Result<std::path::PathBuf> {
let text_tokenizer_path = self
.paths
.text_tokenizer
.as_ref()
.ok_or_else(|| anyhow::anyhow!("text tokenizer path required for Flux.2 models"))?;
if !text_tokenizer_path.exists() {
bail!(
"text tokenizer file not found: {}",
text_tokenizer_path.display()
);
}
let encoder_paths = self.text_encoder_paths();
if encoder_paths.is_empty() {
bail!("text encoder paths required for Flux.2 models");
}
for path in &encoder_paths {
if !path.exists() {
bail!("text encoder file not found: {}", path.display());
}
}
if !self.paths.transformer.exists() {
bail!(
"transformer file not found: {}",
self.paths.transformer.display()
);
}
if !self.paths.vae.exists() {
bail!("VAE file not found: {}", self.paths.vae.display());
}
Ok(text_tokenizer_path.clone())
}
fn text_encoder_paths(&self) -> Vec<std::path::PathBuf> {
if !self.paths.text_encoder_files.is_empty() {
self.paths.text_encoder_files.clone()
} else {
self.paths
.t5_encoder
.as_ref()
.map(|p| vec![p.clone()])
.unwrap_or_default()
}
}
const QWEN3_HIDDEN_LAYERS: [usize; 3] = [9, 18, 27];
fn encode_and_stack(
encoder: &mut encoders::qwen3::Qwen3Encoder,
prompt: &str,
target_device: &Device,
target_dtype: DType,
) -> Result<Tensor> {
let (stacked, _token_count) = encoder.encode_with_layers(
prompt,
target_device,
target_dtype,
&Self::QWEN3_HIDDEN_LAYERS,
)?;
Ok(stacked)
}
fn encode_prompt_cached(
progress: &ProgressReporter,
prompt_cache: &Mutex<LruCache<String, CachedTensor>>,
encoder: &mut encoders::qwen3::Qwen3Encoder,
prompt: &str,
target_device: &Device,
target_dtype: DType,
) -> Result<Tensor> {
let cache_key = prompt_text_key(prompt);
let (txt_emb, cache_hit) = get_or_insert_cached_tensor(
prompt_cache,
cache_key,
target_device,
target_dtype,
|| {
progress.stage_start("Encoding prompt (Qwen3)");
let encode_start = Instant::now();
let txt_emb = Self::encode_and_stack(encoder, prompt, target_device, target_dtype)?;
progress.stage_done("Encoding prompt (Qwen3)", encode_start.elapsed());
Ok(txt_emb)
},
)?;
if cache_hit {
progress.cache_hit("prompt conditioning");
}
Ok(txt_emb)
}
pub fn load(&mut self) -> Result<()> {
if self.loaded.is_some() {
return Ok(());
}
if self.load_strategy == LoadStrategy::Sequential {
return Ok(());
}
tracing::info!(model = %self.model_name, "loading Flux.2 Klein model components...");
let text_tokenizer_path = self.validate_paths()?;
let cpu = Device::Cpu;
let device = crate::device::create_device(&self.progress)?;
let gpu_dtype = crate::engine::gpu_dtype(&device);
tracing::info!("GPU device: {:?}, GPU dtype: {:?}", device, gpu_dtype);
self.progress
.stage_start("Loading Flux.2 transformer (GPU, BF16)");
let xformer_stage = Instant::now();
tracing::info!(
path = %self.paths.transformer.display(),
"loading Flux.2 transformer on GPU..."
);
let flux2_cfg = Flux2Config::klein();
let xformer_paths = if !self.paths.transformer_shards.is_empty() {
self.paths.transformer_shards.clone()
} else {
vec![self.paths.transformer.clone()]
};
let flux_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&xformer_paths, gpu_dtype, &device)? };
let transformer = Flux2TransformerWrapper::BF16(super::transformer::Flux2Transformer::new(
&flux2_cfg, flux_vb,
)?);
self.progress.stage_done(
"Loading Flux.2 transformer (GPU, BF16)",
xformer_stage.elapsed(),
);
tracing::info!("Flux.2 transformer loaded on GPU");
self.progress.stage_start("Loading VAE (GPU)");
let vae_stage = Instant::now();
tracing::info!(path = %self.paths.vae.display(), "loading VAE on GPU...");
let vae_cfg = Flux2VaeConfig::klein();
let vae_vb = unsafe {
VarBuilder::from_mmaped_safetensors(
std::slice::from_ref(&self.paths.vae),
gpu_dtype,
&device,
)?
};
let vae = Flux2AutoEncoder::new(&vae_cfg, vae_vb)?;
self.progress
.stage_done("Loading VAE (GPU)", vae_stage.elapsed());
tracing::info!("VAE loaded on GPU");
let free = free_vram_bytes().unwrap_or(0);
if free > 0 {
self.progress.info(&format!(
"Free VRAM after transformer+VAE: {}",
fmt_gb(free)
));
}
self.progress.stage_start("Selecting Qwen3 encoder");
let resolve_start = Instant::now();
let (encoder_paths, is_gguf, on_gpu, device_label) = {
let bf16_paths = self.text_encoder_paths();
let have_bf16 = !bf16_paths.is_empty() && bf16_paths.iter().all(|p| p.exists());
crate::encoders::variant_resolution::resolve_qwen3_variant(
&self.progress,
self.qwen3_variant.as_deref(),
&device,
free,
&bf16_paths,
have_bf16,
true,
)?
};
self.progress
.stage_done("Selecting Qwen3 encoder", resolve_start.elapsed());
let enc_device = if on_gpu { &device } else { &cpu };
let enc_dtype = if on_gpu { gpu_dtype } else { DType::F32 };
let enc_stage_label = format!("Loading Qwen3 encoder ({device_label})");
self.progress.stage_start(&enc_stage_label);
let enc_stage = Instant::now();
let text_encoder = if is_gguf {
encoders::qwen3::Qwen3Encoder::load_gguf(
&encoder_paths[0],
&text_tokenizer_path,
enc_device,
)?
} else {
encoders::qwen3::Qwen3Encoder::load_bf16(
&encoder_paths,
&text_tokenizer_path,
enc_device,
enc_dtype,
)?
};
self.progress
.stage_done(&enc_stage_label, enc_stage.elapsed());
tracing::info!(device = %device_label, "Qwen3 encoder loaded");
self.loaded = Some(LoadedFlux2 {
transformer,
text_encoder,
vae,
device,
dtype: gpu_dtype,
});
tracing::info!(model = %self.model_name, "all Flux.2 model components loaded successfully");
Ok(())
}
fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
let text_tokenizer_path = self.validate_paths()?;
if let Some(warning) = check_memory_budget(&self.paths, LoadStrategy::Sequential) {
self.progress.info(&warning);
}
let device = crate::device::create_device(&self.progress)?;
let gpu_dtype = crate::engine::gpu_dtype(&device);
let start = Instant::now();
let seed = req.seed.unwrap_or_else(rand_seed);
let width = req.width as usize;
let height = req.height as usize;
tracing::info!(
prompt = %req.prompt,
seed, width, height,
steps = req.steps,
"starting sequential Flux.2 generation"
);
self.progress
.info("Using sequential loading (load-use-drop) to minimize peak memory");
let free = free_vram_bytes().unwrap_or(0);
self.progress.stage_start("Selecting Qwen3 encoder");
let resolve_start = Instant::now();
let (encoder_paths, is_gguf, on_gpu, device_label) = {
let bf16_paths = self.text_encoder_paths();
let have_bf16 = !bf16_paths.is_empty() && bf16_paths.iter().all(|p| p.exists());
crate::encoders::variant_resolution::resolve_qwen3_variant(
&self.progress,
self.qwen3_variant.as_deref(),
&device,
free,
&bf16_paths,
have_bf16,
true,
)?
};
self.progress
.stage_done("Selecting Qwen3 encoder", resolve_start.elapsed());
let enc_device = if on_gpu { &device } else { &Device::Cpu };
let enc_dtype = if on_gpu { gpu_dtype } else { DType::F32 };
let enc_size: u64 = encoder_paths
.iter()
.filter_map(|p| std::fs::metadata(p).ok().map(|m| m.len()))
.sum();
preflight_memory_check("Qwen3 encoder", enc_size)?;
if let Some(status) = memory_status_string() {
self.progress.info(&status);
}
let enc_stage_label = format!("Loading Qwen3 encoder ({device_label})");
self.progress.stage_start(&enc_stage_label);
let enc_stage = Instant::now();
let mut text_encoder = if is_gguf {
encoders::qwen3::Qwen3Encoder::load_gguf(
&encoder_paths[0],
&text_tokenizer_path,
enc_device,
)?
} else {
encoders::qwen3::Qwen3Encoder::load_bf16(
&encoder_paths,
&text_tokenizer_path,
enc_device,
enc_dtype,
)?
};
self.progress
.stage_done(&enc_stage_label, enc_stage.elapsed());
let txt_emb = Self::encode_prompt_cached(
&self.progress,
&self.prompt_cache,
&mut text_encoder,
&req.prompt,
&device,
gpu_dtype,
)?;
drop(text_encoder);
self.progress.info("Freed Qwen3 encoder");
tracing::info!("Qwen3 encoder dropped (sequential mode)");
let xformer_size = std::fs::metadata(&self.paths.transformer)
.map(|m| m.len())
.unwrap_or(0);
let vae_file_size = std::fs::metadata(&self.paths.vae)
.map(|m| m.len())
.unwrap_or(0);
preflight_memory_check("Flux.2 transformer + VAE", xformer_size + vae_file_size)?;
if let Some(status) = memory_status_string() {
self.progress.info(&status);
}
let flux2_cfg = Flux2Config::klein();
self.progress
.stage_start("Loading Flux.2 transformer (GPU, BF16)");
let xformer_stage = Instant::now();
let xformer_paths = if !self.paths.transformer_shards.is_empty() {
self.paths.transformer_shards.clone()
} else {
vec![self.paths.transformer.clone()]
};
let flux_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&xformer_paths, gpu_dtype, &device)? };
let transformer = Flux2TransformerWrapper::BF16(super::transformer::Flux2Transformer::new(
&flux2_cfg, flux_vb,
)?);
self.progress.stage_done(
"Loading Flux.2 transformer (GPU, BF16)",
xformer_stage.elapsed(),
);
self.progress.stage_start("Loading VAE (GPU)");
let vae_stage = Instant::now();
let vae_cfg = Flux2VaeConfig::klein();
let vae_vb = unsafe {
VarBuilder::from_mmaped_safetensors(
std::slice::from_ref(&self.paths.vae),
gpu_dtype,
&device,
)?
};
let vae = Flux2AutoEncoder::new(&vae_cfg, vae_vb)?;
self.progress
.stage_done("Loading VAE (GPU)", vae_stage.elapsed());
let latent_h = height.div_ceil(8);
let latent_w = width.div_ceil(8);
let img =
crate::engine::seeded_randn(seed, &[1, 32, latent_h, latent_w], &device, gpu_dtype)?;
let state = Flux2State::new(&txt_emb, &img)?;
let image_seq_len = (height / 16) * (width / 16);
let timesteps = sampling::get_schedule(req.steps as usize, image_seq_len);
let denoise_label = format!("Denoising ({} steps)", timesteps.len() - 1);
self.progress.stage_start(&denoise_label);
let denoise_start = Instant::now();
let img = transformer.denoise(
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
×teps,
req.guidance,
&self.progress,
)?;
let img = sampling::unpack(&img, height, width)?;
self.progress
.stage_done(&denoise_label, denoise_start.elapsed());
drop(transformer);
self.progress.info("Freed Flux.2 transformer");
drop(state);
drop(txt_emb);
device.synchronize()?;
tracing::info!("Transformer dropped (sequential mode), decoding VAE...");
self.progress.stage_start("VAE decode");
let vae_decode_start = Instant::now();
let img = vae.decode(&img.to_dtype(gpu_dtype)?)?;
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
let img = img.i(0)?;
self.progress
.stage_done("VAE decode", vae_decode_start.elapsed());
let output_metadata = build_output_metadata(req, seed, None);
let image_bytes = encode_image(
&img,
req.output_format,
req.width,
req.height,
output_metadata.as_ref(),
)?;
let generation_time_ms = start.elapsed().as_millis() as u64;
tracing::info!(generation_time_ms, seed, "sequential generation complete");
Ok(GenerateResponse {
images: vec![ImageData {
data: image_bytes,
format: req.output_format,
width: req.width,
height: req.height,
index: 0,
}],
generation_time_ms,
model: req.model.clone(),
seed_used: seed,
})
}
}
impl InferenceEngine for Flux2Engine {
fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
if req.scheduler.is_some() {
tracing::warn!(
"scheduler selection not supported for Flux.2 (flow-matching), ignoring"
);
}
if req.source_image.is_some() {
tracing::warn!("img2img not yet supported for Flux.2 — generating from text only");
}
if req.mask_image.is_some() {
tracing::warn!("inpainting not yet supported for Flux.2 -- ignoring mask");
}
if self.load_strategy == LoadStrategy::Sequential {
return self.generate_sequential(req);
}
let progress = &self.progress;
let loaded = self
.loaded
.as_mut()
.ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
let start = Instant::now();
let seed = req.seed.unwrap_or_else(rand_seed);
let width = req.width as usize;
let height = req.height as usize;
tracing::info!(
prompt = %req.prompt,
seed, width, height,
steps = req.steps,
"starting Flux.2 generation"
);
if loaded.text_encoder.model.is_none() {
progress.stage_start("Reloading Qwen3 encoder");
let reload_start = Instant::now();
loaded.text_encoder.reload()?;
progress.stage_done("Reloading Qwen3 encoder", reload_start.elapsed());
}
let txt_emb = Self::encode_prompt_cached(
progress,
&self.prompt_cache,
&mut loaded.text_encoder,
&req.prompt,
&loaded.device,
loaded.dtype,
)?;
tracing::info!("Qwen3 encoding complete");
if loaded.text_encoder.on_gpu {
loaded.text_encoder.drop_weights();
tracing::info!("Qwen3 encoder dropped from GPU to free VRAM for denoising");
}
let latent_h = height.div_ceil(8);
let latent_w = width.div_ceil(8);
let img = crate::engine::seeded_randn(
seed,
&[1, 32, latent_h, latent_w],
&loaded.device,
loaded.dtype,
)?;
let state = Flux2State::new(&txt_emb, &img)?;
let image_seq_len = (height / 16) * (width / 16);
let timesteps = sampling::get_schedule(req.steps as usize, image_seq_len);
let denoise_label = format!("Denoising ({} steps)", timesteps.len() - 1);
progress.stage_start(&denoise_label);
let denoise_start = Instant::now();
tracing::info!(steps = timesteps.len() - 1, "running denoising loop...");
let img = loaded.transformer.denoise(
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
×teps,
req.guidance,
progress,
)?;
let img = sampling::unpack(&img, height, width)?;
progress.stage_done(&denoise_label, denoise_start.elapsed());
tracing::info!("denoising complete, decoding VAE...");
drop(state);
drop(txt_emb);
progress.stage_start("VAE decode");
let vae_decode_start = Instant::now();
let img = loaded.vae.decode(&img.to_dtype(loaded.dtype)?)?;
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
let img = img.i(0)?;
progress.stage_done("VAE decode", vae_decode_start.elapsed());
tracing::info!("VAE decode complete, encoding output image...");
let output_metadata = build_output_metadata(req, seed, None);
let image_bytes = encode_image(
&img,
req.output_format,
req.width,
req.height,
output_metadata.as_ref(),
)?;
let generation_time_ms = start.elapsed().as_millis() as u64;
tracing::info!(generation_time_ms, seed, "generation complete");
Ok(GenerateResponse {
images: vec![ImageData {
data: image_bytes,
format: req.output_format,
width: req.width,
height: req.height,
index: 0,
}],
generation_time_ms,
model: req.model.clone(),
seed_used: seed,
})
}
fn model_name(&self) -> &str {
&self.model_name
}
fn is_loaded(&self) -> bool {
self.load_strategy == LoadStrategy::Sequential || self.loaded.is_some()
}
fn load(&mut self) -> Result<()> {
Flux2Engine::load(self)
}
fn unload(&mut self) {
self.loaded = None;
clear_cache(&self.prompt_cache);
}
fn set_on_progress(&mut self, callback: ProgressCallback) {
self.progress.set_callback(callback);
}
fn clear_on_progress(&mut self) {
self.progress.clear_callback();
}
}