candle_transformers/models/stable_diffusion/
embeddings.rs1use candle::{Result, Tensor, D};
2use candle_nn as nn;
3use candle_nn::Module;
4
5#[derive(Debug)]
6pub struct TimestepEmbedding {
7 linear_1: nn::Linear,
8 linear_2: nn::Linear,
9}
10
11impl TimestepEmbedding {
12 pub fn new(vs: nn::VarBuilder, channel: usize, time_embed_dim: usize) -> Result<Self> {
14 let linear_1 = nn::linear(channel, time_embed_dim, vs.pp("linear_1"))?;
15 let linear_2 = nn::linear(time_embed_dim, time_embed_dim, vs.pp("linear_2"))?;
16 Ok(Self { linear_1, linear_2 })
17 }
18}
19
20impl Module for TimestepEmbedding {
21 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
22 let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?;
23 self.linear_2.forward(&xs)
24 }
25}
26
27#[derive(Debug)]
28pub struct Timesteps {
29 num_channels: usize,
30 flip_sin_to_cos: bool,
31 downscale_freq_shift: f64,
32}
33
34impl Timesteps {
35 pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self {
36 Self {
37 num_channels,
38 flip_sin_to_cos,
39 downscale_freq_shift,
40 }
41 }
42}
43
44impl Module for Timesteps {
45 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
46 let half_dim = (self.num_channels / 2) as u32;
47 let exponent = (Tensor::arange(0, half_dim, xs.device())?.to_dtype(candle::DType::F32)?
48 * -f64::ln(10000.))?;
49 let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
50 let emb = exponent.exp()?.to_dtype(xs.dtype())?;
51 let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
53 let (cos, sin) = (emb.cos()?, emb.sin()?);
54 let emb = if self.flip_sin_to_cos {
55 Tensor::cat(&[&cos, &sin], D::Minus1)?
56 } else {
57 Tensor::cat(&[&sin, &cos], D::Minus1)?
58 };
59 if self.num_channels % 2 == 1 {
60 emb.pad_with_zeros(D::Minus2, 0, 1)
61 } else {
62 Ok(emb)
63 }
64 }
65}