Skip to main content

candle_transformers/models/stable_diffusion/
embeddings.rs

1use candle::{Result, Tensor, D};
2use candle_nn as nn;
3use candle_nn::Module;
4
5#[derive(Debug)]
6pub struct TimestepEmbedding {
7    linear_1: nn::Linear,
8    linear_2: nn::Linear,
9}
10
11impl TimestepEmbedding {
12    // act_fn: "silu"
13    pub fn new(vs: nn::VarBuilder, channel: usize, time_embed_dim: usize) -> Result<Self> {
14        let linear_1 = nn::linear(channel, time_embed_dim, vs.pp("linear_1"))?;
15        let linear_2 = nn::linear(time_embed_dim, time_embed_dim, vs.pp("linear_2"))?;
16        Ok(Self { linear_1, linear_2 })
17    }
18}
19
20impl Module for TimestepEmbedding {
21    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
22        let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?;
23        self.linear_2.forward(&xs)
24    }
25}
26
27#[derive(Debug)]
28pub struct Timesteps {
29    num_channels: usize,
30    flip_sin_to_cos: bool,
31    downscale_freq_shift: f64,
32}
33
34impl Timesteps {
35    pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self {
36        Self {
37            num_channels,
38            flip_sin_to_cos,
39            downscale_freq_shift,
40        }
41    }
42}
43
44impl Module for Timesteps {
45    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
46        let half_dim = (self.num_channels / 2) as u32;
47        let exponent = (Tensor::arange(0, half_dim, xs.device())?.to_dtype(candle::DType::F32)?
48            * -f64::ln(10000.))?;
49        let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
50        let emb = exponent.exp()?.to_dtype(xs.dtype())?;
51        // emb = timesteps[:, None].float() * emb[None, :]
52        let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
53        let (cos, sin) = (emb.cos()?, emb.sin()?);
54        let emb = if self.flip_sin_to_cos {
55            Tensor::cat(&[&cos, &sin], D::Minus1)?
56        } else {
57            Tensor::cat(&[&sin, &cos], D::Minus1)?
58        };
59        if self.num_channels % 2 == 1 {
60            emb.pad_with_zeros(D::Minus2, 0, 1)
61        } else {
62            Ok(emb)
63        }
64    }
65}