use anyhow::Result;
use candle_core::Tensor;
use candle_transformers::models::flux::{self, WithForward};
use std::time::Instant;
use crate::flux::offload::OffloadedFluxTransformer;
use crate::flux::quantized_transformer::QuantizedFluxTransformer;
use crate::img_utils::InpaintContext;
use crate::progress::{ProgressEvent, ProgressReporter};
#[allow(clippy::large_enum_variant)]
pub(crate) enum FluxTransformer {
BF16(flux::model::Flux),
Quantized(flux::quantized_model::Flux),
QuantizedBypass(QuantizedFluxTransformer),
Offloaded(OffloadedFluxTransformer),
}
impl FluxTransformer {
#[allow(clippy::too_many_arguments)]
pub fn denoise(
&self,
img: &Tensor,
img_ids: &Tensor,
txt: &Tensor,
txt_ids: &Tensor,
vec_: &Tensor,
timesteps: &[f64],
guidance: f64,
progress: &ProgressReporter,
inpaint_ctx: Option<&InpaintContext>,
) -> Result<Tensor> {
let b_sz = img.dim(0)?;
let dev = img.device();
let guidance_tensor = Tensor::full(guidance as f32, b_sz, dev)?;
let mut img = img.clone();
let total_steps = timesteps.len().saturating_sub(1);
for (step, window) in timesteps.windows(2).enumerate() {
let step_start = Instant::now();
let (t_curr, t_prev) = match window {
[a, b] => (a, b),
_ => continue,
};
let t_vec = Tensor::full(*t_curr as f32, b_sz, dev)?;
let pred = match self {
Self::BF16(m) => m.forward(
&img,
img_ids,
txt,
txt_ids,
&t_vec,
vec_,
Some(&guidance_tensor),
)?,
Self::Quantized(m) => m.forward(
&img,
img_ids,
txt,
txt_ids,
&t_vec,
vec_,
Some(&guidance_tensor),
)?,
Self::QuantizedBypass(m) => m.forward(
&img,
img_ids,
txt,
txt_ids,
&t_vec,
vec_,
Some(&guidance_tensor),
)?,
Self::Offloaded(m) => m.forward(
&img,
img_ids,
txt,
txt_ids,
&t_vec,
vec_,
Some(&guidance_tensor),
)?,
};
img = (img + pred * (t_prev - t_curr))?;
if let Some(ctx) = inpaint_ctx {
img = apply_flux_inpaint_step(&img, ctx, *t_prev)?;
}
progress.emit(ProgressEvent::DenoiseStep {
step: step + 1,
total: total_steps,
elapsed: step_start.elapsed(),
});
}
Ok(img)
}
}
fn apply_flux_inpaint_step(img: &Tensor, ctx: &InpaintContext, timestep: f64) -> Result<Tensor> {
crate::img2img::apply_flow_match_inpaint(img, ctx, timestep)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{DType, Device, Tensor};
#[test]
fn packed_inpaint_mask_broadcasts_across_flux_channels() {
let device = Device::Cpu;
let img = Tensor::ones((1, 4, 64), DType::F32, &device).unwrap();
let ctx = InpaintContext {
original_latents: Tensor::zeros((1, 4, 64), DType::F32, &device).unwrap(),
mask: Tensor::ones((1, 4, 1), DType::F32, &device).unwrap(),
noise: Tensor::zeros((1, 4, 64), DType::F32, &device).unwrap(),
};
let blended = apply_flux_inpaint_step(&img, &ctx, 0.0).unwrap();
assert_eq!(blended.dims(), &[1, 4, 64]);
}
}