use anyhow::Result;
use candle_core::{DType, Device, IndexOp, Module, Tensor, D};
use candle_nn::{LayerNorm, Linear, RmsNorm, VarBuilder};
use crate::adaptive_offload::{
plan_adaptive_residency, AdaptiveResidencyPlan, ADAPTIVE_OFFLOAD_RUNTIME_HEADROOM,
};
use crate::flux::lora_bypass::{LoraLinear, LoraRegistry};
use crate::flux::pinned::{
largest_block_size_bytes, pinned_cap_bytes, prefetch_enabled_from_env, try_pin_to_host,
PinnedMemoryTracker, PinnedRegion,
};
use crate::progress::ProgressReporter;
#[cfg(feature = "cuda")]
use std::sync::Arc;
use candle_transformers::models::flux::model::{Config, EmbedNd};
#[cfg(feature = "cuda")]
type PrefetchStream = Arc<candle_core::cuda_backend::cudarc::driver::CudaStream>;
#[cfg(feature = "cuda")]
type PrefetchBuffer = candle_core::cuda_backend::cudarc::driver::CudaSlice<u8>;
fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
const TIME_FACTOR: f64 = 1000.;
const MAX_PERIOD: f64 = 10000.;
if dim % 2 == 1 {
anyhow::bail!("{dim} is odd");
}
let dev = t.device();
let half = dim / 2;
let t = (t * TIME_FACTOR)?;
let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(candle_core::DType::F32)?;
let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?;
let args = t
.unsqueeze(1)?
.to_dtype(candle_core::DType::F32)?
.broadcast_mul(&freqs.unsqueeze(0)?)?;
let emb = Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?.to_dtype(dtype)?;
Ok(emb)
}
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
Ok(crate::attention::attention_default_scale(q, k, v)?)
}
fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
let dims = x.dims();
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
let fr0 = freq_cis.get_on_dim(D::Minus1, 0)?;
let fr1 = freq_cis.get_on_dim(D::Minus1, 1)?;
Ok((fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())?)
}
fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
let q = apply_rope(q, pe)?.contiguous()?;
let k = apply_rope(k, pe)?.contiguous()?;
let x = scaled_dot_product_attention(&q, &k, v)?;
Ok(x.transpose(1, 2)?.flatten_from(2)?)
}
fn layer_norm(dim: usize, vb: VarBuilder) -> Result<LayerNorm> {
let ws = Tensor::ones(dim, vb.dtype(), vb.device())?;
Ok(LayerNorm::new_no_bias(ws, 1e-6))
}
fn linear_to_device(l: &Linear, dev: &Device) -> Result<Linear> {
let w = l.weight().to_device(dev)?;
let b = l.bias().map(|b| b.to_device(dev)).transpose()?;
Ok(Linear::new(w, b))
}
fn lora_linear_to_device(
l: &Linear,
dev: &Device,
registry: Option<&LoraRegistry>,
key: &str,
) -> Result<LoraLinear> {
let inner = linear_to_device(l, dev)?;
let adapters = registry
.map(|r| r.adapters_for(key).to_vec())
.unwrap_or_default();
if adapters.is_empty() {
Ok(LoraLinear::Plain(inner))
} else {
Ok(LoraLinear::WithAdapters { inner, adapters })
}
}
fn layer_norm_to_device(ln: &LayerNorm, dev: &Device) -> Result<LayerNorm> {
let w = ln.weight().to_device(dev)?;
match ln.bias() {
Some(b) => Ok(LayerNorm::new(w, b.to_device(dev)?, 1e-6)),
None => Ok(LayerNorm::new_no_bias(w, 1e-6)),
}
}
fn rms_norm_to_device(rn: &RmsNorm, dev: &Device) -> Result<RmsNorm> {
let inner = rn.clone().into_inner();
Ok(RmsNorm::new(inner.weight().to_device(dev)?, 1e-6))
}
fn visit_double_block_weights<F>(b: &DoubleBlock, mut f: F) -> usize
where
F: FnMut(&Tensor) -> usize,
{
let mut total = 0usize;
total += f(b.img_mod.lin.weight());
if let Some(t) = b.img_mod.lin.bias() {
total += f(t);
}
total += f(b.img_norm1.weight());
if let Some(t) = b.img_norm1.bias() {
total += f(t);
}
total += f(b.img_attn.qkv.weight());
if let Some(t) = b.img_attn.qkv.bias() {
total += f(t);
}
total += f(b.img_attn.query_norm.clone().into_inner().weight());
total += f(b.img_attn.key_norm.clone().into_inner().weight());
total += f(b.img_attn.proj.weight());
if let Some(t) = b.img_attn.proj.bias() {
total += f(t);
}
total += f(b.img_norm2.weight());
if let Some(t) = b.img_norm2.bias() {
total += f(t);
}
total += f(b.img_mlp.lin1.weight());
if let Some(t) = b.img_mlp.lin1.bias() {
total += f(t);
}
total += f(b.img_mlp.lin2.weight());
if let Some(t) = b.img_mlp.lin2.bias() {
total += f(t);
}
total += f(b.txt_mod.lin.weight());
if let Some(t) = b.txt_mod.lin.bias() {
total += f(t);
}
total += f(b.txt_norm1.weight());
if let Some(t) = b.txt_norm1.bias() {
total += f(t);
}
total += f(b.txt_attn.qkv.weight());
if let Some(t) = b.txt_attn.qkv.bias() {
total += f(t);
}
total += f(b.txt_attn.query_norm.clone().into_inner().weight());
total += f(b.txt_attn.key_norm.clone().into_inner().weight());
total += f(b.txt_attn.proj.weight());
if let Some(t) = b.txt_attn.proj.bias() {
total += f(t);
}
total += f(b.txt_norm2.weight());
if let Some(t) = b.txt_norm2.bias() {
total += f(t);
}
total += f(b.txt_mlp.lin1.weight());
if let Some(t) = b.txt_mlp.lin1.bias() {
total += f(t);
}
total += f(b.txt_mlp.lin2.weight());
if let Some(t) = b.txt_mlp.lin2.bias() {
total += f(t);
}
total
}
fn visit_single_block_weights<F>(b: &SingleBlock, mut f: F) -> usize
where
F: FnMut(&Tensor) -> usize,
{
let mut total = 0usize;
total += f(b.linear1.weight());
if let Some(t) = b.linear1.bias() {
total += f(t);
}
total += f(b.linear2.weight());
if let Some(t) = b.linear2.bias() {
total += f(t);
}
total += f(b.query_norm.clone().into_inner().weight());
total += f(b.key_norm.clone().into_inner().weight());
total += f(b.pre_norm.weight());
if let Some(t) = b.pre_norm.bias() {
total += f(t);
}
total += f(b.modulation.lin.weight());
if let Some(t) = b.modulation.lin.bias() {
total += f(t);
}
total
}
fn tensor_bytes(t: &Tensor) -> usize {
t.elem_count() * t.dtype().size_in_bytes()
}
fn prefetch_status_label(requested: bool, stream_ready: bool, buffer_ready: bool) -> &'static str {
if !requested {
return "off";
}
if stream_ready && buffer_ready {
"scaffold-only"
} else {
"unavailable"
}
}
#[cfg(feature = "cuda")]
fn init_prefetch(
gpu_device: &Device,
largest_block_bytes: usize,
) -> Result<(Option<PrefetchStream>, Option<PrefetchBuffer>)> {
let cuda_dev = match gpu_device.as_cuda_device() {
Ok(d) => d,
Err(_) => return Ok((None, None)),
};
let ctx = cuda_dev.cuda_stream().context().clone();
let stream = match ctx.new_stream() {
Ok(s) => s,
Err(e) => {
tracing::warn!("FLUX offload: failed to create prefetch stream ({e:?}) — falling back to single-stream");
return Ok((None, None));
}
};
if largest_block_bytes == 0 {
return Ok((Some(stream), None));
}
let buf = match unsafe { stream.alloc::<u8>(largest_block_bytes) } {
Ok(s) => s,
Err(e) => {
tracing::warn!(
"FLUX offload: prefetch buffer alloc failed ({largest_block_bytes} bytes, {e:?}) — \
falling back to single-stream"
);
return Ok((Some(stream), None));
}
};
Ok((Some(stream), Some(buf)))
}
struct Modulation1 {
lin: Linear,
}
impl Modulation1 {
fn load(dim: usize, vb: VarBuilder) -> Result<Self> {
Ok(Self {
lin: candle_nn::linear(dim, 3 * dim, vb.pp("lin"))?,
})
}
fn to_device(
&self,
dev: &Device,
registry: Option<&LoraRegistry>,
base_key: &str,
) -> Result<GpuModulation1> {
Ok(GpuModulation1 {
lin: lora_linear_to_device(
&self.lin,
dev,
registry,
&format!("{base_key}.lin.weight"),
)?,
})
}
}
struct GpuModulation1 {
lin: LoraLinear,
}
impl GpuModulation1 {
fn forward(&self, vec_: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let pre = vec_.silu()?;
let ys = self.lin.forward(&pre)?.unsqueeze(1)?.chunk(3, D::Minus1)?;
Ok((ys[0].clone(), ys[1].clone(), ys[2].clone()))
}
}
struct Modulation2 {
lin: Linear,
}
impl Modulation2 {
fn load(dim: usize, vb: VarBuilder) -> Result<Self> {
Ok(Self {
lin: candle_nn::linear(dim, 6 * dim, vb.pp("lin"))?,
})
}
fn to_device(
&self,
dev: &Device,
registry: Option<&LoraRegistry>,
base_key: &str,
) -> Result<GpuModulation2> {
Ok(GpuModulation2 {
lin: lora_linear_to_device(
&self.lin,
dev,
registry,
&format!("{base_key}.lin.weight"),
)?,
})
}
}
struct GpuModulation2 {
lin: LoraLinear,
}
impl GpuModulation2 {
#[allow(clippy::type_complexity)]
fn forward(
&self,
vec_: &Tensor,
) -> Result<((Tensor, Tensor, Tensor), (Tensor, Tensor, Tensor))> {
let pre = vec_.silu()?;
let ys = self.lin.forward(&pre)?.unsqueeze(1)?.chunk(6, D::Minus1)?;
Ok((
(ys[0].clone(), ys[1].clone(), ys[2].clone()),
(ys[3].clone(), ys[4].clone(), ys[5].clone()),
))
}
}
struct SelfAttention {
qkv: Linear,
query_norm: RmsNorm,
key_norm: RmsNorm,
proj: Linear,
num_heads: usize,
}
impl SelfAttention {
fn load(dim: usize, num_heads: usize, qkv_bias: bool, vb: VarBuilder) -> Result<Self> {
let head_dim = dim / num_heads;
let qkv = candle_nn::linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?;
let query_norm = vb.get(head_dim, "norm.query_norm.scale")?;
let key_norm = vb.get(head_dim, "norm.key_norm.scale")?;
let proj = candle_nn::linear(dim, dim, vb.pp("proj"))?;
Ok(Self {
qkv,
query_norm: RmsNorm::new(query_norm, 1e-6),
key_norm: RmsNorm::new(key_norm, 1e-6),
proj,
num_heads,
})
}
fn to_device(
&self,
dev: &Device,
registry: Option<&LoraRegistry>,
base_key: &str,
) -> Result<GpuSelfAttention> {
Ok(GpuSelfAttention {
qkv: lora_linear_to_device(
&self.qkv,
dev,
registry,
&format!("{base_key}.qkv.weight"),
)?,
query_norm: rms_norm_to_device(&self.query_norm, dev)?,
key_norm: rms_norm_to_device(&self.key_norm, dev)?,
proj: lora_linear_to_device(
&self.proj,
dev,
registry,
&format!("{base_key}.proj.weight"),
)?,
num_heads: self.num_heads,
})
}
}
struct GpuSelfAttention {
qkv: LoraLinear,
query_norm: RmsNorm,
key_norm: RmsNorm,
proj: LoraLinear,
num_heads: usize,
}
impl GpuSelfAttention {
fn qkv_split(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let qkv = self.qkv.forward(xs)?;
let (b, l, _khd) = qkv.dims3()?;
let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
let q = q.apply(&self.query_norm)?;
let k = k.apply(&self.key_norm)?;
Ok((q, k, v))
}
}
struct Mlp {
lin1: Linear,
lin2: Linear,
}
impl Mlp {
fn load(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result<Self> {
Ok(Self {
lin1: candle_nn::linear(in_sz, mlp_sz, vb.pp("0"))?,
lin2: candle_nn::linear(mlp_sz, in_sz, vb.pp("2"))?,
})
}
fn to_device(
&self,
dev: &Device,
registry: Option<&LoraRegistry>,
base_key: &str,
) -> Result<GpuMlp> {
Ok(GpuMlp {
lin1: lora_linear_to_device(
&self.lin1,
dev,
registry,
&format!("{base_key}.0.weight"),
)?,
lin2: lora_linear_to_device(
&self.lin2,
dev,
registry,
&format!("{base_key}.2.weight"),
)?,
})
}
}
struct GpuMlp {
lin1: LoraLinear,
lin2: LoraLinear,
}
impl GpuMlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let h = self.lin1.forward(xs)?.gelu()?;
self.lin2.forward(&h)
}
}
pub(crate) struct DoubleBlock {
img_mod: Modulation2,
img_norm1: LayerNorm,
img_attn: SelfAttention,
img_norm2: LayerNorm,
img_mlp: Mlp,
txt_mod: Modulation2,
txt_norm1: LayerNorm,
txt_attn: SelfAttention,
txt_norm2: LayerNorm,
txt_mlp: Mlp,
}
impl DoubleBlock {
fn load(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let mlp_sz = (h as f64 * cfg.mlp_ratio) as usize;
Ok(Self {
img_mod: Modulation2::load(h, vb.pp("img_mod"))?,
img_norm1: layer_norm(h, vb.pp("img_norm1"))?,
img_attn: SelfAttention::load(h, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?,
img_norm2: layer_norm(h, vb.pp("img_norm2"))?,
img_mlp: Mlp::load(h, mlp_sz, vb.pp("img_mlp"))?,
txt_mod: Modulation2::load(h, vb.pp("txt_mod"))?,
txt_norm1: layer_norm(h, vb.pp("txt_norm1"))?,
txt_attn: SelfAttention::load(h, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?,
txt_norm2: layer_norm(h, vb.pp("txt_norm2"))?,
txt_mlp: Mlp::load(h, mlp_sz, vb.pp("txt_mlp"))?,
})
}
fn to_device(
&self,
dev: &Device,
registry: Option<&LoraRegistry>,
idx: usize,
) -> Result<GpuDoubleBlock> {
let base = format!("double_blocks.{idx}");
Ok(GpuDoubleBlock {
img_mod: self
.img_mod
.to_device(dev, registry, &format!("{base}.img_mod"))?,
img_norm1: layer_norm_to_device(&self.img_norm1, dev)?,
img_attn: self
.img_attn
.to_device(dev, registry, &format!("{base}.img_attn"))?,
img_norm2: layer_norm_to_device(&self.img_norm2, dev)?,
img_mlp: self
.img_mlp
.to_device(dev, registry, &format!("{base}.img_mlp"))?,
txt_mod: self
.txt_mod
.to_device(dev, registry, &format!("{base}.txt_mod"))?,
txt_norm1: layer_norm_to_device(&self.txt_norm1, dev)?,
txt_attn: self
.txt_attn
.to_device(dev, registry, &format!("{base}.txt_attn"))?,
txt_norm2: layer_norm_to_device(&self.txt_norm2, dev)?,
txt_mlp: self
.txt_mlp
.to_device(dev, registry, &format!("{base}.txt_mlp"))?,
})
}
}
struct GpuDoubleBlock {
img_mod: GpuModulation2,
img_norm1: LayerNorm,
img_attn: GpuSelfAttention,
img_norm2: LayerNorm,
img_mlp: GpuMlp,
txt_mod: GpuModulation2,
txt_norm1: LayerNorm,
txt_attn: GpuSelfAttention,
txt_norm2: LayerNorm,
txt_mlp: GpuMlp,
}
impl GpuDoubleBlock {
fn forward(
&self,
img: &Tensor,
txt: &Tensor,
vec_: &Tensor,
pe: &Tensor,
) -> Result<(Tensor, Tensor)> {
let ((img_s1, img_sc1, img_g1), (img_s2, img_sc2, img_g2)) = self.img_mod.forward(vec_)?;
let ((txt_s1, txt_sc1, txt_g1), (txt_s2, txt_sc2, txt_g2)) = self.txt_mod.forward(vec_)?;
let img_modulated = img
.apply(&self.img_norm1)?
.broadcast_mul(&(&img_sc1 + 1.)?)?
.broadcast_add(&img_s1)?;
let (img_q, img_k, img_v) = self.img_attn.qkv_split(&img_modulated)?;
let txt_modulated = txt
.apply(&self.txt_norm1)?
.broadcast_mul(&(&txt_sc1 + 1.)?)?
.broadcast_add(&txt_s1)?;
let (txt_q, txt_k, txt_v) = self.txt_attn.qkv_split(&txt_modulated)?;
let q = Tensor::cat(&[txt_q, img_q], 2)?;
let k = Tensor::cat(&[txt_k, img_k], 2)?;
let v = Tensor::cat(&[txt_v, img_v], 2)?;
let attn = attention(&q, &k, &v, pe)?;
let txt_attn_out = attn.narrow(1, 0, txt.dim(1)?)?;
let img_attn_out = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?;
let img = (img + img_g1.broadcast_mul(&self.img_attn.proj.forward(&img_attn_out)?)?)?;
let img_ff = img
.apply(&self.img_norm2)?
.broadcast_mul(&(&img_sc2 + 1.)?)?
.broadcast_add(&img_s2)?;
let img = (&img + img_g2.broadcast_mul(&self.img_mlp.forward(&img_ff)?)?)?;
let txt = (txt + txt_g1.broadcast_mul(&self.txt_attn.proj.forward(&txt_attn_out)?)?)?;
let txt_ff = txt
.apply(&self.txt_norm2)?
.broadcast_mul(&(&txt_sc2 + 1.)?)?
.broadcast_add(&txt_s2)?;
let txt = (&txt + txt_g2.broadcast_mul(&self.txt_mlp.forward(&txt_ff)?)?)?;
Ok((img, txt))
}
}
pub(crate) struct SingleBlock {
linear1: Linear,
linear2: Linear,
query_norm: RmsNorm,
key_norm: RmsNorm,
pre_norm: LayerNorm,
modulation: Modulation1,
h_sz: usize,
mlp_sz: usize,
num_heads: usize,
}
impl SingleBlock {
fn load(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let mlp_sz = (h as f64 * cfg.mlp_ratio) as usize;
let head_dim = h / cfg.num_heads;
Ok(Self {
linear1: candle_nn::linear(h, h * 3 + mlp_sz, vb.pp("linear1"))?,
linear2: candle_nn::linear(h + mlp_sz, h, vb.pp("linear2"))?,
query_norm: {
let w = vb.get(head_dim, "norm.query_norm.scale")?;
RmsNorm::new(w, 1e-6)
},
key_norm: {
let w = vb.get(head_dim, "norm.key_norm.scale")?;
RmsNorm::new(w, 1e-6)
},
pre_norm: layer_norm(h, vb.pp("pre_norm"))?,
modulation: Modulation1::load(h, vb.pp("modulation"))?,
h_sz: h,
mlp_sz,
num_heads: cfg.num_heads,
})
}
fn to_device(
&self,
dev: &Device,
registry: Option<&LoraRegistry>,
idx: usize,
) -> Result<GpuSingleBlock> {
let base = format!("single_blocks.{idx}");
Ok(GpuSingleBlock {
linear1: lora_linear_to_device(
&self.linear1,
dev,
registry,
&format!("{base}.linear1.weight"),
)?,
linear2: lora_linear_to_device(
&self.linear2,
dev,
registry,
&format!("{base}.linear2.weight"),
)?,
query_norm: rms_norm_to_device(&self.query_norm, dev)?,
key_norm: rms_norm_to_device(&self.key_norm, dev)?,
pre_norm: layer_norm_to_device(&self.pre_norm, dev)?,
modulation: self
.modulation
.to_device(dev, registry, &format!("{base}.modulation"))?,
h_sz: self.h_sz,
mlp_sz: self.mlp_sz,
num_heads: self.num_heads,
})
}
}
struct GpuSingleBlock {
linear1: LoraLinear,
linear2: LoraLinear,
query_norm: RmsNorm,
key_norm: RmsNorm,
pre_norm: LayerNorm,
modulation: GpuModulation1,
h_sz: usize,
mlp_sz: usize,
num_heads: usize,
}
impl GpuSingleBlock {
fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result<Tensor> {
let (shift, scale, gate) = self.modulation.forward(vec_)?;
let x_mod = xs
.apply(&self.pre_norm)?
.broadcast_mul(&(&scale + 1.)?)?
.broadcast_add(&shift)?;
let x_mod = self.linear1.forward(&x_mod)?;
let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?;
let (b, l, _khd) = qkv.dims3()?;
let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
let mlp = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz)?;
let q = q.apply(&self.query_norm)?;
let k = k.apply(&self.key_norm)?;
let attn = attention(&q, &k, &v, pe)?;
let output_in = Tensor::cat(&[attn, mlp.gelu()?], 2)?;
let output = self.linear2.forward(&output_in)?;
Ok((xs + gate.broadcast_mul(&output)?)?)
}
}
struct FinalLayer {
norm_final: LayerNorm,
linear: Linear,
ada_ln_modulation: Linear,
}
impl FinalLayer {
fn load(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
Ok(Self {
norm_final: layer_norm(h_sz, vb.pp("norm_final"))?,
linear: candle_nn::linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?,
ada_ln_modulation: candle_nn::linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?,
})
}
fn to_device(&self, dev: &Device) -> Result<Self> {
Ok(Self {
norm_final: layer_norm_to_device(&self.norm_final, dev)?,
linear: linear_to_device(&self.linear, dev)?,
ada_ln_modulation: linear_to_device(&self.ada_ln_modulation, dev)?,
})
}
fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result<Tensor> {
let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?;
let (shift, scale) = (&chunks[0], &chunks[1]);
let xs = xs
.apply(&self.norm_final)?
.broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)?
.broadcast_add(&shift.unsqueeze(1)?)?;
Ok(xs.apply(&self.linear)?)
}
}
enum DoubleBlockSlot {
Resident(Box<GpuDoubleBlock>),
Streamed(Box<DoubleBlock>),
}
enum SingleBlockSlot {
Resident(Box<GpuSingleBlock>),
Streamed(Box<SingleBlock>),
}
fn is_probable_cuda_oom(err: &anyhow::Error) -> bool {
let msg = format!("{err:#}").to_ascii_lowercase();
msg.contains("cuda_error_out_of_memory")
|| msg.contains("out of memory")
|| msg.contains("memory allocation")
}
fn materialize_block_slots(
double_blocks: &mut [Option<DoubleBlock>],
single_blocks: &mut [Option<SingleBlock>],
plan: &AdaptiveResidencyPlan,
gpu_device: &Device,
registry: Option<&LoraRegistry>,
) -> Result<(Vec<DoubleBlockSlot>, Vec<SingleBlockSlot>)> {
let mut resident_double: Vec<Option<GpuDoubleBlock>> = std::iter::repeat_with(|| None)
.take(double_blocks.len())
.collect();
let mut resident_single: Vec<Option<GpuSingleBlock>> = std::iter::repeat_with(|| None)
.take(single_blocks.len())
.collect();
for (i, slot) in double_blocks.iter().enumerate() {
if plan.resident.get(i).copied().unwrap_or(false) {
let block = slot
.as_ref()
.ok_or_else(|| anyhow::anyhow!("double block {i} already consumed"))?;
resident_double[i] = Some(block.to_device(gpu_device, registry, i)?);
}
}
let single_offset = double_blocks.len();
for (i, slot) in single_blocks.iter().enumerate() {
if plan
.resident
.get(single_offset + i)
.copied()
.unwrap_or(false)
{
let block = slot
.as_ref()
.ok_or_else(|| anyhow::anyhow!("single block {i} already consumed"))?;
resident_single[i] = Some(block.to_device(gpu_device, registry, i)?);
}
}
let mut double_slots = Vec::with_capacity(double_blocks.len());
for (i, block) in double_blocks.iter_mut().enumerate() {
if let Some(gpu_block) = resident_double[i].take() {
*block = None;
double_slots.push(DoubleBlockSlot::Resident(Box::new(gpu_block)));
} else {
double_slots.push(DoubleBlockSlot::Streamed(Box::new(
block
.take()
.ok_or_else(|| anyhow::anyhow!("double block {i} already consumed"))?,
)));
}
}
let mut single_slots = Vec::with_capacity(single_blocks.len());
for (i, block) in single_blocks.iter_mut().enumerate() {
if let Some(gpu_block) = resident_single[i].take() {
*block = None;
single_slots.push(SingleBlockSlot::Resident(Box::new(gpu_block)));
} else {
single_slots.push(SingleBlockSlot::Streamed(Box::new(
block
.take()
.ok_or_else(|| anyhow::anyhow!("single block {i} already consumed"))?,
)));
}
}
Ok((double_slots, single_slots))
}
fn pin_streamed_block_weights(
double_blocks: &[DoubleBlockSlot],
single_blocks: &[SingleBlockSlot],
) -> (Vec<PinnedRegion>, u64) {
let tracker = PinnedMemoryTracker::new(pinned_cap_bytes());
let mut pinned_regions: Vec<PinnedRegion> = Vec::new();
let mut pin_visit = |t: &Tensor| -> usize {
match try_pin_to_host(t, &tracker) {
Ok(Some(region)) => {
pinned_regions.push(region);
0
}
Ok(None) => 0,
Err(e) => {
tracing::debug!("try_pin_to_host failed: {e:?} (continuing)");
0
}
}
};
for block in double_blocks {
if let DoubleBlockSlot::Streamed(block) = block {
visit_double_block_weights(block, &mut pin_visit);
}
}
for block in single_blocks {
if let SingleBlockSlot::Streamed(block) = block {
visit_single_block_weights(block, &mut pin_visit);
}
}
let pinned_bytes = tracker.used_bytes();
(pinned_regions, pinned_bytes)
}
fn streamed_block_sizes(
double_blocks: &[DoubleBlockSlot],
single_blocks: &[SingleBlockSlot],
) -> Vec<usize> {
let mut sizes = Vec::new();
for block in double_blocks {
if let DoubleBlockSlot::Streamed(block) = block {
sizes.push(visit_double_block_weights(block, tensor_bytes));
}
}
for block in single_blocks {
if let SingleBlockSlot::Streamed(block) = block {
sizes.push(visit_single_block_weights(block, tensor_bytes));
}
}
sizes
}
pub(crate) struct OffloadedFluxTransformer {
#[allow(dead_code)]
pinned_regions: Vec<PinnedRegion>,
img_in: Linear,
txt_in: Linear,
time_in: StemMlpEmbedder,
vector_in: StemMlpEmbedder,
guidance_in: Option<StemMlpEmbedder>,
pe_embedder: EmbedNd,
final_layer: FinalLayer,
double_blocks: Vec<DoubleBlockSlot>,
single_blocks: Vec<SingleBlockSlot>,
gpu_device: Device,
lora_registry: Option<LoraRegistry>,
#[cfg(feature = "cuda")]
#[allow(dead_code)]
prefetch_stream: Option<PrefetchStream>,
#[cfg(feature = "cuda")]
#[allow(dead_code)]
prefetch_buffer: Option<PrefetchBuffer>,
}
impl OffloadedFluxTransformer {
pub fn load(
vb: VarBuilder,
cfg: &Config,
gpu_device: &Device,
gpu_ordinal: usize,
activation_budget: u64,
lora_registry: Option<LoraRegistry>,
progress: &ProgressReporter,
) -> Result<Self> {
progress.info("Loading transformer blocks on CPU…");
let img_in = linear_to_device(
&candle_nn::linear(cfg.in_channels, cfg.hidden_size, vb.pp("img_in"))?,
gpu_device,
)?;
let txt_in = linear_to_device(
&candle_nn::linear(cfg.context_in_dim, cfg.hidden_size, vb.pp("txt_in"))?,
gpu_device,
)?;
let time_in =
StemMlpEmbedder::load(256, cfg.hidden_size, vb.pp("time_in"))?.to_device(gpu_device)?;
let vector_in = StemMlpEmbedder::load(cfg.vec_in_dim, cfg.hidden_size, vb.pp("vector_in"))?
.to_device(gpu_device)?;
let guidance_in = if cfg.guidance_embed {
Some(
StemMlpEmbedder::load(256, cfg.hidden_size, vb.pp("guidance_in"))?
.to_device(gpu_device)?,
)
} else {
None
};
let pe_dim = cfg.hidden_size / cfg.num_heads;
let pe_embedder = EmbedNd::new(pe_dim, cfg.theta, cfg.axes_dim.to_vec());
let final_layer =
FinalLayer::load(cfg.hidden_size, 1, cfg.in_channels, vb.pp("final_layer"))?
.to_device(gpu_device)?;
let mut double_blocks = Vec::with_capacity(cfg.depth);
let vb_d = vb.pp("double_blocks");
for idx in 0..cfg.depth {
double_blocks.push(Some(DoubleBlock::load(cfg, vb_d.pp(idx))?));
}
let mut single_blocks = Vec::with_capacity(cfg.depth_single_blocks);
let vb_s = vb.pp("single_blocks");
for idx in 0..cfg.depth_single_blocks {
single_blocks.push(Some(SingleBlock::load(cfg, vb_s.pp(idx))?));
}
progress.info(&format!(
"Offloading: planning adaptive residency for {} double + {} single blocks",
double_blocks.len(),
single_blocks.len(),
));
let mut block_sizes: Vec<usize> =
Vec::with_capacity(double_blocks.len() + single_blocks.len());
for b in &double_blocks {
block_sizes.push(visit_double_block_weights(
b.as_ref().expect("double block just loaded"),
tensor_bytes,
));
}
for b in &single_blocks {
block_sizes.push(visit_single_block_weights(
b.as_ref().expect("single block just loaded"),
tensor_bytes,
));
}
let free_vram = crate::device::usable_free_vram_bytes(gpu_ordinal).unwrap_or(0);
let mut plan = plan_adaptive_residency(
&block_sizes,
free_vram,
activation_budget,
ADAPTIVE_OFFLOAD_RUNTIME_HEADROOM,
);
let registry_ref = lora_registry.as_ref();
let (double_blocks, single_blocks, plan) = loop {
match materialize_block_slots(
&mut double_blocks,
&mut single_blocks,
&plan,
gpu_device,
registry_ref,
) {
Ok((double_slots, single_slots)) => break (double_slots, single_slots, plan),
Err(err)
if gpu_device.is_cuda()
&& plan.resident_count() > 0
&& is_probable_cuda_oom(&err) =>
{
progress.info(&format!(
"FLUX adaptive offload: resident allocation OOM at {} resident blocks; \
retrying with fewer resident blocks",
plan.resident_count()
));
if let Err(sync_err) = gpu_device.synchronize() {
tracing::warn!(
"FLUX adaptive offload: synchronize after OOM failed: {sync_err}"
);
}
if !plan.demote_largest_resident(&block_sizes) {
return Err(err);
}
}
Err(err) => return Err(err),
}
};
progress.info(&format!(
"FLUX adaptive offload: {} resident / {} streamed blocks \
(resident {:.2} GB, streamed {:.2} GB per denoise pass, reserve {:.2} GB)",
plan.resident_count(),
plan.streamed_count(),
plan.resident_bytes as f64 / 1_000_000_000.0,
plan.streamed_bytes as f64 / 1_000_000_000.0,
plan.reserved_bytes() as f64 / 1_000_000_000.0,
));
let (pinned_regions, pinned_bytes) =
pin_streamed_block_weights(&double_blocks, &single_blocks);
let prefetch_on = prefetch_enabled_from_env() && gpu_device.is_cuda();
let streamed_sizes = streamed_block_sizes(&double_blocks, &single_blocks);
let largest_block = largest_block_size_bytes(&streamed_sizes);
#[cfg(feature = "cuda")]
let (prefetch_stream, prefetch_buffer) = if prefetch_on {
init_prefetch(gpu_device, largest_block)?
} else {
(None, None)
};
let prefetch_label = {
#[cfg(feature = "cuda")]
{
prefetch_status_label(
prefetch_on,
prefetch_stream.is_some(),
prefetch_buffer.is_some(),
)
}
#[cfg(not(feature = "cuda"))]
{
prefetch_status_label(prefetch_on, false, false)
}
};
let pinned_gb = pinned_bytes as f64 / 1_000_000_000.0;
if pinned_regions.is_empty() {
progress.info(&format!(
"FLUX offload: prefetch={} (largest block {:.1} MB) — pinning skipped \
(no streamed CUDA tensors / unsupported tensors)",
prefetch_label,
largest_block as f64 / 1_000_000.0,
));
} else {
progress.info(&format!(
"FLUX offload: pinned {:.2} GB across {} tensors, prefetch={} \
(largest block {:.1} MB)",
pinned_gb,
pinned_regions.len(),
prefetch_label,
largest_block as f64 / 1_000_000.0,
));
}
Ok(Self {
pinned_regions,
img_in,
txt_in,
time_in,
vector_in,
guidance_in,
pe_embedder,
final_layer,
double_blocks,
single_blocks,
gpu_device: gpu_device.clone(),
lora_registry,
#[cfg(feature = "cuda")]
prefetch_stream,
#[cfg(feature = "cuda")]
prefetch_buffer,
})
}
#[allow(dead_code)]
pub(crate) fn has_loras(&self) -> bool {
self.lora_registry
.as_ref()
.map(|r| !r.is_empty())
.unwrap_or(false)
}
#[allow(clippy::too_many_arguments)]
pub fn forward(
&self,
img: &Tensor,
img_ids: &Tensor,
txt: &Tensor,
txt_ids: &Tensor,
timesteps: &Tensor,
y: &Tensor,
guidance: Option<&Tensor>,
) -> Result<Tensor> {
let dtype = img.dtype();
let registry = self.lora_registry.as_ref();
let pe = {
let ids = Tensor::cat(&[txt_ids, img_ids], 1)?;
ids.apply(&self.pe_embedder)?
};
let mut txt = txt.apply(&self.txt_in)?;
let mut img = img.apply(&self.img_in)?;
let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?;
let vec_ = match (self.guidance_in.as_ref(), guidance) {
(Some(g_in), Some(guidance)) => {
(vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))?
}
_ => vec_,
};
let vec_ = (vec_ + y.apply(&self.vector_in))?;
for (i, block) in self.double_blocks.iter().enumerate() {
match block {
DoubleBlockSlot::Resident(gpu_block) => {
(img, txt) = gpu_block.forward(&img, &txt, &vec_, &pe)?;
}
DoubleBlockSlot::Streamed(block) => {
let gpu_block = block.to_device(&self.gpu_device, registry, i)?;
(img, txt) = gpu_block.forward(&img, &txt, &vec_, &pe)?;
self.gpu_device.synchronize()?;
drop(gpu_block);
}
}
tracing::trace!("double block {i} done");
}
let mut img = Tensor::cat(&[&txt, &img], 1)?;
let txt_len = txt.dim(1)?;
for (i, block) in self.single_blocks.iter().enumerate() {
match block {
SingleBlockSlot::Resident(gpu_block) => {
img = gpu_block.forward(&img, &vec_, &pe)?;
}
SingleBlockSlot::Streamed(block) => {
let gpu_block = block.to_device(&self.gpu_device, registry, i)?;
img = gpu_block.forward(&img, &vec_, &pe)?;
self.gpu_device.synchronize()?;
drop(gpu_block);
}
}
tracing::trace!("single block {i} done");
}
let img = img.i((.., txt_len..))?;
self.final_layer.forward(&img, &vec_)
}
}
struct StemMlpEmbedder {
in_layer: Linear,
out_layer: Linear,
}
impl StemMlpEmbedder {
fn load(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result<Self> {
Ok(Self {
in_layer: candle_nn::linear(in_sz, h_sz, vb.pp("in_layer"))?,
out_layer: candle_nn::linear(h_sz, h_sz, vb.pp("out_layer"))?,
})
}
fn to_device(&self, dev: &Device) -> Result<Self> {
Ok(Self {
in_layer: linear_to_device(&self.in_layer, dev)?,
out_layer: linear_to_device(&self.out_layer, dev)?,
})
}
}
impl Module for StemMlpEmbedder {
fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> {
xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer)
}
}
#[cfg(test)]
mod tests {
use super::prefetch_status_label;
#[test]
fn prefetch_status_label_distinguishes_scaffold_from_real_async() {
assert_eq!(prefetch_status_label(false, false, false), "off");
assert_eq!(prefetch_status_label(true, false, false), "unavailable");
assert_eq!(prefetch_status_label(true, true, false), "unavailable");
assert_eq!(prefetch_status_label(true, true, true), "scaffold-only");
}
}