use candle_core::{DType, Device, Module, Tensor, D};
use candle_nn::VarBuilder;
use candle_transformers::models::with_tracing::RmsNorm;
use candle_transformers::models::z_image::transformer::apply_rotary_emb;
use super::quantized_transformer::{
build_edit_modulation_index, select_modulation_params, QwenRopeEmbedder,
};
#[derive(Debug, Clone)]
enum QwenLinear {
Standard(candle_nn::Linear),
Fp8 {
weight: Tensor,
scale: Option<Tensor>,
bias: Option<Tensor>,
},
}
impl QwenLinear {
fn load(
in_dim: usize,
out_dim: usize,
has_bias: bool,
vb: VarBuilder,
) -> candle_core::Result<Self> {
let weight = vb.get((out_dim, in_dim), "weight")?;
if weight.dtype() == DType::F8E4M3 {
let scale = vb.get_unchecked("scale_weight").ok();
let bias = if has_bias {
vb.get_unchecked("bias").ok()
} else {
None
};
Ok(Self::Fp8 {
weight,
scale,
bias,
})
} else {
let bias = if has_bias {
Some(vb.get(out_dim, "bias")?)
} else {
None
};
Ok(Self::Standard(candle_nn::Linear::new(weight, bias)))
}
}
}
impl Module for QwenLinear {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
match self {
Self::Standard(l) => l.forward(x),
Self::Fp8 {
weight,
scale,
bias,
} => {
let dtype = x.dtype();
let w = weight.to_dtype(dtype)?;
let w = match scale {
Some(s) => w.broadcast_mul(&s.to_dtype(dtype)?)?,
None => w,
};
let w = w.t()?;
let out = match *x.dims() {
[b1, b2, m, k] => {
x.reshape((b1 * b2 * m, k))?
.matmul(&w)?
.reshape((b1, b2, m, ()))?
}
[bsize, m, k] => {
x.reshape((bsize * m, k))?
.matmul(&w)?
.reshape((bsize, m, ()))?
}
_ => x.matmul(&w)?,
};
match bias {
Some(b) => out.broadcast_add(&b.to_dtype(dtype)?),
None => Ok(out),
}
}
}
}
}
#[derive(Debug, Clone)]
enum FeedForward {
SwiGlu {
w1: QwenLinear,
w2: QwenLinear,
w3: QwenLinear,
},
Gelu {
proj: QwenLinear,
out: QwenLinear,
},
}
impl FeedForward {
fn new(dim: usize, hidden_dim: usize, vb: VarBuilder) -> candle_core::Result<Self> {
if vb.contains_tensor("net.0.proj.weight") {
let has_bias = vb.contains_tensor("net.0.proj.bias");
let proj =
QwenLinear::load(dim, hidden_dim, has_bias, vb.pp("net").pp("0").pp("proj"))?;
let out = QwenLinear::load(hidden_dim, dim, has_bias, vb.pp("net").pp("2"))?;
Ok(Self::Gelu { proj, out })
} else {
let w1 = QwenLinear::load(dim, hidden_dim, false, vb.pp("w1"))?;
let w2 = QwenLinear::load(hidden_dim, dim, false, vb.pp("w2"))?;
let w3 = QwenLinear::load(dim, hidden_dim, false, vb.pp("w3"))?;
Ok(Self::SwiGlu { w1, w2, w3 })
}
}
}
impl Module for FeedForward {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
match self {
Self::SwiGlu { w1, w2, w3 } => {
let gate = w1.forward(x)?.silu()?;
let x = (gate * w3.forward(x)?)?;
w2.forward(&x)
}
Self::Gelu { proj, out } => {
let x = proj
.forward(x)?
.apply(&candle_nn::Activation::GeluPytorchTanh)?;
out.forward(&x)
}
}
}
}
#[derive(Debug, Clone)]
struct LayerNormNoParams {
eps: f64,
}
impl LayerNormNoParams {
fn new(eps: f64) -> Self {
Self { eps }
}
}
impl Module for LayerNormNoParams {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
x_normed.to_dtype(x_dtype)
}
}
#[derive(Debug, Clone)]
pub(crate) struct QwenImageConfig {
pub num_attention_heads: usize,
pub attention_head_dim: usize,
pub inner_dim: usize,
pub joint_attention_dim: usize,
pub num_layers: usize,
pub in_channels: usize,
pub out_channels: usize,
pub patch_size: usize,
pub axes_dims_rope: Vec<usize>,
pub norm_eps: f64,
pub zero_cond_t: bool,
}
impl Default for QwenImageConfig {
fn default() -> Self {
Self::qwen_image_2512()
}
}
impl QwenImageConfig {
pub fn qwen_image_2512() -> Self {
let num_attention_heads = 24;
let attention_head_dim = 128;
Self {
num_attention_heads,
attention_head_dim,
inner_dim: num_attention_heads * attention_head_dim, joint_attention_dim: 3584,
num_layers: 60,
in_channels: 64, out_channels: 16, patch_size: 2,
axes_dims_rope: vec![16, 56, 56],
norm_eps: 1e-6,
zero_cond_t: false,
}
}
pub fn qwen_image_edit_2511() -> Self {
let mut cfg = Self::qwen_image_2512();
cfg.zero_cond_t = true;
cfg
}
pub fn hidden_dim(&self) -> usize {
(self.inner_dim / 3) * 8
}
}
#[derive(Debug, Clone)]
struct TimestepProjEmbeddings {
linear1: QwenLinear,
linear2: QwenLinear,
frequency_embedding_size: usize,
}
const FREQUENCY_EMBEDDING_SIZE: usize = 256;
pub(crate) const MAX_PERIOD: f64 = 10000.0;
impl TimestepProjEmbeddings {
fn new(inner_dim: usize, vb: VarBuilder) -> candle_core::Result<Self> {
let vb = if vb.contains_tensor("timestep_embedder.linear_1.weight") {
vb.pp("timestep_embedder")
} else {
vb
};
let has_bias = vb.contains_tensor("linear_1.bias");
let linear1 = QwenLinear::load(
FREQUENCY_EMBEDDING_SIZE,
inner_dim,
has_bias,
vb.pp("linear_1"),
)?;
let linear2 = QwenLinear::load(inner_dim, inner_dim, has_bias, vb.pp("linear_2"))?;
Ok(Self {
linear1,
linear2,
frequency_embedding_size: FREQUENCY_EMBEDDING_SIZE,
})
}
fn timestep_embedding(
&self,
t: &Tensor,
device: &Device,
dtype: DType,
) -> candle_core::Result<Tensor> {
let half = self.frequency_embedding_size / 2;
let freqs = Tensor::arange(0u32, half as u32, device)?.to_dtype(DType::F32)?;
let freqs = (freqs * (-MAX_PERIOD.ln() / half as f64))?.exp()?;
let args = t
.unsqueeze(1)?
.to_dtype(DType::F32)?
.broadcast_mul(&freqs.unsqueeze(0)?)?;
let embedding = Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?;
embedding.to_dtype(dtype)
}
fn forward(&self, t: &Tensor, dtype: DType) -> candle_core::Result<Tensor> {
let device = t.device();
let t_freq = self.timestep_embedding(t, device, dtype)?;
self.linear1
.forward(&t_freq)?
.silu()
.and_then(|x| self.linear2.forward(&x))
}
}
#[derive(Debug, Clone)]
struct JointAttention {
to_q: QwenLinear,
to_k: QwenLinear,
to_v: QwenLinear,
to_out: QwenLinear,
add_q_proj: QwenLinear,
add_k_proj: QwenLinear,
add_v_proj: QwenLinear,
add_out_proj: QwenLinear,
norm_q: RmsNorm,
norm_k: RmsNorm,
norm_added_q: RmsNorm,
norm_added_k: RmsNorm,
n_heads: usize,
head_dim: usize,
}
impl JointAttention {
fn new(cfg: &QwenImageConfig, vb: VarBuilder) -> candle_core::Result<Self> {
let dim = cfg.inner_dim;
let text_dim = cfg.joint_attention_dim;
let n_heads = cfg.num_attention_heads;
let head_dim = cfg.attention_head_dim;
let qkv_dim = n_heads * head_dim;
let has_bias = vb.contains_tensor("to_q.bias");
let to_q = QwenLinear::load(dim, qkv_dim, has_bias, vb.pp("to_q"))?;
let to_k = QwenLinear::load(dim, qkv_dim, has_bias, vb.pp("to_k"))?;
let to_v = QwenLinear::load(dim, qkv_dim, has_bias, vb.pp("to_v"))?;
let to_out_key = if vb.contains_tensor("to_out.0.weight") {
"to_out.0"
} else {
"to_out_0"
};
let to_out = QwenLinear::load(qkv_dim, dim, has_bias, vb.pp(to_out_key))?;
let add_q_proj = QwenLinear::load(text_dim, qkv_dim, has_bias, vb.pp("add_q_proj"))?;
let add_k_proj = QwenLinear::load(text_dim, qkv_dim, has_bias, vb.pp("add_k_proj"))?;
let add_v_proj = QwenLinear::load(text_dim, qkv_dim, has_bias, vb.pp("add_v_proj"))?;
let add_out_proj = QwenLinear::load(qkv_dim, text_dim, has_bias, vb.pp("to_add_out"))?;
let norm_q = RmsNorm::new(head_dim, 1e-6, vb.pp("norm_q"))?;
let norm_k = RmsNorm::new(head_dim, 1e-6, vb.pp("norm_k"))?;
let norm_added_q = RmsNorm::new(head_dim, 1e-6, vb.pp("norm_added_q"))?;
let norm_added_k = RmsNorm::new(head_dim, 1e-6, vb.pp("norm_added_k"))?;
Ok(Self {
to_q,
to_k,
to_v,
to_out,
add_q_proj,
add_k_proj,
add_v_proj,
add_out_proj,
norm_q,
norm_k,
norm_added_q,
norm_added_k,
n_heads,
head_dim,
})
}
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
img_hidden: &Tensor,
txt_hidden: &Tensor,
txt_mask: &Tensor,
img_cos: &Tensor,
img_sin: &Tensor,
txt_cos: &Tensor,
txt_sin: &Tensor,
img_seq_len: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
let (b, _, _) = img_hidden.dims3()?;
let q_img = img_hidden.apply(&self.to_q)?;
let k_img = img_hidden.apply(&self.to_k)?;
let v_img = img_hidden.apply(&self.to_v)?;
let q_txt = txt_hidden.apply(&self.add_q_proj)?;
let k_txt = txt_hidden.apply(&self.add_k_proj)?;
let v_txt = txt_hidden.apply(&self.add_v_proj)?;
let txt_seq_len = txt_hidden.dim(1)?;
let q_img = q_img.reshape((b, img_seq_len, self.n_heads, self.head_dim))?;
let k_img = k_img.reshape((b, img_seq_len, self.n_heads, self.head_dim))?;
let v_img = v_img.reshape((b, img_seq_len, self.n_heads, self.head_dim))?;
let q_txt = q_txt.reshape((b, txt_seq_len, self.n_heads, self.head_dim))?;
let k_txt = k_txt.reshape((b, txt_seq_len, self.n_heads, self.head_dim))?;
let v_txt = v_txt.reshape((b, txt_seq_len, self.n_heads, self.head_dim))?;
let q_img = self.apply_qk_norm(&q_img, &self.norm_q)?;
let k_img = self.apply_qk_norm(&k_img, &self.norm_k)?;
let q_txt = self.apply_qk_norm(&q_txt, &self.norm_added_q)?;
let k_txt = self.apply_qk_norm(&k_txt, &self.norm_added_k)?;
let q_img = apply_rotary_emb(&q_img, img_cos, img_sin)?;
let k_img = apply_rotary_emb(&k_img, img_cos, img_sin)?;
let q_txt = apply_rotary_emb(&q_txt, txt_cos, txt_sin)?;
let k_txt = apply_rotary_emb(&k_txt, txt_cos, txt_sin)?;
let q = Tensor::cat(&[&q_txt, &q_img], 1)?;
let k = Tensor::cat(&[&k_txt, &k_img], 1)?;
let v = Tensor::cat(&[&v_txt, &v_img], 1)?;
let q = q.transpose(1, 2)?.contiguous()?;
let k = k.transpose(1, 2)?.contiguous()?;
let v = v.transpose(1, 2)?.contiguous()?;
let scale = 1.0 / (self.head_dim as f64).sqrt();
let img_mask = Tensor::ones((b, img_seq_len), DType::U8, q.device())?;
let key_mask = Tensor::cat(&[txt_mask, &img_mask], 1)?
.unsqueeze(1)?
.unsqueeze(1)?;
let on_true = key_mask.zeros_like()?.to_dtype(q.dtype())?;
let on_false = Tensor::new(f32::NEG_INFINITY, q.device())?
.broadcast_as(key_mask.shape())?
.to_dtype(q.dtype())?;
let key_mask = key_mask.where_cond(&on_true, &on_false)?;
let attn = self.attention_dispatch(&q, &k, &v, scale, q.device(), Some(&key_mask))?;
let total_seq = img_seq_len + txt_seq_len;
let attn = attn.transpose(1, 2)?.reshape((b, total_seq, ()))?;
let txt_attn = attn.narrow(1, 0, txt_seq_len)?;
let img_attn = attn.narrow(1, txt_seq_len, img_seq_len)?;
let img_out = img_attn.apply(&self.to_out)?;
let txt_out = txt_attn.apply(&self.add_out_proj)?.broadcast_mul(
&txt_mask
.unsqueeze(D::Minus1)?
.to_dtype(txt_hidden.dtype())?,
)?;
Ok((img_out, txt_out))
}
fn apply_qk_norm(&self, x: &Tensor, norm: &RmsNorm) -> candle_core::Result<Tensor> {
let (b, seq, heads, head_dim) = x.dims4()?;
let flat = x.reshape((b * seq * heads, head_dim))?;
let normed = norm.forward(&flat)?;
normed.reshape((b, seq, heads, head_dim))
}
fn attention_dispatch(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
scale: f64,
device: &Device,
key_mask: Option<&Tensor>,
) -> candle_core::Result<Tensor> {
if device.is_metal() {
candle_nn::ops::sdpa(q, k, v, None, false, scale as f32, 1.0)
} else {
let mut attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
if let Some(mask) = key_mask {
attn_weights = attn_weights.broadcast_add(mask)?;
}
attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(v)
}
}
}
#[derive(Debug, Clone)]
struct QwenImageTransformerBlock {
norm1: LayerNormNoParams,
norm1_context: LayerNormNoParams,
attn: JointAttention,
ff: FeedForward,
ff_context: FeedForward,
norm2: LayerNormNoParams,
norm2_context: LayerNormNoParams,
adaln_modulation: QwenLinear,
adaln_context_modulation: QwenLinear,
}
impl QwenImageTransformerBlock {
fn new(cfg: &QwenImageConfig, vb: VarBuilder) -> candle_core::Result<Self> {
let dim = cfg.inner_dim;
let text_dim = cfg.joint_attention_dim;
let is_comfyui = vb.contains_tensor("img_mlp.net.0.proj.weight");
let hidden_dim = if is_comfyui {
dim * 4
} else {
cfg.hidden_dim()
};
let norm1 = LayerNormNoParams::new(cfg.norm_eps);
let norm1_context = LayerNormNoParams::new(cfg.norm_eps);
let attn = JointAttention::new(cfg, vb.pp("attn"))?;
let ff_key = if is_comfyui { "img_mlp" } else { "ff" };
let ff_ctx_key = if is_comfyui { "txt_mlp" } else { "ff_context" };
let ff = FeedForward::new(dim, hidden_dim, vb.pp(ff_key))?;
let ff_context = FeedForward::new(text_dim, text_dim * 4, vb.pp(ff_ctx_key))?;
let norm2 = LayerNormNoParams::new(cfg.norm_eps);
let norm2_context = LayerNormNoParams::new(cfg.norm_eps);
let has_bias =
vb.contains_tensor("img_mod.1.bias") || vb.contains_tensor("norm1.linear.bias");
let (adaln_modulation, adaln_context_modulation) = if vb.contains_tensor("img_mod.1.weight")
{
(
QwenLinear::load(dim, 6 * dim, has_bias, vb.pp("img_mod").pp("1"))?,
QwenLinear::load(dim, 6 * text_dim, has_bias, vb.pp("txt_mod").pp("1"))?,
)
} else {
(
QwenLinear::load(dim, 6 * dim, has_bias, vb.pp("norm1").pp("linear"))?,
QwenLinear::load(
dim,
6 * text_dim,
has_bias,
vb.pp("norm1_context").pp("linear"),
)?,
)
};
Ok(Self {
norm1,
norm1_context,
attn,
ff,
ff_context,
norm2,
norm2_context,
adaln_modulation,
adaln_context_modulation,
})
}
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
img_hidden: &Tensor,
txt_hidden: &Tensor,
txt_mask: &Tensor,
temb: &Tensor,
img_cos: &Tensor,
img_sin: &Tensor,
txt_cos: &Tensor,
txt_sin: &Tensor,
modulate_index: Option<&Tensor>,
) -> candle_core::Result<(Tensor, Tensor)> {
let img_seq_len = img_hidden.dim(1)?;
let img_mod = temb.silu()?.apply(&self.adaln_modulation)?;
let img_mod = if let Some(modulate_index) = modulate_index {
select_modulation_params(&img_mod, modulate_index)?
} else {
img_mod.unsqueeze(1)?
};
let img_chunks = img_mod.chunk(6, D::Minus1)?;
let (
img_shift_msa,
img_scale_msa,
img_gate_msa,
img_shift_mlp,
img_scale_mlp,
img_gate_mlp,
) = (
&img_chunks[0],
&img_chunks[1],
&img_chunks[2],
&img_chunks[3],
&img_chunks[4],
&img_chunks[5],
);
let txt_temb = if modulate_index.is_some() {
temb.narrow(0, 0, txt_hidden.dim(0)?)?
} else {
temb.clone()
};
let txt_mod = txt_temb
.silu()?
.apply(&self.adaln_context_modulation)?
.unsqueeze(1)?;
let txt_chunks = txt_mod.chunk(6, D::Minus1)?;
let (
txt_shift_msa,
txt_scale_msa,
txt_gate_msa,
txt_shift_mlp,
txt_scale_mlp,
txt_gate_mlp,
) = (
&txt_chunks[0],
&txt_chunks[1],
&txt_chunks[2],
&txt_chunks[3],
&txt_chunks[4],
&txt_chunks[5],
);
let img_attn_in = self
.norm1
.forward(img_hidden)?
.broadcast_mul(&(img_scale_msa + 1.0)?)?
.broadcast_add(img_shift_msa)?;
let txt_attn_in = self
.norm1_context
.forward(txt_hidden)?
.broadcast_mul(&(txt_scale_msa + 1.0)?)?
.broadcast_add(txt_shift_msa)?;
let (img_attn, txt_attn) = self.attn.forward(
&img_attn_in,
&txt_attn_in,
txt_mask,
img_cos,
img_sin,
txt_cos,
txt_sin,
img_seq_len,
)?;
let img_hidden = (img_hidden + img_gate_msa.broadcast_mul(&img_attn)?)?;
let txt_hidden = (txt_hidden + txt_gate_msa.broadcast_mul(&txt_attn)?)?;
let img_mlp_in = self
.norm2
.forward(&img_hidden)?
.broadcast_mul(&(img_scale_mlp + 1.0)?)?
.broadcast_add(img_shift_mlp)?;
let img_ff = self.ff.forward(&img_mlp_in)?;
let img_hidden = (img_hidden + img_gate_mlp.broadcast_mul(&img_ff)?)?;
let txt_mlp_in = self
.norm2_context
.forward(&txt_hidden)?
.broadcast_mul(&(txt_scale_mlp + 1.0)?)?
.broadcast_add(txt_shift_mlp)?;
let txt_ff = self.ff_context.forward(&txt_mlp_in)?;
let txt_hidden = (txt_hidden + txt_gate_mlp.broadcast_mul(&txt_ff)?)?;
Ok((img_hidden, txt_hidden))
}
}
#[derive(Debug, Clone)]
struct OutputLayer {
norm_final: LayerNormNoParams,
linear: QwenLinear,
adaln_linear: QwenLinear,
}
impl OutputLayer {
fn new(
inner_dim: usize,
out_channels: usize,
patch_size: usize,
vb: VarBuilder,
) -> candle_core::Result<Self> {
let output_dim = patch_size * patch_size * out_channels;
let norm_final = LayerNormNoParams::new(1e-6);
let has_bias = vb.contains_tensor("proj_out.bias");
let proj_out = QwenLinear::load(inner_dim, output_dim, has_bias, vb.pp("proj_out"))?;
let adaln_linear = QwenLinear::load(
inner_dim,
2 * inner_dim,
has_bias,
vb.pp("norm_out").pp("linear"),
)?;
Ok(Self {
norm_final,
linear: proj_out,
adaln_linear,
})
}
fn forward(&self, x: &Tensor, temb: &Tensor) -> candle_core::Result<Tensor> {
let mod_params = temb.silu()?.apply(&self.adaln_linear)?;
let chunks = mod_params.chunk(2, D::Minus1)?;
let scale = chunks[0].unsqueeze(1)?;
let shift = chunks[1].unsqueeze(1)?;
let x = self
.norm_final
.forward(x)?
.broadcast_mul(&(scale + 1.0)?)?
.broadcast_add(&shift)?;
x.apply(&self.linear)
}
}
#[derive(Debug, Clone)]
pub(crate) struct QwenImageTransformer2DModel {
time_embed: TimestepProjEmbeddings,
img_in: QwenLinear,
txt_in: QwenLinear,
txt_norm: RmsNorm,
blocks: Vec<QwenImageTransformerBlock>,
rope_embedder: QwenRopeEmbedder,
output_layer: OutputLayer,
cfg: QwenImageConfig,
}
impl QwenImageTransformer2DModel {
pub fn new(cfg: &QwenImageConfig, vb: VarBuilder) -> candle_core::Result<Self> {
let device = vb.device();
let dtype = vb.dtype();
let is_comfyui = vb.contains_tensor("img_in.weight");
let block_text_dim = if is_comfyui {
cfg.inner_dim
} else {
cfg.joint_attention_dim
};
let time_embed = TimestepProjEmbeddings::new(cfg.inner_dim, vb.pp("time_text_embed"))?;
let img_in_key = if is_comfyui { "img_in" } else { "x_embedder" };
let has_stem_bias = vb.contains_tensor(&format!("{img_in_key}.bias"));
let img_in = QwenLinear::load(
cfg.in_channels,
cfg.inner_dim,
has_stem_bias,
vb.pp(img_in_key),
)?;
let (txt_in_key, txt_in_in) = if is_comfyui {
("txt_in", cfg.joint_attention_dim) } else {
("context_embedder", cfg.joint_attention_dim) };
let txt_in = QwenLinear::load(txt_in_in, block_text_dim, has_stem_bias, vb.pp(txt_in_key))?;
let txt_norm = RmsNorm::new(cfg.joint_attention_dim, cfg.norm_eps, vb.pp("txt_norm"))?;
let mut block_cfg = cfg.clone();
block_cfg.joint_attention_dim = block_text_dim;
let mut blocks = Vec::with_capacity(cfg.num_layers);
let vb_blocks = vb.pp("transformer_blocks");
for i in 0..cfg.num_layers {
blocks.push(QwenImageTransformerBlock::new(&block_cfg, vb_blocks.pp(i))?);
}
let rope_embedder =
QwenRopeEmbedder::new(10000.0, cfg.axes_dims_rope.clone(), device, dtype)?;
let output_layer =
OutputLayer::new(cfg.inner_dim, cfg.out_channels, cfg.patch_size, vb.clone())?;
Ok(Self {
time_embed,
img_in,
txt_in,
txt_norm,
blocks,
rope_embedder,
output_layer,
cfg: cfg.clone(),
})
}
pub fn forward(
&self,
x: &Tensor,
t: &Tensor,
encoder_hidden_states: &Tensor,
encoder_attention_mask: &Tensor,
) -> candle_core::Result<Tensor> {
let device = x.device();
let (_b, _c, h, w) = x.dims4()?;
let patch_size = self.cfg.patch_size;
let temb = self
.time_embed
.forward(t, crate::engine::gpu_dtype(device))?;
let hp = h / patch_size;
let wp = w / patch_size;
let x_packed = x
.reshape((_b, _c, hp, patch_size, wp, patch_size))?
.permute((0, 2, 4, 1, 3, 5))?
.reshape((_b, hp * wp, _c * patch_size * patch_size))?
.contiguous()?;
let img_hidden = x_packed.apply(&self.img_in)?;
let txt_normed = self.txt_norm.forward(encoder_hidden_states)?;
let txt_hidden = txt_normed.apply(&self.txt_in)?;
let h_tokens = h / patch_size;
let w_tokens = w / patch_size;
let txt_seq_len = encoder_hidden_states.dim(1)?;
let (img_cos, img_sin, txt_cos, txt_sin) =
self.rope_embedder
.forward(1, h_tokens, w_tokens, txt_seq_len, device)?;
let mut img = img_hidden;
let mut txt = txt_hidden;
for block in &self.blocks {
let (new_img, new_txt) = block.forward(
&img,
&txt,
encoder_attention_mask,
&temb,
&img_cos,
&img_sin,
&txt_cos,
&txt_sin,
None,
)?;
img = new_img;
txt = new_txt;
}
let img_out = self.output_layer.forward(&img, &temb)?;
let x_out = img_out
.reshape((_b, hp, wp, self.cfg.out_channels, patch_size, patch_size))?
.permute((0, 3, 1, 4, 2, 5))?
.reshape((_b, self.cfg.out_channels, h, w))?
.contiguous()?;
Ok(x_out)
}
pub fn forward_packed(
&self,
packed_hidden_states: &Tensor,
t: &Tensor,
encoder_hidden_states: &Tensor,
encoder_attention_mask: &Tensor,
img_shapes: &[(usize, usize, usize)],
) -> candle_core::Result<Tensor> {
let device = packed_hidden_states.device();
let batch = packed_hidden_states.dim(0)?;
let mut timestep = t.clone();
let modulate_index = if self.cfg.zero_cond_t {
timestep = Tensor::cat(&[×tep, &(timestep.zeros_like()?)], 0)?;
Some(build_edit_modulation_index(img_shapes, batch, device)?)
} else {
None
};
let temb = self
.time_embed
.forward(×tep, crate::engine::gpu_dtype(device))?;
let mut img = packed_hidden_states.apply(&self.img_in)?;
let txt_normed = self.txt_norm.forward(encoder_hidden_states)?;
let mut txt = txt_normed.apply(&self.txt_in)?;
let txt_seq_len = encoder_hidden_states.dim(1)?;
let (img_cos, img_sin, txt_cos, txt_sin) =
self.rope_embedder
.forward_shapes(img_shapes, txt_seq_len, device)?;
for block in &self.blocks {
let (new_img, new_txt) = block.forward(
&img,
&txt,
encoder_attention_mask,
&temb,
&img_cos,
&img_sin,
&txt_cos,
&txt_sin,
modulate_index.as_ref(),
)?;
img = new_img;
txt = new_txt;
}
let out_temb = if self.cfg.zero_cond_t {
temb.narrow(0, 0, batch)?
} else {
temb
};
self.output_layer.forward(&img, &out_temb)
}
}