#![allow(dead_code)]
use anyhow::Result;
use candle_core::{quantized::QTensor, DType, IndexOp, Module, Tensor, D};
use candle_nn::{LayerNorm, RmsNorm};
use candle_transformers::models::flux::model::{Config, EmbedNd};
use candle_transformers::quantized_nn::Linear as QuantizedLinear;
use candle_transformers::quantized_var_builder::VarBuilder;
use std::sync::Arc;
use crate::flux::lora_bypass::{LoraLinear, LoraRegistry};
use crate::progress::ProgressReporter;
fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
const TIME_FACTOR: f64 = 1000.;
const MAX_PERIOD: f64 = 10000.;
if dim % 2 == 1 {
anyhow::bail!("{dim} is odd");
}
let dev = t.device();
let half = dim / 2;
let t = (t * TIME_FACTOR)?;
let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(DType::F32)?;
let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?;
let args = t
.unsqueeze(1)?
.to_dtype(DType::F32)?
.broadcast_mul(&freqs.unsqueeze(0)?)?;
let emb = Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?.to_dtype(dtype)?;
Ok(emb)
}
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
Ok(crate::attention::attention_default_scale(q, k, v)?)
}
fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
let dims = x.dims();
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
let fr0 = freq_cis.get_on_dim(D::Minus1, 0)?;
let fr1 = freq_cis.get_on_dim(D::Minus1, 1)?;
Ok((fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())?)
}
fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
let q = apply_rope(q, pe)?.contiguous()?;
let k = apply_rope(k, pe)?.contiguous()?;
let x = scaled_dot_product_attention(&q, &k, v)?;
Ok(x.transpose(1, 2)?.flatten_from(2)?)
}
fn layer_norm(dim: usize, vb: &VarBuilder) -> Result<LayerNorm> {
let ws = Tensor::ones(dim, DType::F32, vb.device())?;
Ok(LayerNorm::new_no_bias(ws, 1e-6))
}
fn load_quantized_linear(
in_dim: usize,
out_dim: usize,
bias: bool,
vb: VarBuilder,
registry: Option<&LoraRegistry>,
key: &str,
) -> Result<LoraLinear> {
let inner = if bias {
candle_transformers::quantized_nn::linear_b(in_dim, out_dim, true, vb)?
} else {
candle_transformers::quantized_nn::linear_no_bias(in_dim, out_dim, vb)?
};
let adapters = registry
.map(|r| r.adapters_for(key).to_vec())
.unwrap_or_default();
if adapters.is_empty() {
Ok(LoraLinear::Quantized(inner))
} else {
Ok(LoraLinear::WithAdaptersQuantized { inner, adapters })
}
}
fn quantized_linear(
in_dim: usize,
out_dim: usize,
vb: VarBuilder,
registry: Option<&LoraRegistry>,
key: &str,
) -> Result<LoraLinear> {
load_quantized_linear(in_dim, out_dim, true, vb, registry, key)
}
fn rms_norm_from_qtensor(dim: usize, vb: VarBuilder, name: &str) -> Result<RmsNorm> {
let weight = vb.get(dim, name)?.dequantize(vb.device())?;
Ok(RmsNorm::new(weight, 1e-6))
}
struct Modulation1 {
lin: LoraLinear,
}
impl Modulation1 {
fn load(
dim: usize,
vb: VarBuilder,
registry: Option<&LoraRegistry>,
base_key: &str,
) -> Result<Self> {
let lin = quantized_linear(
dim,
3 * dim,
vb.pp("lin"),
registry,
&format!("{base_key}.lin.weight"),
)?;
Ok(Self { lin })
}
fn forward(&self, vec_: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let pre = vec_.silu()?;
let ys = self.lin.forward(&pre)?.unsqueeze(1)?.chunk(3, D::Minus1)?;
Ok((ys[0].clone(), ys[1].clone(), ys[2].clone()))
}
fn rebind_lora(&mut self, registry: Option<&LoraRegistry>, base_key: &str) {
rebind(&mut self.lin, registry, &format!("{base_key}.lin.weight"));
}
}
struct Modulation2 {
lin: LoraLinear,
}
impl Modulation2 {
fn load(
dim: usize,
vb: VarBuilder,
registry: Option<&LoraRegistry>,
base_key: &str,
) -> Result<Self> {
let lin = quantized_linear(
dim,
6 * dim,
vb.pp("lin"),
registry,
&format!("{base_key}.lin.weight"),
)?;
Ok(Self { lin })
}
#[allow(clippy::type_complexity)]
fn forward(
&self,
vec_: &Tensor,
) -> Result<((Tensor, Tensor, Tensor), (Tensor, Tensor, Tensor))> {
let pre = vec_.silu()?;
let ys = self.lin.forward(&pre)?.unsqueeze(1)?.chunk(6, D::Minus1)?;
Ok((
(ys[0].clone(), ys[1].clone(), ys[2].clone()),
(ys[3].clone(), ys[4].clone(), ys[5].clone()),
))
}
fn rebind_lora(&mut self, registry: Option<&LoraRegistry>, base_key: &str) {
rebind(&mut self.lin, registry, &format!("{base_key}.lin.weight"));
}
}
struct SelfAttention {
qkv: LoraLinear,
query_norm: RmsNorm,
key_norm: RmsNorm,
proj: LoraLinear,
num_heads: usize,
}
impl SelfAttention {
fn load(
dim: usize,
num_heads: usize,
qkv_bias: bool,
vb: VarBuilder,
registry: Option<&LoraRegistry>,
base_key: &str,
) -> Result<Self> {
let head_dim = dim / num_heads;
let qkv = load_quantized_linear(
dim,
dim * 3,
qkv_bias,
vb.pp("qkv"),
registry,
&format!("{base_key}.qkv.weight"),
)?;
let query_norm = rms_norm_from_qtensor(head_dim, vb.pp("norm"), "query_norm.scale")?;
let key_norm = rms_norm_from_qtensor(head_dim, vb.pp("norm"), "key_norm.scale")?;
let proj = quantized_linear(
dim,
dim,
vb.pp("proj"),
registry,
&format!("{base_key}.proj.weight"),
)?;
Ok(Self {
qkv,
query_norm,
key_norm,
proj,
num_heads,
})
}
fn qkv_split(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let qkv = self.qkv.forward(xs)?;
let (b, l, _khd) = qkv.dims3()?;
let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
let q = q.apply(&self.query_norm)?;
let k = k.apply(&self.key_norm)?;
Ok((q, k, v))
}
fn rebind_lora(&mut self, registry: Option<&LoraRegistry>, base_key: &str) {
rebind(&mut self.qkv, registry, &format!("{base_key}.qkv.weight"));
rebind(&mut self.proj, registry, &format!("{base_key}.proj.weight"));
}
}
struct Mlp {
lin1: LoraLinear,
lin2: LoraLinear,
}
impl Mlp {
fn load(
in_sz: usize,
mlp_sz: usize,
vb: VarBuilder,
registry: Option<&LoraRegistry>,
base_key: &str,
) -> Result<Self> {
let lin1 = quantized_linear(
in_sz,
mlp_sz,
vb.pp("0"),
registry,
&format!("{base_key}.0.weight"),
)?;
let lin2 = quantized_linear(
mlp_sz,
in_sz,
vb.pp("2"),
registry,
&format!("{base_key}.2.weight"),
)?;
Ok(Self { lin1, lin2 })
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let h = self.lin1.forward(xs)?.gelu()?;
self.lin2.forward(&h)
}
fn rebind_lora(&mut self, registry: Option<&LoraRegistry>, base_key: &str) {
rebind(&mut self.lin1, registry, &format!("{base_key}.0.weight"));
rebind(&mut self.lin2, registry, &format!("{base_key}.2.weight"));
}
}
struct DoubleBlock {
img_mod: Modulation2,
img_norm1: LayerNorm,
img_attn: SelfAttention,
img_norm2: LayerNorm,
img_mlp: Mlp,
txt_mod: Modulation2,
txt_norm1: LayerNorm,
txt_attn: SelfAttention,
txt_norm2: LayerNorm,
txt_mlp: Mlp,
}
impl DoubleBlock {
fn load(
cfg: &Config,
vb: VarBuilder,
registry: Option<&LoraRegistry>,
idx: usize,
) -> Result<Self> {
let h = cfg.hidden_size;
let mlp_sz = (h as f64 * cfg.mlp_ratio) as usize;
let base = format!("double_blocks.{idx}");
Ok(Self {
img_mod: Modulation2::load(h, vb.pp("img_mod"), registry, &format!("{base}.img_mod"))?,
img_norm1: layer_norm(h, &vb.pp("img_norm1"))?,
img_attn: SelfAttention::load(
h,
cfg.num_heads,
cfg.qkv_bias,
vb.pp("img_attn"),
registry,
&format!("{base}.img_attn"),
)?,
img_norm2: layer_norm(h, &vb.pp("img_norm2"))?,
img_mlp: Mlp::load(
h,
mlp_sz,
vb.pp("img_mlp"),
registry,
&format!("{base}.img_mlp"),
)?,
txt_mod: Modulation2::load(h, vb.pp("txt_mod"), registry, &format!("{base}.txt_mod"))?,
txt_norm1: layer_norm(h, &vb.pp("txt_norm1"))?,
txt_attn: SelfAttention::load(
h,
cfg.num_heads,
cfg.qkv_bias,
vb.pp("txt_attn"),
registry,
&format!("{base}.txt_attn"),
)?,
txt_norm2: layer_norm(h, &vb.pp("txt_norm2"))?,
txt_mlp: Mlp::load(
h,
mlp_sz,
vb.pp("txt_mlp"),
registry,
&format!("{base}.txt_mlp"),
)?,
})
}
fn forward(
&self,
img: &Tensor,
txt: &Tensor,
vec_: &Tensor,
pe: &Tensor,
) -> Result<(Tensor, Tensor)> {
let ((img_s1, img_sc1, img_g1), (img_s2, img_sc2, img_g2)) = self.img_mod.forward(vec_)?;
let ((txt_s1, txt_sc1, txt_g1), (txt_s2, txt_sc2, txt_g2)) = self.txt_mod.forward(vec_)?;
let img_modulated = img
.apply(&self.img_norm1)?
.broadcast_mul(&(&img_sc1 + 1.)?)?
.broadcast_add(&img_s1)?;
let (img_q, img_k, img_v) = self.img_attn.qkv_split(&img_modulated)?;
let txt_modulated = txt
.apply(&self.txt_norm1)?
.broadcast_mul(&(&txt_sc1 + 1.)?)?
.broadcast_add(&txt_s1)?;
let (txt_q, txt_k, txt_v) = self.txt_attn.qkv_split(&txt_modulated)?;
let q = Tensor::cat(&[txt_q, img_q], 2)?;
let k = Tensor::cat(&[txt_k, img_k], 2)?;
let v = Tensor::cat(&[txt_v, img_v], 2)?;
let attn = attention(&q, &k, &v, pe)?;
let txt_attn_out = attn.narrow(1, 0, txt.dim(1)?)?;
let img_attn_out = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?;
let img = (img + img_g1.broadcast_mul(&self.img_attn.proj.forward(&img_attn_out)?)?)?;
let img_ff = img
.apply(&self.img_norm2)?
.broadcast_mul(&(&img_sc2 + 1.)?)?
.broadcast_add(&img_s2)?;
let img = (&img + img_g2.broadcast_mul(&self.img_mlp.forward(&img_ff)?)?)?;
let txt = (txt + txt_g1.broadcast_mul(&self.txt_attn.proj.forward(&txt_attn_out)?)?)?;
let txt_ff = txt
.apply(&self.txt_norm2)?
.broadcast_mul(&(&txt_sc2 + 1.)?)?
.broadcast_add(&txt_s2)?;
let txt = (&txt + txt_g2.broadcast_mul(&self.txt_mlp.forward(&txt_ff)?)?)?;
Ok((img, txt))
}
fn rebind_lora(&mut self, registry: Option<&LoraRegistry>, idx: usize) {
let base = format!("double_blocks.{idx}");
self.img_mod
.rebind_lora(registry, &format!("{base}.img_mod"));
self.img_attn
.rebind_lora(registry, &format!("{base}.img_attn"));
self.img_mlp
.rebind_lora(registry, &format!("{base}.img_mlp"));
self.txt_mod
.rebind_lora(registry, &format!("{base}.txt_mod"));
self.txt_attn
.rebind_lora(registry, &format!("{base}.txt_attn"));
self.txt_mlp
.rebind_lora(registry, &format!("{base}.txt_mlp"));
}
}
struct SingleBlock {
linear1: LoraLinear,
linear2: LoraLinear,
query_norm: RmsNorm,
key_norm: RmsNorm,
pre_norm: LayerNorm,
modulation: Modulation1,
h_sz: usize,
mlp_sz: usize,
num_heads: usize,
}
impl SingleBlock {
fn load(
cfg: &Config,
vb: VarBuilder,
registry: Option<&LoraRegistry>,
idx: usize,
) -> Result<Self> {
let h = cfg.hidden_size;
let mlp_sz = (h as f64 * cfg.mlp_ratio) as usize;
let head_dim = h / cfg.num_heads;
let base = format!("single_blocks.{idx}");
Ok(Self {
linear1: quantized_linear(
h,
h * 3 + mlp_sz,
vb.pp("linear1"),
registry,
&format!("{base}.linear1.weight"),
)?,
linear2: quantized_linear(
h + mlp_sz,
h,
vb.pp("linear2"),
registry,
&format!("{base}.linear2.weight"),
)?,
query_norm: rms_norm_from_qtensor(head_dim, vb.pp("norm"), "query_norm.scale")?,
key_norm: rms_norm_from_qtensor(head_dim, vb.pp("norm"), "key_norm.scale")?,
pre_norm: layer_norm(h, &vb.pp("pre_norm"))?,
modulation: Modulation1::load(
h,
vb.pp("modulation"),
registry,
&format!("{base}.modulation"),
)?,
h_sz: h,
mlp_sz,
num_heads: cfg.num_heads,
})
}
fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result<Tensor> {
let (shift, scale, gate) = self.modulation.forward(vec_)?;
let x_mod = xs
.apply(&self.pre_norm)?
.broadcast_mul(&(&scale + 1.)?)?
.broadcast_add(&shift)?;
let x_mod = self.linear1.forward(&x_mod)?;
let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?;
let (b, l, _khd) = qkv.dims3()?;
let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
let mlp = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz)?;
let q = q.apply(&self.query_norm)?;
let k = k.apply(&self.key_norm)?;
let attn = attention(&q, &k, &v, pe)?;
let output_in = Tensor::cat(&[attn, mlp.gelu()?], 2)?;
let output = self.linear2.forward(&output_in)?;
Ok((xs + gate.broadcast_mul(&output)?)?)
}
fn rebind_lora(&mut self, registry: Option<&LoraRegistry>, idx: usize) {
let base = format!("single_blocks.{idx}");
rebind(
&mut self.linear1,
registry,
&format!("{base}.linear1.weight"),
);
rebind(
&mut self.linear2,
registry,
&format!("{base}.linear2.weight"),
);
self.modulation
.rebind_lora(registry, &format!("{base}.modulation"));
}
}
struct FinalLayer {
norm_final: LayerNorm,
linear: QuantizedLinear,
ada_ln_modulation: QuantizedLinear,
}
impl FinalLayer {
fn load(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
Ok(Self {
norm_final: layer_norm(h_sz, &vb.pp("norm_final"))?,
linear: candle_transformers::quantized_nn::linear(
h_sz,
p_sz * p_sz * out_c,
vb.pp("linear"),
)?,
ada_ln_modulation: candle_transformers::quantized_nn::linear(
h_sz,
2 * h_sz,
vb.pp("adaLN_modulation.1"),
)?,
})
}
fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result<Tensor> {
let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?;
let (shift, scale) = (&chunks[0], &chunks[1]);
let xs = xs
.apply(&self.norm_final)?
.broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)?
.broadcast_add(&shift.unsqueeze(1)?)?;
Ok(xs.apply(&self.linear)?)
}
}
struct StemMlpEmbedder {
in_layer: QuantizedLinear,
out_layer: QuantizedLinear,
}
impl StemMlpEmbedder {
fn load(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result<Self> {
Ok(Self {
in_layer: candle_transformers::quantized_nn::linear(in_sz, h_sz, vb.pp("in_layer"))?,
out_layer: candle_transformers::quantized_nn::linear(h_sz, h_sz, vb.pp("out_layer"))?,
})
}
}
impl Module for StemMlpEmbedder {
fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> {
xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer)
}
}
pub(crate) struct QuantizedFluxTransformer {
img_in: QuantizedLinear,
txt_in: QuantizedLinear,
time_in: StemMlpEmbedder,
vector_in: StemMlpEmbedder,
guidance_in: Option<StemMlpEmbedder>,
pe_embedder: EmbedNd,
final_layer: FinalLayer,
double_blocks: Vec<DoubleBlock>,
single_blocks: Vec<SingleBlock>,
lora_registry: Option<LoraRegistry>,
}
impl QuantizedFluxTransformer {
pub fn load(
cfg: &Config,
vb: VarBuilder,
registry: Option<&LoraRegistry>,
progress: &ProgressReporter,
) -> Result<Self> {
progress.info("Loading FLUX quantized transformer (bypass-mode LoRA)");
let img_in = candle_transformers::quantized_nn::linear(
cfg.in_channels,
cfg.hidden_size,
vb.pp("img_in"),
)?;
let txt_in = candle_transformers::quantized_nn::linear(
cfg.context_in_dim,
cfg.hidden_size,
vb.pp("txt_in"),
)?;
let time_in = StemMlpEmbedder::load(256, cfg.hidden_size, vb.pp("time_in"))?;
let vector_in = StemMlpEmbedder::load(cfg.vec_in_dim, cfg.hidden_size, vb.pp("vector_in"))?;
let guidance_in = if cfg.guidance_embed {
Some(StemMlpEmbedder::load(
256,
cfg.hidden_size,
vb.pp("guidance_in"),
)?)
} else {
None
};
let pe_dim = cfg.hidden_size / cfg.num_heads;
let pe_embedder = EmbedNd::new(pe_dim, cfg.theta, cfg.axes_dim.to_vec());
let final_layer =
FinalLayer::load(cfg.hidden_size, 1, cfg.in_channels, vb.pp("final_layer"))?;
let mut double_blocks = Vec::with_capacity(cfg.depth);
let vb_d = vb.pp("double_blocks");
for idx in 0..cfg.depth {
double_blocks.push(DoubleBlock::load(cfg, vb_d.pp(idx), registry, idx)?);
}
let mut single_blocks = Vec::with_capacity(cfg.depth_single_blocks);
let vb_s = vb.pp("single_blocks");
for idx in 0..cfg.depth_single_blocks {
single_blocks.push(SingleBlock::load(cfg, vb_s.pp(idx), registry, idx)?);
}
progress.info(&format!(
"Quantized transformer loaded: {} double + {} single blocks (GPU-resident)",
double_blocks.len(),
single_blocks.len(),
));
Ok(Self {
img_in,
txt_in,
time_in,
vector_in,
guidance_in,
pe_embedder,
final_layer,
double_blocks,
single_blocks,
lora_registry: registry.cloned(),
})
}
pub(crate) fn set_lora_registry(&mut self, registry: Option<LoraRegistry>) {
let r = registry.as_ref();
for (idx, b) in self.double_blocks.iter_mut().enumerate() {
b.rebind_lora(r, idx);
}
for (idx, b) in self.single_blocks.iter_mut().enumerate() {
b.rebind_lora(r, idx);
}
self.lora_registry = registry;
}
#[allow(dead_code)]
pub(crate) fn has_loras(&self) -> bool {
self.lora_registry
.as_ref()
.map(|r| !r.is_empty())
.unwrap_or(false)
}
#[allow(clippy::too_many_arguments)]
pub fn forward(
&self,
img: &Tensor,
img_ids: &Tensor,
txt: &Tensor,
txt_ids: &Tensor,
timesteps: &Tensor,
y: &Tensor,
guidance: Option<&Tensor>,
) -> Result<Tensor> {
if txt.rank() != 3 {
anyhow::bail!("unexpected shape for txt {:?}", txt.shape());
}
if img.rank() != 3 {
anyhow::bail!("unexpected shape for img {:?}", img.shape());
}
let dtype = img.dtype();
let pe = {
let ids = Tensor::cat(&[txt_ids, img_ids], 1)?;
ids.apply(&self.pe_embedder)?
};
let mut txt = txt.apply(&self.txt_in)?;
let mut img = img.apply(&self.img_in)?;
let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?;
let vec_ = match (self.guidance_in.as_ref(), guidance) {
(Some(g_in), Some(guidance)) => {
(vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))?
}
_ => vec_,
};
let vec_ = (vec_ + y.apply(&self.vector_in))?;
for block in self.double_blocks.iter() {
(img, txt) = block.forward(&img, &txt, &vec_, &pe)?;
}
let mut img = Tensor::cat(&[&txt, &img], 1)?;
let txt_len = txt.dim(1)?;
for block in self.single_blocks.iter() {
img = block.forward(&img, &vec_, &pe)?;
}
let img = img.i((.., txt_len..))?;
self.final_layer.forward(&img, &vec_)
}
}
fn rebind(lin: &mut LoraLinear, registry: Option<&LoraRegistry>, key: &str) {
let stack = registry
.map(|r| r.adapters_for(key).to_vec())
.unwrap_or_default();
lin.set_adapters(stack);
}
#[cfg(test)]
fn quantize_cpu(
src: &Tensor,
dtype: candle_core::quantized::GgmlDType,
device: &candle_core::Device,
) -> Result<QTensor> {
Ok(QTensor::quantize_onto(src, dtype, device)?)
}
#[cfg(test)]
fn arc_quantize_cpu(
src: &Tensor,
dtype: candle_core::quantized::GgmlDType,
device: &candle_core::Device,
) -> Result<Arc<QTensor>> {
Ok(Arc::new(quantize_cpu(src, dtype, device)?))
}
#[allow(dead_code)]
fn _arc_qtensor_unused(_t: Arc<QTensor>) {}
#[cfg(test)]
mod tests {
use super::*;
use crate::flux::lora_bypass::LinearLoraAdapter;
use candle_core::quantized::GgmlDType;
use candle_core::Device;
#[test]
fn quantized_linear_with_no_adapter_matches_unwrapped_forward() {
let device = Device::Cpu;
let in_dim = 32;
let out_dim = 64;
let weight: Vec<f32> = (0..out_dim * in_dim)
.map(|i| ((i as f32) * 0.013).sin())
.collect();
let weight = Tensor::from_vec(weight, (out_dim, in_dim), &device).unwrap();
let q_weight = arc_quantize_cpu(&weight, GgmlDType::Q8_0, &device).unwrap();
let bias: Vec<f32> = (0..out_dim).map(|i| (i as f32) * 0.01).collect();
let bias = Tensor::from_vec(bias, (out_dim,), &device).unwrap();
let inner = QuantizedLinear::from_arc(q_weight.clone(), Some(bias.clone())).unwrap();
let wrapped = LoraLinear::Quantized(inner.clone());
let x_data: Vec<f32> = (0..2 * 3 * in_dim)
.map(|i| ((i as f32) * 0.017).cos())
.collect();
let x = Tensor::from_vec(x_data, (2, 3, in_dim), &device).unwrap();
let baseline = <QuantizedLinear as candle_core::Module>::forward(&inner, &x).unwrap();
let wrapped_out = wrapped.forward(&x).unwrap();
let max = max_abs_diff(&baseline, &wrapped_out);
assert!(max < 1e-7, "no-adapter wrapped diverged: {max}");
}
#[test]
fn quantized_linear_with_one_adapter_matches_explicit_dequant_merge() {
let device = Device::Cpu;
let in_dim = 32;
let out_dim = 64;
let rank = 4;
let scale = 0.5f32;
let w_vec: Vec<f32> = (0..out_dim * in_dim)
.map(|i| ((i as f32) * 0.011).sin() * 0.1)
.collect();
let w = Tensor::from_vec(w_vec, (out_dim, in_dim), &device).unwrap();
let down_vec: Vec<f32> = (0..rank * in_dim)
.map(|i| ((i as f32) * 0.019).cos() * 0.05)
.collect();
let up_vec: Vec<f32> = (0..out_dim * rank)
.map(|i| ((i as f32) * 0.023).sin() * 0.05)
.collect();
let down = Tensor::from_vec(down_vec, (rank, in_dim), &device).unwrap();
let up = Tensor::from_vec(up_vec, (out_dim, rank), &device).unwrap();
let q_w = arc_quantize_cpu(&w, GgmlDType::Q8_0, &device).unwrap();
let inner = QuantizedLinear::from_arc(q_w, None).unwrap();
let bypass = LoraLinear::WithAdaptersQuantized {
inner,
adapters: vec![LinearLoraAdapter {
down: down.clone(),
up: up.clone(),
scale,
fused_slice: None,
}],
};
let merged_delta = up.matmul(&down).unwrap().affine(scale as f64, 0.0).unwrap();
let merged_w = (&w + &merged_delta).unwrap();
let q_merged = arc_quantize_cpu(&merged_w, GgmlDType::Q8_0, &device).unwrap();
let merged_inner = QuantizedLinear::from_arc(q_merged, None).unwrap();
let x_vec: Vec<f32> = (0..3 * in_dim)
.map(|i| ((i as f32) * 0.029).cos() * 0.5)
.collect();
let x = Tensor::from_vec(x_vec, (1, 3, in_dim), &device).unwrap();
let bypass_out = bypass.forward(&x).unwrap();
let merged_out =
<QuantizedLinear as candle_core::Module>::forward(&merged_inner, &x).unwrap();
let max = max_abs_diff(&bypass_out, &merged_out);
assert!(max < 5e-2, "Q8_0 bypass vs merge max diff: {max}");
}
#[test]
fn quantized_linear_two_adapters_compose() {
let device = Device::Cpu;
let in_dim = 32;
let out_dim = 16;
let rank = 3;
let w: Vec<f32> = (0..out_dim * in_dim)
.map(|i| ((i as f32) * 0.011).sin() * 0.1)
.collect();
let w = Tensor::from_vec(w, (out_dim, in_dim), &device).unwrap();
let q_w = arc_quantize_cpu(&w, GgmlDType::Q8_0, &device).unwrap();
let make_pair = |salt: f32| {
let d: Vec<f32> = (0..rank * in_dim)
.map(|i| ((i as f32 + salt) * 0.013).sin() * 0.03)
.collect();
let u: Vec<f32> = (0..out_dim * rank)
.map(|i| ((i as f32 + salt) * 0.017).cos() * 0.03)
.collect();
(
Tensor::from_vec(d, (rank, in_dim), &device).unwrap(),
Tensor::from_vec(u, (out_dim, rank), &device).unwrap(),
)
};
let (d1, u1) = make_pair(1.0);
let (d2, u2) = make_pair(7.0);
let s1 = 0.4f32;
let s2 = -0.3f32;
let inner = QuantizedLinear::from_arc(q_w.clone(), None).unwrap();
let two = LoraLinear::WithAdaptersQuantized {
inner: inner.clone(),
adapters: vec![
LinearLoraAdapter {
down: d1.clone(),
up: u1.clone(),
scale: s1,
fused_slice: None,
},
LinearLoraAdapter {
down: d2.clone(),
up: u2.clone(),
scale: s2,
fused_slice: None,
},
],
};
let one1 = LoraLinear::WithAdaptersQuantized {
inner: inner.clone(),
adapters: vec![LinearLoraAdapter {
down: d1,
up: u1,
scale: s1,
fused_slice: None,
}],
};
let one2 = LoraLinear::WithAdaptersQuantized {
inner,
adapters: vec![LinearLoraAdapter {
down: d2,
up: u2,
scale: s2,
fused_slice: None,
}],
};
let x: Vec<f32> = (0..2 * in_dim)
.map(|i| ((i as f32) * 0.029).cos() * 0.5)
.collect();
let x = Tensor::from_vec(x, (1, 2, in_dim), &device).unwrap();
let two_out = two.forward(&x).unwrap();
let one1_out = one1.forward(&x).unwrap();
let one2_out = one2.forward(&x).unwrap();
let base_out = <QuantizedLinear as candle_core::Module>::forward(
&QuantizedLinear::from_arc(q_w, None).unwrap(),
&x,
)
.unwrap();
let lhs = (&two_out + &base_out).unwrap();
let rhs = (&one1_out + &one2_out).unwrap();
let max = max_abs_diff(&lhs, &rhs);
assert!(max < 1e-5, "two-adapter compose != sum of singles: {max}");
}
#[test]
fn quantized_linear_fused_qkv_only_writes_target_slice() {
let device = Device::Cpu;
let in_dim = 32; let h = 16; let out_dim = 3 * h;
let rank = 2;
let w: Vec<f32> = (0..out_dim * in_dim)
.map(|i| ((i as f32) * 0.011).sin() * 0.1)
.collect();
let w = Tensor::from_vec(w, (out_dim, in_dim), &device).unwrap();
let q_w = arc_quantize_cpu(&w, GgmlDType::Q8_0, &device).unwrap();
let inner = QuantizedLinear::from_arc(q_w, None).unwrap();
let down: Vec<f32> = (0..rank * in_dim)
.map(|i| ((i as f32) * 0.013).sin() * 0.05)
.collect();
let up: Vec<f32> = (0..h * rank)
.map(|i| ((i as f32) * 0.017).cos() * 0.05)
.collect();
let down = Tensor::from_vec(down, (rank, in_dim), &device).unwrap();
let up = Tensor::from_vec(up, (h, rank), &device).unwrap();
let with_q = LoraLinear::WithAdaptersQuantized {
inner: inner.clone(),
adapters: vec![LinearLoraAdapter {
down,
up,
scale: 0.7,
fused_slice: Some(crate::flux::lora_bypass::FusedSlice {
offset: 0,
length: h,
}),
}],
};
let x: Vec<f32> = (0..3 * in_dim)
.map(|i| ((i as f32) * 0.029).cos() * 0.5)
.collect();
let x = Tensor::from_vec(x, (1, 3, in_dim), &device).unwrap();
let plain_out = <QuantizedLinear as candle_core::Module>::forward(&inner, &x).unwrap();
let bypass_out = with_q.forward(&x).unwrap();
let kv_plain = plain_out.narrow(2, h, 2 * h).unwrap();
let kv_bypass = bypass_out.narrow(2, h, 2 * h).unwrap();
let max = max_abs_diff(&kv_plain, &kv_bypass);
assert!(max < 1e-7, "K/V rows drifted under Q-only adapter: {max}");
}
#[test]
fn bypass_skips_rebuild_on_same_fingerprint() {
let device = Device::Cpu;
use crate::flux::lora::{LoraAdapter, LoraLayer, LoraSpec};
use crate::flux::lora_bypass::{build_registry, LoraRegistry};
use std::collections::HashMap as HM;
let h = 16;
let a = Tensor::zeros((4, h), DType::F32, &device).unwrap();
let b = Tensor::zeros((h, 4), DType::F32, &device).unwrap();
let mut layers = std::collections::HashMap::new();
layers.insert(
"transformer.transformer_blocks.0.attn.to_q".to_string(),
LoraLayer { a, b, alpha: None },
);
let adapter = LoraAdapter { layers, rank: 4 };
let specs = [LoraSpec {
adapter: &adapter,
scale: 0.5,
path_hash: 0x1234,
}];
let mut linear_out_dims = HM::new();
linear_out_dims.insert("double_blocks.0.img_attn.qkv.weight".to_string(), 3 * h);
let r1 = build_registry(&specs, &linear_out_dims, &device, DType::F32).unwrap();
let r2 = build_registry(&specs, &linear_out_dims, &device, DType::F32).unwrap();
assert_eq!(r1.len(), r2.len(), "same specs → same registry length");
let key = "double_blocks.0.img_attn.qkv.weight";
assert_eq!(r1.adapters_for(key).len(), r2.adapters_for(key).len());
drop(LoraRegistry::default());
}
fn max_abs_diff(a: &Tensor, b: &Tensor) -> f32 {
let diff = (a - b).unwrap().abs().unwrap();
diff.flatten_all()
.unwrap()
.max(0)
.unwrap()
.to_scalar::<f32>()
.unwrap()
}
}