Skip to main content

ferrotorch_diffusion/
time_embedding.rs

1//! Time-step sinusoidal positional encoding + the MLP that follows it.
2//!
3//! These match `diffusers.models.embeddings.{Timesteps, TimestepEmbedding}`
4//! 1:1 for SD-1.5's settings (`flip_sin_to_cos = true`, `freq_shift = 0`).
5//!
6//! ```text
7//! Timesteps(C):           t -> [B, C]
8//!   half = C / 2
9//!   exponent = -log(max_period) * arange(half) / half
10//!   freqs = exp(exponent)
11//!   args  = t.float() * freqs
12//!   emb   = cat([cos(args), sin(args)], dim=-1)   (flip_sin_to_cos = true)
13//!
14//! TimestepEmbedding(C, time_emb_dim):
15//!   Linear(C, time_emb_dim) -> SiLU -> Linear(time_emb_dim, time_emb_dim)
16//! ```
17//!
18//! Diffusers' `Timesteps` is parameter-free — it's just an arithmetic
19//! recipe. We keep it as a `Module<T>` for ergonomic composition but
20//! `parameters()` is empty.
21
22use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
23use ferrotorch_nn::module::{Module, StateDict};
24use ferrotorch_nn::parameter::Parameter;
25use ferrotorch_nn::{Linear, SiLU};
26
27// ---------------------------------------------------------------------------
28// Timesteps (sinusoidal positional encoding)
29// ---------------------------------------------------------------------------
30
31/// `Timesteps` — sinusoidal positional encoding of a scalar timestep.
32///
33/// Parameter-free. Reproduces
34/// `diffusers.models.embeddings.Timesteps(num_channels, flip_sin_to_cos,
35/// downscale_freq_shift)` for `flip_sin_to_cos=true` and
36/// `downscale_freq_shift=0` (the SD-1.5 settings).
37#[derive(Debug, Clone)]
38pub struct Timesteps {
39    /// Output channel count (must be even). For SD UNet: 320.
40    pub num_channels: usize,
41    /// If true, `cat([cos, sin])` (SD-style). If false, `cat([sin, cos])`.
42    pub flip_sin_to_cos: bool,
43    /// `downscale_freq_shift` from diffusers (subtracted in the exponent
44    /// denominator). Always 0 for SD-1.5.
45    pub downscale_freq_shift: f64,
46    /// Maximum period of the sinusoid (diffusers default: 10000).
47    pub max_period: f64,
48}
49
50impl Timesteps {
51    /// Build a `Timesteps` module.
52    ///
53    /// # Errors
54    ///
55    /// Returns [`FerrotorchError::InvalidArgument`] when `num_channels`
56    /// is not a positive even integer.
57    pub fn new(
58        num_channels: usize,
59        flip_sin_to_cos: bool,
60        downscale_freq_shift: f64,
61    ) -> FerrotorchResult<Self> {
62        if num_channels == 0 || num_channels % 2 != 0 {
63            return Err(FerrotorchError::InvalidArgument {
64                message: format!(
65                    "Timesteps::new: num_channels must be a positive even integer, got {num_channels}"
66                ),
67            });
68        }
69        Ok(Self {
70            num_channels,
71            flip_sin_to_cos,
72            downscale_freq_shift,
73            max_period: 10_000.0,
74        })
75    }
76
77    /// Compute the sinusoidal encoding for a batch of timesteps.
78    ///
79    /// `timesteps` has shape `[B]` (1-D, dtype `T`). The output has
80    /// shape `[B, num_channels]`.
81    ///
82    /// # Errors
83    ///
84    /// Returns [`FerrotorchError::ShapeMismatch`] when the input is not
85    /// rank-1, [`FerrotorchError::InvalidArgument`] if the half-channel
86    /// math overflows `T`.
87    pub fn forward_t<T: Float>(&self, timesteps: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
88        if timesteps.ndim() != 1 {
89            return Err(FerrotorchError::ShapeMismatch {
90                message: format!(
91                    "Timesteps::forward_t: expected 1-D timesteps [B], got {:?}",
92                    timesteps.shape()
93                ),
94            });
95        }
96        let batch = timesteps.shape()[0];
97        let half = self.num_channels / 2;
98        // exponent = -ln(max_period) * i / (half - downscale_freq_shift),
99        // for i in 0..half. half - downscale_freq_shift defaults to half
100        // (downscale_freq_shift = 0 for SD-1.5).
101        let denom = (half as f64) - self.downscale_freq_shift;
102        if denom <= 0.0 {
103            return Err(FerrotorchError::InvalidArgument {
104                message: format!(
105                    "Timesteps::forward_t: invalid denominator {denom} (half={half}, \
106                     downscale_freq_shift={})",
107                    self.downscale_freq_shift,
108                ),
109            });
110        }
111        let log_max = self.max_period.ln();
112        let mut freqs = Vec::with_capacity(half);
113        for i in 0..half {
114            let exponent = -log_max * (i as f64) / denom;
115            freqs.push(exponent.exp());
116        }
117        // Read the timestep values into f64 for the multiply, then cast
118        // back to T for the sin/cos call.
119        let ts_data = timesteps.data()?;
120        let zero_t = T::from(0.0).ok_or_else(|| FerrotorchError::InvalidArgument {
121            message: "Timesteps::forward_t: failed to cast 0.0 into Float".into(),
122        })?;
123        let mut out = vec![zero_t; batch * self.num_channels];
124        for (b, &t) in ts_data.iter().enumerate() {
125            let t_f64: f64 = t
126                .to_f64()
127                .ok_or_else(|| FerrotorchError::InvalidArgument {
128                    message: "Timesteps::forward_t: failed to cast timestep into f64".into(),
129                })?;
130            for (i, &freq) in freqs.iter().enumerate() {
131                let arg = t_f64 * freq;
132                let cos_v = arg.cos();
133                let sin_v = arg.sin();
134                let (left, right) = if self.flip_sin_to_cos {
135                    (cos_v, sin_v)
136                } else {
137                    (sin_v, cos_v)
138                };
139                out[b * self.num_channels + i] = T::from(left).ok_or_else(|| {
140                    FerrotorchError::InvalidArgument {
141                        message: "Timesteps: cast left value to T failed".into(),
142                    }
143                })?;
144                out[b * self.num_channels + half + i] = T::from(right).ok_or_else(|| {
145                    FerrotorchError::InvalidArgument {
146                        message: "Timesteps: cast right value to T failed".into(),
147                    }
148                })?;
149            }
150        }
151        Tensor::from_storage(
152            TensorStorage::cpu(out),
153            vec![batch, self.num_channels],
154            false,
155        )
156    }
157}
158
159// `Module` impl: forward expects a 1-D timestep tensor. Parameter-free.
160impl<T: Float> Module<T> for Timesteps {
161    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
162        self.forward_t(input)
163    }
164    fn parameters(&self) -> Vec<&Parameter<T>> {
165        Vec::new()
166    }
167    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
168        Vec::new()
169    }
170    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
171        Vec::new()
172    }
173    fn train(&mut self) {
174        // Timesteps is parameter-free and stateless — no training-mode
175        // flag to flip. Mirrors diffusers' `Timesteps` (which has no
176        // submodules / parameters / buffers; PyTorch nn.Module.train()
177        // recurses over children that don't exist here).
178    }
179    fn eval(&mut self) {
180        // Same as `train` — nothing to switch.
181    }
182    fn is_training(&self) -> bool {
183        // Always-inference (deterministic arithmetic, no dropout / norm
184        // running stats).
185        false
186    }
187    fn load_state_dict(&mut self, _state: &StateDict<T>, _strict: bool) -> FerrotorchResult<()> {
188        Ok(())
189    }
190}
191
192// ---------------------------------------------------------------------------
193// TimestepEmbedding (MLP)
194// ---------------------------------------------------------------------------
195
196/// `TimestepEmbedding` — `Linear -> SiLU -> Linear` applied to the
197/// sinusoidal encoding.
198///
199/// State-dict key layout (matches diffusers):
200///
201/// ```text
202/// linear_1.{weight,bias}    [time_emb_dim, in_channels]
203/// linear_2.{weight,bias}    [time_emb_dim, time_emb_dim]
204/// ```
205#[derive(Debug)]
206pub struct TimestepEmbedding<T: Float> {
207    /// First linear `in_channels -> time_emb_dim`.
208    pub linear_1: Linear<T>,
209    /// Output linear `time_emb_dim -> time_emb_dim`.
210    pub linear_2: Linear<T>,
211    activation: SiLU,
212    training: bool,
213}
214
215impl<T: Float> TimestepEmbedding<T> {
216    /// Build a randomly-initialized `TimestepEmbedding`.
217    ///
218    /// # Errors
219    ///
220    /// Returns the underlying [`FerrotorchError`] on bad dims.
221    pub fn new(in_channels: usize, time_emb_dim: usize) -> FerrotorchResult<Self> {
222        let linear_1 = Linear::<T>::new(in_channels, time_emb_dim, true)?;
223        let linear_2 = Linear::<T>::new(time_emb_dim, time_emb_dim, true)?;
224        Ok(Self {
225            linear_1,
226            linear_2,
227            activation: SiLU::new(),
228            training: false,
229        })
230    }
231}
232
233impl<T: Float> Module<T> for TimestepEmbedding<T> {
234    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
235        let h = self.linear_1.forward(input)?;
236        let h = self.activation.forward(&h)?;
237        self.linear_2.forward(&h)
238    }
239
240    fn parameters(&self) -> Vec<&Parameter<T>> {
241        let mut o = Vec::new();
242        o.extend(self.linear_1.parameters());
243        o.extend(self.linear_2.parameters());
244        o
245    }
246    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
247        let mut o = Vec::new();
248        o.extend(self.linear_1.parameters_mut());
249        o.extend(self.linear_2.parameters_mut());
250        o
251    }
252    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
253        let mut o = Vec::new();
254        for (n, p) in self.linear_1.named_parameters() {
255            o.push((format!("linear_1.{n}"), p));
256        }
257        for (n, p) in self.linear_2.named_parameters() {
258            o.push((format!("linear_2.{n}"), p));
259        }
260        o
261    }
262
263    fn train(&mut self) {
264        self.training = true;
265    }
266    fn eval(&mut self) {
267        self.training = false;
268    }
269    fn is_training(&self) -> bool {
270        self.training
271    }
272
273    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
274        let extract = |prefix: &str| -> StateDict<T> {
275            let p = format!("{prefix}.");
276            state
277                .iter()
278                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
279                .collect()
280        };
281        if strict {
282            for k in state.keys() {
283                if !(k.starts_with("linear_1.") || k.starts_with("linear_2.")) {
284                    return Err(FerrotorchError::InvalidArgument {
285                        message: format!(
286                            "unexpected key in TimestepEmbedding state_dict: \"{k}\""
287                        ),
288                    });
289                }
290            }
291        }
292        self.linear_1
293            .load_state_dict(&extract("linear_1"), strict)?;
294        self.linear_2
295            .load_state_dict(&extract("linear_2"), strict)?;
296        Ok(())
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn timesteps_shape_flip_true() {
306        let t = Timesteps::new(8, true, 0.0).unwrap();
307        let ts = Tensor::from_storage(
308            TensorStorage::cpu(vec![0.0f32, 50.0, 100.0]),
309            vec![3],
310            false,
311        )
312        .unwrap();
313        let e = t.forward_t(&ts).unwrap();
314        assert_eq!(e.shape(), &[3, 8]);
315        // For t=0, cos=1 and sin=0 for all freqs => first half ones, second half zeros.
316        let d = e.data().unwrap();
317        for i in 0..4 {
318            assert!((d[i] - 1.0).abs() < 1e-6);
319        }
320        for i in 4..8 {
321            assert!(d[i].abs() < 1e-6);
322        }
323    }
324
325    #[test]
326    fn timesteps_rejects_odd_channels() {
327        assert!(Timesteps::new(7, true, 0.0).is_err());
328    }
329
330    #[test]
331    fn timestep_embedding_shapes() {
332        let mlp = TimestepEmbedding::<f32>::new(8, 16).unwrap();
333        let x = Tensor::from_storage(
334            TensorStorage::cpu(vec![0.5f32; 8]),
335            vec![1, 8],
336            false,
337        )
338        .unwrap();
339        let y = mlp.forward(&x).unwrap();
340        assert_eq!(y.shape(), &[1, 16]);
341    }
342
343    #[test]
344    fn timestep_embedding_named_parameters() {
345        let mlp = TimestepEmbedding::<f32>::new(8, 16).unwrap();
346        let names: Vec<String> = mlp.named_parameters().into_iter().map(|(n, _)| n).collect();
347        for k in [
348            "linear_1.weight",
349            "linear_1.bias",
350            "linear_2.weight",
351            "linear_2.bias",
352        ] {
353            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
354        }
355    }
356}