use crate::adaptive_offload::{
plan_adaptive_residency, AdaptiveResidencyPlan, ADAPTIVE_OFFLOAD_RUNTIME_HEADROOM,
};
use crate::progress::ProgressReporter;
use candle_core::{DType, IndexOp, Module, Result, Tensor, D};
use candle_nn::{LayerNorm, Linear, RmsNorm, VarBuilder};
use std::sync::{Arc, OnceLock};
#[derive(Debug, Clone)]
pub(crate) enum Flux2Linear {
Standard(candle_nn::Linear),
Fp8 {
weight: Tensor,
scale: Option<Tensor>,
bias: Option<Tensor>,
},
Nvfp4Streaming {
packed: Tensor,
block_scales: Tensor,
tensor_scale: f32,
out_dim: usize,
#[allow(dead_code)]
in_dim: usize,
slice: Option<(usize, usize, usize)>,
bias: Option<Tensor>,
cache: Arc<OnceLock<Tensor>>,
},
}
impl Flux2Linear {
fn load_with_bias(
in_dim: usize,
out_dim: usize,
has_bias: bool,
vb: VarBuilder,
) -> Result<Self> {
if vb.contains_tensor("weight.nvfp4_packed") {
let packed = vb.get_unchecked_dtype("weight.nvfp4_packed", DType::U8)?;
let block_scales =
vb.get_unchecked_dtype("weight.nvfp4_block_scales", DType::F8E4M3)?;
let tensor_scale_t = vb.get_unchecked_dtype("weight.nvfp4_tensor_scale", DType::F32)?;
let cpu = candle_core::Device::Cpu;
let packed = packed.to_device(&cpu)?;
let block_scales = block_scales.to_device(&cpu)?;
let tensor_scale: f32 = tensor_scale_t.to_dtype(DType::F32)?.to_scalar()?;
let slice = if vb.contains_tensor("weight.nvfp4_slice_meta") {
let meta = vb
.get_unchecked_dtype("weight.nvfp4_slice_meta", DType::U32)?
.to_device(&cpu)?;
let v: Vec<u32> = meta.flatten_all()?.to_vec1()?;
if v.len() != 3 {
candle_core::bail!(
"NVFP4 slice meta tensor must have length 3, got {}",
v.len()
);
}
Some((v[0] as usize, v[1] as usize, v[2] as usize))
} else {
None
};
let packed_dims = packed.dims();
if packed_dims.len() != 2 {
candle_core::bail!("NVFP4 packed weight must be rank 2, got {:?}", packed_dims,);
}
let n_full = packed_dims[0];
let k_half = packed_dims[1];
let k = k_half * 2;
if k != in_dim {
candle_core::bail!(
"NVFP4: in_dim mismatch — checkpoint K={}, module expected {}",
k,
in_dim,
);
}
let expected_n_full = match slice {
Some((_, _, n_components)) => out_dim * n_components,
None => out_dim,
};
if n_full != expected_n_full {
candle_core::bail!(
"NVFP4: out_dim mismatch — checkpoint N_full={}, module expected {} (out_dim={}, slice={:?})",
n_full,
expected_n_full,
out_dim,
slice,
);
}
let bias = if has_bias {
vb.get_unchecked("bias").ok()
} else {
None
};
return Ok(Self::Nvfp4Streaming {
packed,
block_scales,
tensor_scale,
out_dim,
in_dim,
slice,
bias,
cache: Arc::new(OnceLock::new()),
});
}
let weight = vb.get((out_dim, in_dim), "weight")?;
if weight.dtype() == DType::F8E4M3 {
let scale = vb.get_unchecked("scale_weight").ok();
let bias = if has_bias {
vb.get_unchecked("bias").ok()
} else {
None
};
Ok(Self::Fp8 {
weight,
scale,
bias,
})
} else {
let bias = if has_bias {
Some(vb.get(out_dim, "bias")?)
} else {
None
};
Ok(Self::Standard(candle_nn::Linear::new(weight, bias)))
}
}
fn to_device(&self, device: &candle_core::Device) -> Result<Self> {
match self {
Self::Standard(linear) => Ok(Self::Standard(linear_to_device(linear, device)?)),
Self::Fp8 {
weight,
scale,
bias,
} => Ok(Self::Fp8 {
weight: weight.to_device(device)?,
scale: scale.as_ref().map(|t| t.to_device(device)).transpose()?,
bias: bias.as_ref().map(|t| t.to_device(device)).transpose()?,
}),
Self::Nvfp4Streaming { .. } => {
candle_core::bail!("Flux.2 block offload does not support NVFP4 streaming layers")
}
}
}
}
fn linear_to_device(linear: &Linear, device: &candle_core::Device) -> Result<Linear> {
let weight = linear.weight().to_device(device)?;
let bias = linear
.bias()
.map(|bias| bias.to_device(device))
.transpose()?;
Ok(Linear::new(weight, bias))
}
fn layer_norm_to_device(norm: &LayerNorm, device: &candle_core::Device) -> Result<LayerNorm> {
let weight = norm.weight().to_device(device)?;
match norm.bias() {
Some(bias) => Ok(LayerNorm::new(weight, bias.to_device(device)?, 1e-6)),
None => Ok(LayerNorm::new_no_bias(weight, 1e-6)),
}
}
fn rms_norm_to_device(norm: &RmsNorm, device: &candle_core::Device) -> Result<RmsNorm> {
let inner = norm.clone().into_inner();
Ok(RmsNorm::new(inner.weight().to_device(device)?, 1e-6))
}
fn tensor_bytes(t: &Tensor) -> usize {
t.elem_count() * t.dtype().size_in_bytes()
}
fn flux2_linear_bytes(linear: &Flux2Linear) -> usize {
match linear {
Flux2Linear::Standard(linear) => {
tensor_bytes(linear.weight()) + linear.bias().map(tensor_bytes).unwrap_or(0)
}
Flux2Linear::Fp8 {
weight,
scale,
bias,
} => {
tensor_bytes(weight)
+ scale.as_ref().map(tensor_bytes).unwrap_or(0)
+ bias.as_ref().map(tensor_bytes).unwrap_or(0)
}
Flux2Linear::Nvfp4Streaming {
packed,
block_scales,
bias,
cache,
..
} => {
tensor_bytes(packed)
+ tensor_bytes(block_scales)
+ bias.as_ref().map(tensor_bytes).unwrap_or(0)
+ cache.get().map(tensor_bytes).unwrap_or(0)
}
}
}
fn layer_norm_bytes(norm: &LayerNorm) -> usize {
tensor_bytes(norm.weight()) + norm.bias().map(tensor_bytes).unwrap_or(0)
}
fn rms_norm_bytes(norm: &RmsNorm) -> usize {
tensor_bytes(norm.clone().into_inner().weight())
}
impl Module for Flux2Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
match self {
Self::Standard(l) => l.forward(x),
Self::Fp8 {
weight,
scale,
bias,
} => {
let dtype = x.dtype();
let w = weight.to_dtype(dtype)?;
let w = match scale {
Some(s) => w.broadcast_mul(&s.to_dtype(dtype)?)?,
None => w,
};
let w = w.t()?;
let out = match *x.dims() {
[b1, b2, m, k] => {
x.reshape((b1 * b2 * m, k))?
.matmul(&w)?
.reshape((b1, b2, m, ()))?
}
[bsize, m, k] => {
x.reshape((bsize * m, k))?
.matmul(&w)?
.reshape((bsize, m, ()))?
}
_ => x.matmul(&w)?,
};
match bias {
Some(b) => out.broadcast_add(&b.to_dtype(dtype)?),
None => Ok(out),
}
}
Self::Nvfp4Streaming {
packed,
block_scales,
tensor_scale,
out_dim,
slice,
bias,
cache,
..
} => {
let _backend = crate::nvfp4::resolve_nvfp4_backend(x.device())?;
let bf16_full = match cache.get() {
Some(t) => t,
None => {
let dequanted = crate::nvfp4::dequant_nvfp4_to_bf16_cpu(
packed,
block_scales,
*tensor_scale,
)?;
let _ = cache.set(dequanted);
cache.get().expect("cache populated above")
}
};
let bf16_sliced_cpu = match slice {
Some((axis, component, _n_components)) => {
bf16_full.narrow(*axis, component * out_dim, *out_dim)?
}
None => bf16_full.clone(),
};
let dtype = x.dtype();
let w_dev = bf16_sliced_cpu.to_device(x.device())?.to_dtype(dtype)?;
let w = w_dev.t()?;
let out = match *x.dims() {
[b1, b2, m, k] => {
x.reshape((b1 * b2 * m, k))?
.matmul(&w)?
.reshape((b1, b2, m, ()))?
}
[bsize, m, k] => {
x.reshape((bsize * m, k))?
.matmul(&w)?
.reshape((bsize, m, ()))?
}
_ => x.matmul(&w)?,
};
match bias {
Some(b) => out.broadcast_add(&b.to_dtype(dtype)?),
None => Ok(out),
}
}
}
}
}
fn flux2_linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Flux2Linear> {
Flux2Linear::load_with_bias(in_dim, out_dim, false, vb)
}
#[derive(Debug, Clone)]
pub struct Flux2Config {
pub in_channels: usize,
pub vec_in_dim: usize,
pub context_in_dim: usize,
pub hidden_size: usize,
pub mlp_ratio: f64,
pub num_heads: usize,
pub depth: usize,
pub depth_single_blocks: usize,
pub axes_dim: Vec<usize>,
pub theta: usize,
pub guidance_embed: bool,
}
impl Flux2Config {
pub fn klein() -> Self {
Self {
in_channels: 128,
vec_in_dim: 0,
context_in_dim: 7680,
hidden_size: 3072,
mlp_ratio: 3.0,
num_heads: 24,
depth: 5,
depth_single_blocks: 20,
axes_dim: vec![32, 32, 32, 32],
theta: 2000,
guidance_embed: false,
}
}
pub fn klein_9b() -> Self {
Self {
in_channels: 128,
vec_in_dim: 0,
context_in_dim: 12288, hidden_size: 4096,
mlp_ratio: 3.0,
num_heads: 32,
depth: 8,
depth_single_blocks: 24,
axes_dim: vec![32, 32, 32, 32],
theta: 2000,
guidance_embed: false,
}
}
}
fn layer_norm(dim: usize, vb: &VarBuilder) -> Result<LayerNorm> {
let ws = Tensor::ones(dim, vb.dtype(), vb.device())?;
Ok(LayerNorm::new_no_bias(ws, 1e-6))
}
pub(crate) fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
crate::attention::attention_default_scale(q, k, v)
}
pub(crate) fn rope(pos: &Tensor, dim: usize, theta: usize) -> Result<Tensor> {
if dim % 2 == 1 {
candle_core::bail!("dim {dim} is odd")
}
let dev = pos.device();
let theta = theta as f64;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, 1, inv_freq_len), dev)?;
let inv_freq = inv_freq.to_dtype(pos.dtype())?;
let freqs = pos.unsqueeze(2)?.broadcast_mul(&inv_freq)?;
let cos = freqs.cos()?;
let sin = freqs.sin()?;
let out = Tensor::stack(&[&cos, &sin.neg()?, &sin, &cos], 3)?;
let (b, n, d, _ij) = out.dims4()?;
out.reshape((b, n, d, 2, 2))
}
pub(crate) fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
let dims = x.dims();
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
let fr0 = freq_cis.get_on_dim(D::Minus1, 0)?;
let fr1 = freq_cis.get_on_dim(D::Minus1, 1)?;
(fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
}
pub(crate) fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
let q = apply_rope(q, pe)?.contiguous()?;
let k = apply_rope(k, pe)?.contiguous()?;
let x = scaled_dot_product_attention(&q, &k, v)?;
x.transpose(1, 2)?.flatten_from(2)
}
pub(crate) fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
const TIME_FACTOR: f64 = 1000.;
const MAX_PERIOD: f64 = 10000.;
if dim % 2 == 1 {
candle_core::bail!("{dim} is odd")
}
let dev = t.device();
let half = dim / 2;
let t = (t * TIME_FACTOR)?;
let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(DType::F32)?;
let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?;
let args = t
.unsqueeze(1)?
.to_dtype(DType::F32)?
.broadcast_mul(&freqs.unsqueeze(0)?)?;
Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?.to_dtype(dtype)
}
#[derive(Debug, Clone)]
pub(crate) struct EmbedNd {
theta: usize,
axes_dim: Vec<usize>,
}
impl EmbedNd {
pub(crate) fn new(theta: usize, axes_dim: Vec<usize>) -> Self {
Self { theta, axes_dim }
}
}
impl candle_core::Module for EmbedNd {
fn forward(&self, ids: &Tensor) -> Result<Tensor> {
let n_axes = ids.dim(D::Minus1)?;
let mut emb = Vec::with_capacity(n_axes);
for idx in 0..n_axes {
emb.push(rope(
&ids.get_on_dim(D::Minus1, idx)?,
self.axes_dim[idx],
self.theta,
)?)
}
Tensor::cat(&emb, 2)?.unsqueeze(1)
}
}
#[derive(Debug, Clone)]
struct MlpEmbedder {
in_layer: Flux2Linear,
out_layer: Flux2Linear,
}
impl MlpEmbedder {
fn new(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result<Self> {
let in_layer = flux2_linear_no_bias(in_sz, h_sz, vb.pp("linear_1"))?;
let out_layer = flux2_linear_no_bias(h_sz, h_sz, vb.pp("linear_2"))?;
Ok(Self {
in_layer,
out_layer,
})
}
fn to_device(&self, device: &candle_core::Device) -> Result<Self> {
Ok(Self {
in_layer: self.in_layer.to_device(device)?,
out_layer: self.out_layer.to_device(device)?,
})
}
}
impl candle_core::Module for MlpEmbedder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer)
}
}
struct ModulationOut {
shift: Tensor,
scale: Tensor,
gate: Tensor,
}
impl ModulationOut {
fn scale_shift(&self, xs: &Tensor) -> Result<Tensor> {
xs.broadcast_mul(&(&self.scale + 1.)?)?
.broadcast_add(&self.shift)
}
fn gate(&self, xs: &Tensor) -> Result<Tensor> {
self.gate.broadcast_mul(xs)
}
}
#[derive(Debug, Clone)]
struct Modulation1 {
lin: Flux2Linear,
}
impl Modulation1 {
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
let lin = flux2_linear_no_bias(dim, 3 * dim, vb.pp("linear"))?;
Ok(Self { lin })
}
fn to_device(&self, device: &candle_core::Device) -> Result<Self> {
Ok(Self {
lin: self.lin.to_device(device)?,
})
}
fn forward(&self, vec_: &Tensor) -> Result<ModulationOut> {
let ys = vec_
.silu()?
.apply(&self.lin)?
.unsqueeze(1)?
.chunk(3, D::Minus1)?;
if ys.len() != 3 {
candle_core::bail!("unexpected len from chunk {ys:?}")
}
Ok(ModulationOut {
shift: ys[0].clone(),
scale: ys[1].clone(),
gate: ys[2].clone(),
})
}
}
#[derive(Debug, Clone)]
struct Modulation2 {
lin: Flux2Linear,
}
impl Modulation2 {
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
let lin = flux2_linear_no_bias(dim, 6 * dim, vb.pp("linear"))?;
Ok(Self { lin })
}
fn to_device(&self, device: &candle_core::Device) -> Result<Self> {
Ok(Self {
lin: self.lin.to_device(device)?,
})
}
fn forward(&self, vec_: &Tensor) -> Result<(ModulationOut, ModulationOut)> {
let ys = vec_
.silu()?
.apply(&self.lin)?
.unsqueeze(1)?
.chunk(6, D::Minus1)?;
if ys.len() != 6 {
candle_core::bail!("unexpected len from chunk {ys:?}")
}
Ok((
ModulationOut {
shift: ys[0].clone(),
scale: ys[1].clone(),
gate: ys[2].clone(),
},
ModulationOut {
shift: ys[3].clone(),
scale: ys[4].clone(),
gate: ys[5].clone(),
},
))
}
}
#[derive(Debug, Clone)]
struct Mlp {
lin1: Flux2Linear,
lin2: Flux2Linear,
mlp_sz: usize,
}
impl Mlp {
fn new(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result<Self> {
let lin1 = flux2_linear_no_bias(in_sz, mlp_sz * 2, vb.pp("linear_in"))?;
let lin2 = flux2_linear_no_bias(mlp_sz, in_sz, vb.pp("linear_out"))?;
Ok(Self { lin1, lin2, mlp_sz })
}
fn to_device(&self, device: &candle_core::Device) -> Result<Self> {
Ok(Self {
lin1: self.lin1.to_device(device)?,
lin2: self.lin2.to_device(device)?,
mlp_sz: self.mlp_sz,
})
}
}
fn mlp_bytes(mlp: &Mlp) -> usize {
flux2_linear_bytes(&mlp.lin1) + flux2_linear_bytes(&mlp.lin2)
}
impl candle_core::Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let x = xs.apply(&self.lin1)?;
let gate = x.narrow(D::Minus1, 0, self.mlp_sz)?.silu()?;
let val = x.narrow(D::Minus1, self.mlp_sz, self.mlp_sz)?;
(gate * val)?.apply(&self.lin2)
}
}
#[derive(Debug, Clone)]
struct DoubleAttention {
to_q: Flux2Linear,
to_k: Flux2Linear,
to_v: Flux2Linear,
to_out: Flux2Linear,
norm_q: RmsNorm,
norm_k: RmsNorm,
num_heads: usize,
}
impl DoubleAttention {
fn new_img(dim: usize, num_heads: usize, vb: VarBuilder) -> Result<Self> {
let head_dim = dim / num_heads;
Ok(Self {
to_q: flux2_linear_no_bias(dim, dim, vb.pp("to_q"))?,
to_k: flux2_linear_no_bias(dim, dim, vb.pp("to_k"))?,
to_v: flux2_linear_no_bias(dim, dim, vb.pp("to_v"))?,
to_out: flux2_linear_no_bias(dim, dim, vb.pp("to_out").pp("0"))?,
norm_q: RmsNorm::new(vb.get(head_dim, "norm_q.weight")?, 1e-6),
norm_k: RmsNorm::new(vb.get(head_dim, "norm_k.weight")?, 1e-6),
num_heads,
})
}
fn new_txt(dim: usize, num_heads: usize, vb: VarBuilder) -> Result<Self> {
let head_dim = dim / num_heads;
Ok(Self {
to_q: flux2_linear_no_bias(dim, dim, vb.pp("add_q_proj"))?,
to_k: flux2_linear_no_bias(dim, dim, vb.pp("add_k_proj"))?,
to_v: flux2_linear_no_bias(dim, dim, vb.pp("add_v_proj"))?,
to_out: flux2_linear_no_bias(dim, dim, vb.pp("to_add_out"))?,
norm_q: RmsNorm::new(vb.get(head_dim, "norm_added_q.weight")?, 1e-6),
norm_k: RmsNorm::new(vb.get(head_dim, "norm_added_k.weight")?, 1e-6),
num_heads,
})
}
fn qkv(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let (b, l, _) = xs.dims3()?;
let q = xs
.apply(&self.to_q)?
.reshape((b, l, self.num_heads, ()))?
.transpose(1, 2)?
.apply(&self.norm_q)?;
let k = xs
.apply(&self.to_k)?
.reshape((b, l, self.num_heads, ()))?
.transpose(1, 2)?
.apply(&self.norm_k)?;
let v = xs
.apply(&self.to_v)?
.reshape((b, l, self.num_heads, ()))?
.transpose(1, 2)?;
Ok((q, k, v))
}
fn to_device(&self, device: &candle_core::Device) -> Result<Self> {
Ok(Self {
to_q: self.to_q.to_device(device)?,
to_k: self.to_k.to_device(device)?,
to_v: self.to_v.to_device(device)?,
to_out: self.to_out.to_device(device)?,
norm_q: rms_norm_to_device(&self.norm_q, device)?,
norm_k: rms_norm_to_device(&self.norm_k, device)?,
num_heads: self.num_heads,
})
}
}
fn double_attention_bytes(attention: &DoubleAttention) -> usize {
flux2_linear_bytes(&attention.to_q)
+ flux2_linear_bytes(&attention.to_k)
+ flux2_linear_bytes(&attention.to_v)
+ flux2_linear_bytes(&attention.to_out)
+ rms_norm_bytes(&attention.norm_q)
+ rms_norm_bytes(&attention.norm_k)
}
#[derive(Debug, Clone)]
struct DoubleStreamBlock {
img_norm1: LayerNorm,
img_attn: DoubleAttention,
img_norm2: LayerNorm,
img_mlp: Mlp,
txt_attn: DoubleAttention,
txt_norm1: LayerNorm,
txt_norm2: LayerNorm,
txt_mlp: Mlp,
}
impl DoubleStreamBlock {
fn new(cfg: &Flux2Config, vb: VarBuilder) -> Result<Self> {
let h_sz = cfg.hidden_size;
let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
let attn_vb = vb.pp("attn");
Ok(Self {
img_norm1: layer_norm(h_sz, &vb)?,
img_attn: DoubleAttention::new_img(h_sz, cfg.num_heads, attn_vb.clone())?,
img_norm2: layer_norm(h_sz, &vb)?,
img_mlp: Mlp::new(h_sz, mlp_sz, vb.pp("ff"))?,
txt_attn: DoubleAttention::new_txt(h_sz, cfg.num_heads, attn_vb)?,
txt_norm1: layer_norm(h_sz, &vb)?,
txt_norm2: layer_norm(h_sz, &vb)?,
txt_mlp: Mlp::new(h_sz, mlp_sz, vb.pp("ff_context"))?,
})
}
fn to_device(&self, device: &candle_core::Device) -> Result<Self> {
Ok(Self {
img_norm1: layer_norm_to_device(&self.img_norm1, device)?,
img_attn: self.img_attn.to_device(device)?,
img_norm2: layer_norm_to_device(&self.img_norm2, device)?,
img_mlp: self.img_mlp.to_device(device)?,
txt_attn: self.txt_attn.to_device(device)?,
txt_norm1: layer_norm_to_device(&self.txt_norm1, device)?,
txt_norm2: layer_norm_to_device(&self.txt_norm2, device)?,
txt_mlp: self.txt_mlp.to_device(device)?,
})
}
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
img: &Tensor,
txt: &Tensor,
img_mod1: &ModulationOut,
img_mod2: &ModulationOut,
txt_mod1: &ModulationOut,
txt_mod2: &ModulationOut,
pe: &Tensor,
) -> Result<(Tensor, Tensor)> {
let img_modulated = img_mod1.scale_shift(&img.apply(&self.img_norm1)?)?;
let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?;
let txt_modulated = txt_mod1.scale_shift(&txt.apply(&self.txt_norm1)?)?;
let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?;
let q = Tensor::cat(&[txt_q, img_q], 2)?;
let k = Tensor::cat(&[txt_k, img_k], 2)?;
let v = Tensor::cat(&[txt_v, img_v], 2)?;
let attn = attention(&q, &k, &v, pe)?;
let txt_attn_out = attn.narrow(1, 0, txt.dim(1)?)?;
let img_attn_out = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?;
let img = (img + img_mod1.gate(&img_attn_out.apply(&self.img_attn.to_out)?))?;
let img = (&img
+ img_mod2.gate(
&img_mod2
.scale_shift(&img.apply(&self.img_norm2)?)?
.apply(&self.img_mlp)?,
)?)?;
let txt = (txt + txt_mod1.gate(&txt_attn_out.apply(&self.txt_attn.to_out)?))?;
let txt = (&txt
+ txt_mod2.gate(
&txt_mod2
.scale_shift(&txt.apply(&self.txt_norm2)?)?
.apply(&self.txt_mlp)?,
)?)?;
Ok((img, txt))
}
}
fn double_stream_block_bytes(block: &DoubleStreamBlock) -> usize {
layer_norm_bytes(&block.img_norm1)
+ double_attention_bytes(&block.img_attn)
+ layer_norm_bytes(&block.img_norm2)
+ mlp_bytes(&block.img_mlp)
+ double_attention_bytes(&block.txt_attn)
+ layer_norm_bytes(&block.txt_norm1)
+ layer_norm_bytes(&block.txt_norm2)
+ mlp_bytes(&block.txt_mlp)
}
#[derive(Debug, Clone)]
struct SingleStreamBlock {
linear1: Flux2Linear,
linear2: Flux2Linear,
norm_q: RmsNorm,
norm_k: RmsNorm,
pre_norm: LayerNorm,
h_sz: usize,
mlp_sz: usize,
num_heads: usize,
}
impl SingleStreamBlock {
fn new(cfg: &Flux2Config, vb: VarBuilder) -> Result<Self> {
let h_sz = cfg.hidden_size;
let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
let head_dim = h_sz / cfg.num_heads;
let attn_vb = vb.pp("attn");
let linear1 =
flux2_linear_no_bias(h_sz, h_sz * 3 + mlp_sz * 2, attn_vb.pp("to_qkv_mlp_proj"))?;
let linear2 = flux2_linear_no_bias(h_sz + mlp_sz, h_sz, attn_vb.pp("to_out"))?;
Ok(Self {
linear1,
linear2,
norm_q: RmsNorm::new(attn_vb.get(head_dim, "norm_q.weight")?, 1e-6),
norm_k: RmsNorm::new(attn_vb.get(head_dim, "norm_k.weight")?, 1e-6),
pre_norm: layer_norm(h_sz, &vb)?,
h_sz,
mlp_sz,
num_heads: cfg.num_heads,
})
}
fn to_device(&self, device: &candle_core::Device) -> Result<Self> {
Ok(Self {
linear1: self.linear1.to_device(device)?,
linear2: self.linear2.to_device(device)?,
norm_q: rms_norm_to_device(&self.norm_q, device)?,
norm_k: rms_norm_to_device(&self.norm_k, device)?,
pre_norm: layer_norm_to_device(&self.pre_norm, device)?,
h_sz: self.h_sz,
mlp_sz: self.mlp_sz,
num_heads: self.num_heads,
})
}
fn forward(&self, xs: &Tensor, mod_out: &ModulationOut, pe: &Tensor) -> Result<Tensor> {
let x_mod = mod_out.scale_shift(&xs.apply(&self.pre_norm)?)?;
let x_mod = x_mod.apply(&self.linear1)?;
let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?;
let (b, l, _) = qkv.dims3()?;
let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
let q = qkv.i((.., .., 0))?.transpose(1, 2)?.apply(&self.norm_q)?;
let k = qkv.i((.., .., 1))?.transpose(1, 2)?.apply(&self.norm_k)?;
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
let mlp_portion = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz * 2)?;
let attn = attention(&q, &k, &v, pe)?;
let mlp_gate = mlp_portion.narrow(D::Minus1, 0, self.mlp_sz)?.silu()?;
let mlp_val = mlp_portion.narrow(D::Minus1, self.mlp_sz, self.mlp_sz)?;
let mlp_out = (mlp_gate * mlp_val)?;
let output = Tensor::cat(&[attn, mlp_out], 2)?.apply(&self.linear2)?;
xs + mod_out.gate(&output)
}
}
fn single_stream_block_bytes(block: &SingleStreamBlock) -> usize {
flux2_linear_bytes(&block.linear1)
+ flux2_linear_bytes(&block.linear2)
+ rms_norm_bytes(&block.norm_q)
+ rms_norm_bytes(&block.norm_k)
+ layer_norm_bytes(&block.pre_norm)
}
#[derive(Debug, Clone)]
struct LastLayer {
norm_final: LayerNorm,
linear: Flux2Linear,
ada_ln_modulation: Flux2Linear,
}
impl LastLayer {
fn new(h_sz: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
Ok(Self {
norm_final: layer_norm(h_sz, &vb)?,
linear: flux2_linear_no_bias(h_sz, out_c, vb.pp("proj_out"))?,
ada_ln_modulation: flux2_linear_no_bias(
h_sz,
2 * h_sz,
vb.pp("norm_out").pp("linear"),
)?,
})
}
fn to_device(&self, device: &candle_core::Device) -> Result<Self> {
Ok(Self {
norm_final: layer_norm_to_device(&self.norm_final, device)?,
linear: self.linear.to_device(device)?,
ada_ln_modulation: self.ada_ln_modulation.to_device(device)?,
})
}
fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result<Tensor> {
let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?;
let (scale, shift) = (&chunks[0], &chunks[1]);
let xs = xs
.apply(&self.norm_final)?
.broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)?
.broadcast_add(&shift.unsqueeze(1)?)?;
xs.apply(&self.linear)
}
}
#[derive(Debug, Clone)]
pub struct Flux2Transformer {
img_in: Flux2Linear,
txt_in: Flux2Linear,
time_in: MlpEmbedder,
vector_in: Option<MlpEmbedder>,
guidance_in: Option<MlpEmbedder>,
pe_embedder: EmbedNd,
double_mod_img: Modulation2,
double_mod_txt: Modulation2,
single_mod: Modulation1,
double_blocks: Vec<DoubleStreamBlock>,
single_blocks: Vec<SingleStreamBlock>,
final_layer: LastLayer,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum Flux2StreamingBlock {
Double(usize),
Single(usize),
}
pub(crate) fn flux2_streaming_block_plan(cfg: &Flux2Config) -> Vec<Flux2StreamingBlock> {
let mut blocks = Vec::with_capacity(cfg.depth + cfg.depth_single_blocks);
blocks.extend((0..cfg.depth).map(Flux2StreamingBlock::Double));
blocks.extend((0..cfg.depth_single_blocks).map(Flux2StreamingBlock::Single));
blocks
}
enum DoubleBlockSlot {
Resident(DoubleStreamBlock),
Streamed(DoubleStreamBlock),
}
enum SingleBlockSlot {
Resident(SingleStreamBlock),
Streamed(SingleStreamBlock),
}
fn is_probable_cuda_oom(err: &candle_core::Error) -> bool {
let msg = err.to_string().to_ascii_lowercase();
msg.contains("cuda_error_out_of_memory")
|| msg.contains("out of memory")
|| msg.contains("memory allocation")
}
fn materialize_flux2_block_slots(
double_blocks: &[DoubleStreamBlock],
single_blocks: &[SingleStreamBlock],
plan: &AdaptiveResidencyPlan,
device: &candle_core::Device,
) -> Result<(Vec<DoubleBlockSlot>, Vec<SingleBlockSlot>)> {
let mut double_slots = Vec::with_capacity(double_blocks.len());
for (i, block) in double_blocks.iter().enumerate() {
if plan.resident.get(i).copied().unwrap_or(false) {
double_slots.push(DoubleBlockSlot::Resident(block.to_device(device)?));
} else {
double_slots.push(DoubleBlockSlot::Streamed(block.clone()));
}
}
let single_offset = double_blocks.len();
let mut single_slots = Vec::with_capacity(single_blocks.len());
for (i, block) in single_blocks.iter().enumerate() {
if plan
.resident
.get(single_offset + i)
.copied()
.unwrap_or(false)
{
single_slots.push(SingleBlockSlot::Resident(block.to_device(device)?));
} else {
single_slots.push(SingleBlockSlot::Streamed(block.clone()));
}
}
Ok((double_slots, single_slots))
}
pub(crate) struct OffloadedFlux2Transformer {
block_plan: Vec<Flux2StreamingBlock>,
img_in: Flux2Linear,
txt_in: Flux2Linear,
time_in: MlpEmbedder,
vector_in: Option<MlpEmbedder>,
guidance_in: Option<MlpEmbedder>,
pe_embedder: EmbedNd,
double_mod_img: Modulation2,
double_mod_txt: Modulation2,
single_mod: Modulation1,
double_blocks: Vec<DoubleBlockSlot>,
single_blocks: Vec<SingleBlockSlot>,
final_layer: LastLayer,
device: candle_core::Device,
}
impl OffloadedFlux2Transformer {
pub(crate) fn new(
cfg: &Flux2Config,
cpu_vb: VarBuilder,
device: &candle_core::Device,
gpu_ordinal: usize,
activation_budget: u64,
progress: &ProgressReporter,
) -> Result<Self> {
let block_plan = flux2_streaming_block_plan(cfg);
let dense = Flux2Transformer::new(cfg, cpu_vb)?;
Self::from_dense(
dense,
block_plan,
device,
gpu_ordinal,
activation_budget,
progress,
)
}
fn from_dense(
dense: Flux2Transformer,
block_plan: Vec<Flux2StreamingBlock>,
device: &candle_core::Device,
gpu_ordinal: usize,
activation_budget: u64,
progress: &ProgressReporter,
) -> Result<Self> {
let Flux2Transformer {
img_in,
txt_in,
time_in,
vector_in,
guidance_in,
pe_embedder,
double_mod_img,
double_mod_txt,
single_mod,
double_blocks,
single_blocks,
final_layer,
} = dense;
let img_in = img_in.to_device(device)?;
let txt_in = txt_in.to_device(device)?;
let time_in = time_in.to_device(device)?;
let vector_in = vector_in
.as_ref()
.map(|embedder| embedder.to_device(device))
.transpose()?;
let guidance_in = guidance_in
.as_ref()
.map(|embedder| embedder.to_device(device))
.transpose()?;
let double_mod_img = double_mod_img.to_device(device)?;
let double_mod_txt = double_mod_txt.to_device(device)?;
let single_mod = single_mod.to_device(device)?;
let final_layer = final_layer.to_device(device)?;
let mut block_sizes = Vec::with_capacity(double_blocks.len() + single_blocks.len());
block_sizes.extend(double_blocks.iter().map(double_stream_block_bytes));
block_sizes.extend(single_blocks.iter().map(single_stream_block_bytes));
let free_vram = crate::device::usable_free_vram_bytes(gpu_ordinal).unwrap_or(0);
let mut plan = plan_adaptive_residency(
&block_sizes,
free_vram,
activation_budget,
ADAPTIVE_OFFLOAD_RUNTIME_HEADROOM,
);
let (double_blocks, single_blocks, plan) = loop {
match materialize_flux2_block_slots(&double_blocks, &single_blocks, &plan, device) {
Ok((double_slots, single_slots)) => break (double_slots, single_slots, plan),
Err(err)
if device.is_cuda()
&& plan.resident_count() > 0
&& is_probable_cuda_oom(&err) =>
{
progress.info(&format!(
"Flux.2 adaptive offload: resident allocation OOM at {} resident blocks; \
retrying with fewer resident blocks",
plan.resident_count()
));
if let Err(sync_err) = device.synchronize() {
tracing::warn!(
"Flux.2 adaptive offload: synchronize after OOM failed: {sync_err}"
);
}
if !plan.demote_largest_resident(&block_sizes) {
return Err(err);
}
}
Err(err) => return Err(err),
}
};
progress.info(&format!(
"Flux.2 adaptive offload: {} resident / {} streamed blocks \
(resident {:.2} GB, streamed {:.2} GB per denoise pass, reserve {:.2} GB)",
plan.resident_count(),
plan.streamed_count(),
plan.resident_bytes as f64 / 1_000_000_000.0,
plan.streamed_bytes as f64 / 1_000_000_000.0,
plan.reserved_bytes() as f64 / 1_000_000_000.0,
));
Ok(Self {
block_plan,
img_in,
txt_in,
time_in,
vector_in,
guidance_in,
pe_embedder,
double_mod_img,
double_mod_txt,
single_mod,
double_blocks,
single_blocks,
final_layer,
device: device.clone(),
})
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn forward(
&self,
img: &Tensor,
img_ids: &Tensor,
txt: &Tensor,
txt_ids: &Tensor,
timesteps: &Tensor,
y: &Tensor,
guidance: Option<&Tensor>,
) -> Result<Tensor> {
if txt.rank() != 3 || img.rank() != 3 {
candle_core::bail!("expected rank 3, got txt={} img={}", txt.rank(), img.rank())
}
let device = &self.device;
let dtype = img.dtype();
let img = img.to_device(device)?;
let txt = txt.to_device(device)?;
let img_ids = img_ids.to_device(device)?;
let txt_ids = txt_ids.to_device(device)?;
let timesteps = timesteps.to_device(device)?;
let y = y.to_device(device)?;
let guidance = guidance.map(|g| g.to_device(device)).transpose()?;
let pe = {
let ids = Tensor::cat(&[&txt_ids, &img_ids], 1)?;
ids.apply(&self.pe_embedder)?
};
let mut txt = txt.apply(&self.txt_in)?;
let mut img = img.apply(&self.img_in)?;
let mut vec_ = timestep_embedding(×teps, 256, dtype)?.apply(&self.time_in)?;
if let (Some(g_in), Some(guidance)) = (self.guidance_in.as_ref(), guidance.as_ref()) {
vec_ = (vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))?;
}
if let Some(vec_in) = self.vector_in.as_ref() {
vec_ = (vec_ + y.apply(vec_in))?;
}
let (img_mod1, img_mod2) = self.double_mod_img.forward(&vec_)?;
let (txt_mod1, txt_mod2) = self.double_mod_txt.forward(&vec_)?;
debug_assert_eq!(
self.block_plan.len(),
self.double_blocks.len() + self.single_blocks.len()
);
for block in &self.double_blocks {
match block {
DoubleBlockSlot::Resident(block) => {
(img, txt) = block
.forward(&img, &txt, &img_mod1, &img_mod2, &txt_mod1, &txt_mod2, &pe)?;
}
DoubleBlockSlot::Streamed(block) => {
let block = block.to_device(device)?;
(img, txt) = block
.forward(&img, &txt, &img_mod1, &img_mod2, &txt_mod1, &txt_mod2, &pe)?;
device.synchronize()?;
drop(block);
}
}
}
let single_mod = self.single_mod.forward(&vec_)?;
let mut img = Tensor::cat(&[&txt, &img], 1)?;
for block in &self.single_blocks {
match block {
SingleBlockSlot::Resident(block) => {
img = block.forward(&img, &single_mod, &pe)?;
}
SingleBlockSlot::Streamed(block) => {
let block = block.to_device(device)?;
img = block.forward(&img, &single_mod, &pe)?;
device.synchronize()?;
drop(block);
}
}
}
let img = img.i((.., txt.dim(1)?..))?;
self.final_layer.forward(&img, &vec_)
}
}
impl Flux2Transformer {
pub fn new(cfg: &Flux2Config, vb: VarBuilder) -> Result<Self> {
let img_in = flux2_linear_no_bias(cfg.in_channels, cfg.hidden_size, vb.pp("x_embedder"))?;
let txt_in = flux2_linear_no_bias(
cfg.context_in_dim,
cfg.hidden_size,
vb.pp("context_embedder"),
)?;
let time_in = MlpEmbedder::new(
256,
cfg.hidden_size,
vb.pp("time_guidance_embed").pp("timestep_embedder"),
)?;
let vector_in = if cfg.vec_in_dim > 0 {
Some(MlpEmbedder::new(
cfg.vec_in_dim,
cfg.hidden_size,
vb.pp("vector_in"),
)?)
} else {
None
};
let guidance_in = if cfg.guidance_embed {
Some(MlpEmbedder::new(
256,
cfg.hidden_size,
vb.pp("time_guidance_embed").pp("guidance_embedder"),
)?)
} else {
None
};
let double_mod_img =
Modulation2::new(cfg.hidden_size, vb.pp("double_stream_modulation_img"))?;
let double_mod_txt =
Modulation2::new(cfg.hidden_size, vb.pp("double_stream_modulation_txt"))?;
let single_mod = Modulation1::new(cfg.hidden_size, vb.pp("single_stream_modulation"))?;
let mut double_blocks = Vec::with_capacity(cfg.depth);
let vb_d = vb.pp("transformer_blocks");
for idx in 0..cfg.depth {
double_blocks.push(DoubleStreamBlock::new(cfg, vb_d.pp(idx))?);
}
let mut single_blocks = Vec::with_capacity(cfg.depth_single_blocks);
let vb_s = vb.pp("single_transformer_blocks");
for idx in 0..cfg.depth_single_blocks {
single_blocks.push(SingleStreamBlock::new(cfg, vb_s.pp(idx))?);
}
let final_layer = LastLayer::new(cfg.hidden_size, cfg.in_channels, vb.clone())?;
let pe_embedder = EmbedNd::new(cfg.theta, cfg.axes_dim.to_vec());
Ok(Self {
img_in,
txt_in,
time_in,
vector_in,
guidance_in,
pe_embedder,
double_mod_img,
double_mod_txt,
single_mod,
double_blocks,
single_blocks,
final_layer,
})
}
#[allow(clippy::too_many_arguments)]
pub fn forward(
&self,
img: &Tensor,
img_ids: &Tensor,
txt: &Tensor,
txt_ids: &Tensor,
timesteps: &Tensor,
y: &Tensor,
guidance: Option<&Tensor>,
) -> Result<Tensor> {
if txt.rank() != 3 || img.rank() != 3 {
candle_core::bail!("expected rank 3, got txt={} img={}", txt.rank(), img.rank())
}
let dtype = img.dtype();
let pe = {
let ids = Tensor::cat(&[txt_ids, img_ids], 1)?;
ids.apply(&self.pe_embedder)?
};
let mut txt = txt.apply(&self.txt_in)?;
let mut img = img.apply(&self.img_in)?;
let mut vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?;
if let (Some(g_in), Some(guidance)) = (self.guidance_in.as_ref(), guidance) {
vec_ = (vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))?;
}
if let Some(vec_in) = self.vector_in.as_ref() {
vec_ = (vec_ + y.apply(vec_in))?;
}
let (img_mod1, img_mod2) = self.double_mod_img.forward(&vec_)?;
let (txt_mod1, txt_mod2) = self.double_mod_txt.forward(&vec_)?;
for block in &self.double_blocks {
(img, txt) =
block.forward(&img, &txt, &img_mod1, &img_mod2, &txt_mod1, &txt_mod2, &pe)?;
}
let single_mod = self.single_mod.forward(&vec_)?;
let mut img = Tensor::cat(&[&txt, &img], 1)?;
for block in &self.single_blocks {
img = block.forward(&img, &single_mod, &pe)?;
}
let img = img.i((.., txt.dim(1)?..))?;
self.final_layer.forward(&img, &vec_)
}
}
#[allow(clippy::large_enum_variant)]
pub(crate) enum Flux2TransformerWrapper {
BF16(Flux2Transformer),
Offloaded(OffloadedFlux2Transformer),
Quantized(super::quantized_transformer::QuantizedFlux2Transformer),
}
impl Flux2TransformerWrapper {
#[allow(clippy::too_many_arguments)]
pub fn denoise(
&self,
img: &Tensor,
img_ids: &Tensor,
txt: &Tensor,
txt_ids: &Tensor,
vec_: &Tensor,
timesteps: &[f64],
guidance: f64,
progress: &crate::progress::ProgressReporter,
inpaint_ctx: Option<&crate::img_utils::InpaintContext>,
) -> anyhow::Result<Tensor> {
use crate::progress::ProgressEvent;
use std::time::Instant;
let b_sz = img.dim(0)?;
let dev = img.device();
let guidance_tensor = Tensor::full(guidance as f32, b_sz, dev)?;
let mut img = img.clone();
let total_steps = timesteps.len().saturating_sub(1);
for (step, window) in timesteps.windows(2).enumerate() {
let step_start = Instant::now();
let (t_curr, t_prev) = match window {
[a, b] => (a, b),
_ => continue,
};
let t_vec = Tensor::full(*t_curr as f32, b_sz, dev)?;
let pred = match self {
Self::BF16(m) => m.forward(
&img,
img_ids,
txt,
txt_ids,
&t_vec,
vec_,
Some(&guidance_tensor),
)?,
Self::Offloaded(m) => m.forward(
&img,
img_ids,
txt,
txt_ids,
&t_vec,
vec_,
Some(&guidance_tensor),
)?,
Self::Quantized(m) => m.forward(
&img,
img_ids,
txt,
txt_ids,
&t_vec,
vec_,
Some(&guidance_tensor),
)?,
};
img = (img + pred * (t_prev - t_curr))?;
if let Some(ctx) = inpaint_ctx {
img = crate::img2img::apply_flow_match_inpaint(&img, ctx, *t_prev)?;
}
progress.emit(ProgressEvent::DenoiseStep {
step: step + 1,
total: total_steps,
elapsed: step_start.elapsed(),
});
}
Ok(img)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn klein_config_dimensions() {
let cfg = Flux2Config::klein();
assert_eq!(cfg.in_channels, 128);
assert_eq!(cfg.hidden_size, 3072);
assert_eq!(cfg.num_heads, 24);
assert_eq!(cfg.hidden_size / cfg.num_heads, 128); assert_eq!(cfg.depth, 5);
assert_eq!(cfg.depth_single_blocks, 20);
assert_eq!(cfg.axes_dim, vec![32, 32, 32, 32]);
assert_eq!(cfg.theta, 2000);
assert!(!cfg.guidance_embed); }
#[test]
fn klein_mlp_sizes() {
let cfg = Flux2Config::klein();
let h_sz = cfg.hidden_size; let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize; assert_eq!(mlp_sz, 9216);
assert_eq!(h_sz * 3 + mlp_sz * 2, 27648); assert_eq!(h_sz + mlp_sz, 12288); }
#[test]
fn klein_context_dim_matches_qwen3() {
let cfg = Flux2Config::klein();
assert_eq!(cfg.context_in_dim, 7680);
assert_eq!(cfg.context_in_dim, 2560 * 3);
}
#[test]
fn klein_9b_config_dimensions() {
let cfg = Flux2Config::klein_9b();
assert_eq!(cfg.in_channels, 128);
assert_eq!(cfg.hidden_size, 4096);
assert_eq!(cfg.num_heads, 32);
assert_eq!(cfg.hidden_size / cfg.num_heads, 128); assert_eq!(cfg.depth, 8);
assert_eq!(cfg.depth_single_blocks, 24);
assert_eq!(cfg.context_in_dim, 12288);
assert_eq!(cfg.context_in_dim, 4096 * 3); assert!(!cfg.guidance_embed); }
#[test]
fn flux2_streaming_block_plan_preserves_reference_order() {
let mut cfg = Flux2Config::klein();
cfg.depth = 2;
cfg.depth_single_blocks = 3;
assert_eq!(
flux2_streaming_block_plan(&cfg),
vec![
Flux2StreamingBlock::Double(0),
Flux2StreamingBlock::Double(1),
Flux2StreamingBlock::Single(0),
Flux2StreamingBlock::Single(1),
Flux2StreamingBlock::Single(2),
]
);
}
#[test]
fn timestep_embedding_shape() {
let dev = candle_core::Device::Cpu;
let t = Tensor::full(0.5f32, 2, &dev).unwrap();
let emb = timestep_embedding(&t, 256, DType::F32).unwrap();
assert_eq!(emb.dims(), &[2, 256]);
}
#[test]
fn rope_4d_shape() {
let dev = candle_core::Device::Cpu;
let pos = Tensor::zeros((1, 16), DType::F32, &dev).unwrap();
let r = rope(&pos, 32, 2000).unwrap();
assert_eq!(r.dims(), &[1, 16, 16, 2, 2]);
}
#[test]
fn test_timestep_embedding_dtype_preserved() {
let dev = candle_core::Device::Cpu;
let t = Tensor::full(0.5f32, 2, &dev).unwrap();
let emb = timestep_embedding(&t, 128, DType::BF16).unwrap();
assert_eq!(emb.dtype(), DType::BF16);
assert_eq!(emb.dims(), &[2, 128]);
}
#[test]
fn test_timestep_embedding_values_bounded() {
let dev = candle_core::Device::Cpu;
let t = Tensor::full(0.7f32, 1, &dev).unwrap();
let emb = timestep_embedding(&t, 64, DType::F32).unwrap();
let flat = emb.flatten_all().unwrap();
let vals: Vec<f32> = flat.to_vec1().unwrap();
for v in &vals {
assert!(
*v >= -1.0 && *v <= 1.0,
"embedding value {v} outside [-1, 1] (sin/cos bounds)"
);
}
}
#[test]
fn test_rope_odd_dim_fails() {
let dev = candle_core::Device::Cpu;
let pos = Tensor::zeros((1, 4), DType::F32, &dev).unwrap();
let result = rope(&pos, 33, 2000);
assert!(result.is_err(), "rope with odd dim should fail");
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("odd"),
"error should mention 'odd', got: {err_msg}"
);
}
#[test]
fn flux2_linear_standard_bf16_forward() {
let dev = candle_core::Device::Cpu;
let weight = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], (2, 2), &dev).unwrap();
let lin = Flux2Linear::Standard(candle_nn::Linear::new(weight, None));
let x = Tensor::from_vec(vec![1.0f32, 0.0], (1, 2), &dev).unwrap();
let out = lin.forward(&x).unwrap();
let v: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(v, vec![1.0, 3.0]);
}
#[test]
fn flux2_linear_fp8_forward_matches_bf16_reference() {
let dev = candle_core::Device::Cpu;
let weight = Tensor::from_vec(vec![2.0f32; 8], (2, 4), &dev)
.unwrap()
.to_dtype(DType::F8E4M3)
.unwrap();
let lin = Flux2Linear::Fp8 {
weight,
scale: None,
bias: None,
};
let x = Tensor::from_vec(vec![1.0f32; 4], (1, 4), &dev).unwrap();
let out = lin.forward(&x).unwrap();
assert_eq!(
out.dtype(),
DType::F32,
"FP8 forward must preserve activation dtype",
);
let v: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
for x in &v {
assert!((x - 8.0).abs() < 1e-3, "got {x}, want 8.0");
}
}
#[test]
fn flux2_linear_fp8_basic_matmul() {
let dev = candle_core::Device::Cpu;
let weight = Tensor::from_vec(vec![3.0f32], (1, 1), &dev)
.unwrap()
.to_dtype(DType::F8E4M3)
.unwrap();
let lin = Flux2Linear::Fp8 {
weight,
scale: None,
bias: None,
};
let x = Tensor::from_vec(vec![2.0f32], (1, 1), &dev).unwrap();
let out = lin.forward(&x).unwrap();
let v: f32 = out.flatten_all().unwrap().to_vec1::<f32>().unwrap()[0];
assert!((v - 6.0).abs() < 1e-3);
}
#[test]
fn flux2_linear_fp8_applies_bias_after_matmul() {
let dev = candle_core::Device::Cpu;
let weight = Tensor::from_vec(vec![1.0f32], (1, 1), &dev)
.unwrap()
.to_dtype(DType::F8E4M3)
.unwrap();
let bias = Tensor::from_vec(vec![10.0f32], 1, &dev).unwrap();
let lin = Flux2Linear::Fp8 {
weight,
scale: None,
bias: Some(bias),
};
let x = Tensor::from_vec(vec![3.0f32], (1, 1), &dev).unwrap();
let out = lin.forward(&x).unwrap();
let v: f32 = out.flatten_all().unwrap().to_vec1::<f32>().unwrap()[0];
assert!((v - 13.0).abs() < 1e-2);
}
#[test]
fn flux2_linear_fp8_applies_sidecar_scale_at_forward() {
let dev = candle_core::Device::Cpu;
let weight = Tensor::from_vec(vec![2.0f32; 8], (2, 4), &dev)
.unwrap()
.to_dtype(DType::F8E4M3)
.unwrap();
let scale = Tensor::from_vec(vec![0.5f32], 1, &dev).unwrap();
let lin = Flux2Linear::Fp8 {
weight,
scale: Some(scale),
bias: None,
};
let x = Tensor::from_vec(vec![1.0f32; 4], (1, 4), &dev).unwrap();
let out = lin.forward(&x).unwrap();
let v: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
for x in &v {
assert!(
(x - 4.0).abs() < 1e-3,
"got {x}, want 4.0 (sidecar scale 0.5 applied to FP8(2.0))",
);
}
}
fn nvfp4_unit_fixture(n_rows: usize) -> (Tensor, Tensor) {
use crate::nvfp4::swizzle_block_scales;
use candle_core::Device;
let dev = Device::Cpu;
let packed_bytes = vec![0x22u8; n_rows * 8];
let packed = Tensor::from_vec(packed_bytes, (n_rows, 8), &dev).unwrap();
let natural_scales: Vec<f32> = vec![1.0f32; n_rows];
let swizzled = swizzle_block_scales(&natural_scales, n_rows, 1).unwrap();
let padded_rows = n_rows.div_ceil(128) * 128;
let padded_cols = 4;
let scales_f32 = Tensor::from_vec(swizzled, (padded_rows, padded_cols), &dev).unwrap();
let block_scales = scales_f32.to_dtype(DType::F8E4M3).unwrap();
(packed, block_scales)
}
#[test]
fn flux2_linear_nvfp4_streaming_round_trip_matches_standard() {
let dev = candle_core::Device::Cpu;
let n_full = 4;
let k = 16;
let tensor_scale = 0.25f32;
let (packed, block_scales) = nvfp4_unit_fixture(n_full);
let streaming = Flux2Linear::Nvfp4Streaming {
packed: packed.clone(),
block_scales: block_scales.clone(),
tensor_scale,
out_dim: n_full,
in_dim: k,
slice: None,
bias: None,
cache: Arc::new(OnceLock::new()),
};
let ref_w = Tensor::from_vec(vec![tensor_scale; n_full * k], (n_full, k), &dev).unwrap();
let ref_lin = Flux2Linear::Standard(candle_nn::Linear::new(ref_w, None));
let x = Tensor::from_vec(vec![1.0f32; k], (1, k), &dev).unwrap();
let out_streaming = streaming.forward(&x).unwrap();
let out_ref = ref_lin.forward(&x).unwrap();
let s: Vec<f32> = out_streaming
.to_dtype(DType::F32)
.unwrap()
.flatten_all()
.unwrap()
.to_vec1()
.unwrap();
let r: Vec<f32> = out_ref
.to_dtype(DType::F32)
.unwrap()
.flatten_all()
.unwrap()
.to_vec1()
.unwrap();
assert_eq!(s.len(), r.len());
for (i, (a, b)) in s.iter().zip(r.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-2,
"streaming[{i}]={a}, reference={b} — diverged beyond BF16 tolerance",
);
assert!(
(a - 4.0).abs() < 1e-2,
"streaming[{i}]={a}, want 4.0 (sum of 16 × 0.25)",
);
}
}
#[test]
fn flux2_linear_nvfp4_streaming_caches_bf16() {
let dev = candle_core::Device::Cpu;
let n_full = 2;
let k = 16;
let (packed, block_scales) = nvfp4_unit_fixture(n_full);
let cache = Arc::new(OnceLock::new());
let streaming = Flux2Linear::Nvfp4Streaming {
packed,
block_scales,
tensor_scale: 1.0,
out_dim: n_full,
in_dim: k,
slice: None,
bias: None,
cache: cache.clone(),
};
assert!(cache.get().is_none(), "cache empty before first forward");
let x = Tensor::from_vec(vec![1.0f32; k], (1, k), &dev).unwrap();
let _ = streaming.forward(&x).unwrap();
assert!(
cache.get().is_some(),
"cache must be populated after first forward",
);
let cached = cache.get().unwrap();
assert_eq!(cached.dtype(), DType::BF16);
assert_eq!(cached.dims(), &[n_full, k]);
}
#[test]
fn flux2_linear_nvfp4_streaming_slice_q_k_v_share_cache() {
let dev = candle_core::Device::Cpu;
let out_dim = 2;
let n_full = out_dim * 3; let k = 16;
let tensor_scale = 0.25f32;
let (packed, block_scales) = nvfp4_unit_fixture(n_full);
let shared_cache = Arc::new(OnceLock::new());
let mut linears = Vec::with_capacity(3);
for component in 0..3 {
linears.push(Flux2Linear::Nvfp4Streaming {
packed: packed.clone(),
block_scales: block_scales.clone(),
tensor_scale,
out_dim,
in_dim: k,
slice: Some((0, component, 3)),
bias: None,
cache: shared_cache.clone(),
});
}
assert!(
shared_cache.get().is_none(),
"shared cache empty before any forward",
);
let x = Tensor::from_vec(vec![1.0f32; k], (1, k), &dev).unwrap();
let out_q = linears[0].forward(&x).unwrap();
assert!(
shared_cache.get().is_some(),
"Q-forward must populate cache"
);
let cached_after_q = shared_cache.get().unwrap().clone();
let _out_k = linears[1].forward(&x).unwrap();
let cached_after_k = shared_cache.get().unwrap().clone();
let after_q_data: Vec<f32> = cached_after_q
.to_dtype(DType::F32)
.unwrap()
.flatten_all()
.unwrap()
.to_vec1()
.unwrap();
let after_k_data: Vec<f32> = cached_after_k
.to_dtype(DType::F32)
.unwrap()
.flatten_all()
.unwrap()
.to_vec1()
.unwrap();
assert_eq!(
after_q_data, after_k_data,
"shared cache must be unchanged after subsequent forwards",
);
for (component, lin) in linears.iter().enumerate() {
let out = lin.forward(&x).unwrap();
let v: Vec<f32> = out
.to_dtype(DType::F32)
.unwrap()
.flatten_all()
.unwrap()
.to_vec1()
.unwrap();
assert_eq!(v.len(), out_dim);
for (i, &x) in v.iter().enumerate() {
assert!(
(x - 4.0).abs() < 1e-2,
"component {component} out[{i}] = {x}, want 4.0",
);
}
}
let q_v: Vec<f32> = out_q
.to_dtype(DType::F32)
.unwrap()
.flatten_all()
.unwrap()
.to_vec1()
.unwrap();
for (i, &x) in q_v.iter().enumerate() {
assert!((x - 4.0).abs() < 1e-2, "Q[{i}]={x}");
}
}
#[test]
fn test_klein_config_vec_in_dim_zero() {
let cfg = Flux2Config::klein();
assert_eq!(
cfg.vec_in_dim, 0,
"Klein vec_in_dim must be 0 (no pooled text vector)"
);
assert!(
cfg.vec_in_dim == 0,
"vec_in_dim > 0 would create an unused MlpEmbedder for Klein"
);
assert!(
!cfg.guidance_embed,
"Klein is a distilled model; guidance_embed must be false"
);
}
#[test]
fn last_layer_forward_uses_diffusers_scale_then_shift_ordering() {
use candle_core::Device;
use candle_nn::VarBuilder;
use std::collections::HashMap;
let dev = Device::Cpu;
let h_sz = 2usize;
let out_c = 2usize;
let scale_val = 3.0f32;
let shift_val = 0.5f32;
let silu_one = 0.731_058_6f32; let dot_factor = h_sz as f32 * silu_one;
let w_scale = scale_val / dot_factor;
let w_shift = shift_val / dot_factor;
let ada_weight: Vec<f32> = vec![
w_scale, w_scale, w_scale, w_scale, w_shift, w_shift, w_shift, w_shift, ];
let proj_weight = vec![1.0f32, 0.0, 0.0, 1.0];
let mut map: HashMap<String, candle_core::Tensor> = HashMap::new();
map.insert(
"norm_out.linear.weight".to_string(),
Tensor::from_vec(ada_weight, (2 * h_sz, h_sz), &dev).unwrap(),
);
map.insert(
"proj_out.weight".to_string(),
Tensor::from_vec(proj_weight, (out_c, h_sz), &dev).unwrap(),
);
let vb = VarBuilder::from_tensors(map, DType::F32, &dev);
let layer = LastLayer::new(h_sz, out_c, vb).unwrap();
let xs = Tensor::zeros((1, 1, h_sz), DType::F32, &dev).unwrap();
let vec_ = Tensor::ones((1, h_sz), DType::F32, &dev).unwrap();
let out = layer.forward(&xs, &vec_).unwrap();
let vals: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(vals.len(), out_c);
let tol = 0.08; assert!(
(vals[0] - shift_val).abs() < tol,
"LastLayer output[0]={:.4}: expected shift={shift_val:.4} \
(diffusers scale-then-shift ordering). \
Got scale={scale_val:.4} instead? The c0c2b80 regression is present.",
vals[0],
);
assert!(
(vals[1] - shift_val).abs() < tol,
"LastLayer output[1]={:.4}: expected shift={shift_val:.4}",
vals[1],
);
}
}