Skip to main content

ferrotorch_diffusion/
time_embedding.rs

1//! Time-step sinusoidal positional encoding + the MLP that follows it.
2//!
3//! These match `diffusers.models.embeddings.{Timesteps, TimestepEmbedding}`
4//! 1:1 for SD-1.5's settings (`flip_sin_to_cos = true`, `freq_shift = 0`).
5//!
6//! ```text
7//! Timesteps(C):           t -> [B, C]
8//!   half = C / 2
9//!   exponent = -log(max_period) * arange(half) / half
10//!   freqs = exp(exponent)
11//!   args  = t.float() * freqs
12//!   emb   = cat([cos(args), sin(args)], dim=-1)   (flip_sin_to_cos = true)
13//!
14//! TimestepEmbedding(C, time_emb_dim):
15//!   Linear(C, time_emb_dim) -> SiLU -> Linear(time_emb_dim, time_emb_dim)
16//! ```
17//!
18//! Diffusers' `Timesteps` is parameter-free — it's just an arithmetic
19//! recipe. We keep it as a `Module<T>` for ergonomic composition but
20//! `parameters()` is empty.
21
22use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
23use ferrotorch_nn::module::{Module, StateDict};
24use ferrotorch_nn::parameter::Parameter;
25use ferrotorch_nn::{Linear, SiLU};
26
27// ---------------------------------------------------------------------------
28// Timesteps (sinusoidal positional encoding)
29// ---------------------------------------------------------------------------
30
31/// `Timesteps` — sinusoidal positional encoding of a scalar timestep.
32///
33/// Parameter-free. Reproduces
34/// `diffusers.models.embeddings.Timesteps(num_channels, flip_sin_to_cos,
35/// downscale_freq_shift)` for `flip_sin_to_cos=true` and
36/// `downscale_freq_shift=0` (the SD-1.5 settings).
37#[derive(Debug, Clone)]
38pub struct Timesteps {
39    /// Output channel count (must be even). For SD UNet: 320.
40    pub num_channels: usize,
41    /// If true, `cat([cos, sin])` (SD-style). If false, `cat([sin, cos])`.
42    pub flip_sin_to_cos: bool,
43    /// `downscale_freq_shift` from diffusers (subtracted in the exponent
44    /// denominator). Always 0 for SD-1.5.
45    pub downscale_freq_shift: f64,
46    /// Maximum period of the sinusoid (diffusers default: 10000).
47    pub max_period: f64,
48}
49
50impl Timesteps {
51    /// Build a `Timesteps` module.
52    ///
53    /// # Errors
54    ///
55    /// Returns [`FerrotorchError::InvalidArgument`] when `num_channels`
56    /// is not a positive even integer.
57    pub fn new(
58        num_channels: usize,
59        flip_sin_to_cos: bool,
60        downscale_freq_shift: f64,
61    ) -> FerrotorchResult<Self> {
62        if num_channels == 0 || num_channels % 2 != 0 {
63            return Err(FerrotorchError::InvalidArgument {
64                message: format!(
65                    "Timesteps::new: num_channels must be a positive even integer, got {num_channels}"
66                ),
67            });
68        }
69        Ok(Self {
70            num_channels,
71            flip_sin_to_cos,
72            downscale_freq_shift,
73            max_period: 10_000.0,
74        })
75    }
76
77    /// Compute the sinusoidal encoding for a batch of timesteps.
78    ///
79    /// `timesteps` has shape `[B]` (1-D, dtype `T`). The output has
80    /// shape `[B, num_channels]`.
81    ///
82    /// # Errors
83    ///
84    /// Returns [`FerrotorchError::ShapeMismatch`] when the input is not
85    /// rank-1, [`FerrotorchError::InvalidArgument`] if the half-channel
86    /// math overflows `T`.
87    pub fn forward_t<T: Float>(&self, timesteps: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
88        if timesteps.ndim() != 1 {
89            return Err(FerrotorchError::ShapeMismatch {
90                message: format!(
91                    "Timesteps::forward_t: expected 1-D timesteps [B], got {:?}",
92                    timesteps.shape()
93                ),
94            });
95        }
96        let batch = timesteps.shape()[0];
97        let half = self.num_channels / 2;
98        // exponent = -ln(max_period) * i / (half - downscale_freq_shift),
99        // for i in 0..half. half - downscale_freq_shift defaults to half
100        // (downscale_freq_shift = 0 for SD-1.5).
101        let denom = (half as f64) - self.downscale_freq_shift;
102        if denom <= 0.0 {
103            return Err(FerrotorchError::InvalidArgument {
104                message: format!(
105                    "Timesteps::forward_t: invalid denominator {denom} (half={half}, \
106                     downscale_freq_shift={})",
107                    self.downscale_freq_shift,
108                ),
109            });
110        }
111        let log_max = self.max_period.ln();
112        let mut freqs = Vec::with_capacity(half);
113        for i in 0..half {
114            let exponent = -log_max * (i as f64) / denom;
115            freqs.push(exponent.exp());
116        }
117        // Read the timestep values into f64 for the multiply, then cast
118        // back to T for the sin/cos call.
119        let ts_data = timesteps.data()?;
120        let zero_t = T::from(0.0).ok_or_else(|| FerrotorchError::InvalidArgument {
121            message: "Timesteps::forward_t: failed to cast 0.0 into Float".into(),
122        })?;
123        let mut out = vec![zero_t; batch * self.num_channels];
124        for (b, &t) in ts_data.iter().enumerate() {
125            let t_f64: f64 = t.to_f64().ok_or_else(|| FerrotorchError::InvalidArgument {
126                message: "Timesteps::forward_t: failed to cast timestep into f64".into(),
127            })?;
128            for (i, &freq) in freqs.iter().enumerate() {
129                let arg = t_f64 * freq;
130                let cos_v = arg.cos();
131                let sin_v = arg.sin();
132                let (left, right) = if self.flip_sin_to_cos {
133                    (cos_v, sin_v)
134                } else {
135                    (sin_v, cos_v)
136                };
137                out[b * self.num_channels + i] =
138                    T::from(left).ok_or_else(|| FerrotorchError::InvalidArgument {
139                        message: "Timesteps: cast left value to T failed".into(),
140                    })?;
141                out[b * self.num_channels + half + i] =
142                    T::from(right).ok_or_else(|| FerrotorchError::InvalidArgument {
143                        message: "Timesteps: cast right value to T failed".into(),
144                    })?;
145            }
146        }
147        Tensor::from_storage(
148            TensorStorage::cpu(out),
149            vec![batch, self.num_channels],
150            false,
151        )
152    }
153}
154
155// `Module` impl: forward expects a 1-D timestep tensor. Parameter-free.
156impl<T: Float> Module<T> for Timesteps {
157    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
158        self.forward_t(input)
159    }
160    fn parameters(&self) -> Vec<&Parameter<T>> {
161        Vec::new()
162    }
163    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
164        Vec::new()
165    }
166    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
167        Vec::new()
168    }
169    fn train(&mut self) {
170        // Timesteps is parameter-free and stateless — no training-mode
171        // flag to flip. Mirrors diffusers' `Timesteps` (which has no
172        // submodules / parameters / buffers; PyTorch nn.Module.train()
173        // recurses over children that don't exist here).
174    }
175    fn eval(&mut self) {
176        // Same as `train` — nothing to switch.
177    }
178    fn is_training(&self) -> bool {
179        // Always-inference (deterministic arithmetic, no dropout / norm
180        // running stats).
181        false
182    }
183    fn load_state_dict(&mut self, _state: &StateDict<T>, _strict: bool) -> FerrotorchResult<()> {
184        Ok(())
185    }
186}
187
188// ---------------------------------------------------------------------------
189// TimestepEmbedding (MLP)
190// ---------------------------------------------------------------------------
191
192/// `TimestepEmbedding` — `Linear -> SiLU -> Linear` applied to the
193/// sinusoidal encoding.
194///
195/// State-dict key layout (matches diffusers):
196///
197/// ```text
198/// linear_1.{weight,bias}    [time_emb_dim, in_channels]
199/// linear_2.{weight,bias}    [time_emb_dim, time_emb_dim]
200/// ```
201#[derive(Debug)]
202pub struct TimestepEmbedding<T: Float> {
203    /// First linear `in_channels -> time_emb_dim`.
204    pub linear_1: Linear<T>,
205    /// Output linear `time_emb_dim -> time_emb_dim`.
206    pub linear_2: Linear<T>,
207    activation: SiLU,
208    training: bool,
209}
210
211impl<T: Float> TimestepEmbedding<T> {
212    /// Build a randomly-initialized `TimestepEmbedding`.
213    ///
214    /// # Errors
215    ///
216    /// Returns the underlying [`FerrotorchError`] on bad dims.
217    pub fn new(in_channels: usize, time_emb_dim: usize) -> FerrotorchResult<Self> {
218        let linear_1 = Linear::<T>::new(in_channels, time_emb_dim, true)?;
219        let linear_2 = Linear::<T>::new(time_emb_dim, time_emb_dim, true)?;
220        Ok(Self {
221            linear_1,
222            linear_2,
223            activation: SiLU::new(),
224            training: false,
225        })
226    }
227}
228
229impl<T: Float> Module<T> for TimestepEmbedding<T> {
230    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
231        let h = self.linear_1.forward(input)?;
232        let h = self.activation.forward(&h)?;
233        self.linear_2.forward(&h)
234    }
235
236    fn parameters(&self) -> Vec<&Parameter<T>> {
237        let mut o = Vec::new();
238        o.extend(self.linear_1.parameters());
239        o.extend(self.linear_2.parameters());
240        o
241    }
242    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
243        let mut o = Vec::new();
244        o.extend(self.linear_1.parameters_mut());
245        o.extend(self.linear_2.parameters_mut());
246        o
247    }
248    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
249        let mut o = Vec::new();
250        for (n, p) in self.linear_1.named_parameters() {
251            o.push((format!("linear_1.{n}"), p));
252        }
253        for (n, p) in self.linear_2.named_parameters() {
254            o.push((format!("linear_2.{n}"), p));
255        }
256        o
257    }
258
259    fn train(&mut self) {
260        self.training = true;
261    }
262    fn eval(&mut self) {
263        self.training = false;
264    }
265    fn is_training(&self) -> bool {
266        self.training
267    }
268
269    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
270        let extract = |prefix: &str| -> StateDict<T> {
271            let p = format!("{prefix}.");
272            state
273                .iter()
274                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
275                .collect()
276        };
277        if strict {
278            for k in state.keys() {
279                if !(k.starts_with("linear_1.") || k.starts_with("linear_2.")) {
280                    return Err(FerrotorchError::InvalidArgument {
281                        message: format!("unexpected key in TimestepEmbedding state_dict: \"{k}\""),
282                    });
283                }
284            }
285        }
286        self.linear_1
287            .load_state_dict(&extract("linear_1"), strict)?;
288        self.linear_2
289            .load_state_dict(&extract("linear_2"), strict)?;
290        Ok(())
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn timesteps_shape_flip_true() {
300        let t = Timesteps::new(8, true, 0.0).unwrap();
301        let ts = Tensor::from_storage(
302            TensorStorage::cpu(vec![0.0f32, 50.0, 100.0]),
303            vec![3],
304            false,
305        )
306        .unwrap();
307        let e = t.forward_t(&ts).unwrap();
308        assert_eq!(e.shape(), &[3, 8]);
309        // For t=0, cos=1 and sin=0 for all freqs => first half ones, second half zeros.
310        let d = e.data().unwrap();
311        for i in 0..4 {
312            assert!((d[i] - 1.0).abs() < 1e-6);
313        }
314        for i in 4..8 {
315            assert!(d[i].abs() < 1e-6);
316        }
317    }
318
319    #[test]
320    fn timesteps_rejects_odd_channels() {
321        assert!(Timesteps::new(7, true, 0.0).is_err());
322    }
323
324    #[test]
325    fn timestep_embedding_shapes() {
326        let mlp = TimestepEmbedding::<f32>::new(8, 16).unwrap();
327        let x =
328            Tensor::from_storage(TensorStorage::cpu(vec![0.5f32; 8]), vec![1, 8], false).unwrap();
329        let y = mlp.forward(&x).unwrap();
330        assert_eq!(y.shape(), &[1, 16]);
331    }
332
333    #[test]
334    fn timestep_embedding_named_parameters() {
335        let mlp = TimestepEmbedding::<f32>::new(8, 16).unwrap();
336        let names: Vec<String> = mlp.named_parameters().into_iter().map(|(n, _)| n).collect();
337        for k in [
338            "linear_1.weight",
339            "linear_1.bias",
340            "linear_2.weight",
341            "linear_2.bias",
342        ] {
343            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
344        }
345    }
346}