#![allow(clippy::similar_names)]
#![allow(clippy::needless_pass_by_value)]
use candle_core::{D, Module, Result, Tensor};
use candle_nn::{Linear, VarBuilder, linear_no_bias};
#[derive(Debug)]
pub struct AdaLNModulation {
proj: Linear,
embed_dim: usize,
}
impl AdaLNModulation {
pub fn new(embed_dim: usize, vb: VarBuilder) -> Result<Self> {
let proj = linear_no_bias(embed_dim, embed_dim * 6, vb)?;
Ok(Self { proj, embed_dim })
}
pub fn forward(
&self,
global_cond: &Tensor,
) -> Result<(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)> {
let activated = candle_nn::ops::silu(global_cond)?;
let projected = self.proj.forward(&activated)?;
let projected = projected.unsqueeze(1)?; let mut chunks = Vec::with_capacity(6);
for i in 0..6 {
let start = i * self.embed_dim;
let chunk = projected.narrow(D::Minus1, start, self.embed_dim)?;
chunks.push(chunk);
}
let gate_ff = chunks.remove(5);
let shift_ff = chunks.remove(4);
let scale_ff = chunks.remove(3);
let gate_self = chunks.remove(2);
let shift_self = chunks.remove(1);
let scale_self = chunks.remove(0);
Ok((
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff,
))
}
}
pub fn modulate(x: &Tensor, shift: &Tensor, scale: &Tensor) -> Result<Tensor> {
let one = Tensor::ones_like(scale)?;
let scaled = x.broadcast_mul(&(scale + one)?)?;
scaled.broadcast_add(shift)
}