Skip to main content

ferrotorch_diffusion/
time_embedding.rs

1//! Time-step sinusoidal positional encoding + the MLP that follows it.
2//!
3//! These match `diffusers.models.embeddings.{Timesteps, TimestepEmbedding}`
4//! 1:1 for SD-1.5's settings (`flip_sin_to_cos = true`, `freq_shift = 0`).
5//!
6//! ```text
7//! Timesteps(C):           t -> [B, C]
8//!   half = C / 2
9//!   exponent = -log(max_period) * arange(half) / half
10//!   freqs = exp(exponent)
11//!   args  = t.float() * freqs
12//!   emb   = cat([cos(args), sin(args)], dim=-1)   (flip_sin_to_cos = true)
13//!
14//! TimestepEmbedding(C, time_emb_dim):
15//!   Linear(C, time_emb_dim) -> SiLU -> Linear(time_emb_dim, time_emb_dim)
16//! ```
17//!
18//! Diffusers' `Timesteps` is parameter-free — it's just an arithmetic
19//! recipe. We keep it as a `Module<T>` for ergonomic composition but
20//! `parameters()` is empty.
21//!
22//! ## REQ status (per `.design/ferrotorch-diffusion/time_embedding.md`)
23//!
24//! | REQ | Status | Evidence |
25//! |---|---|---|
26//! | REQ-1 | SHIPPED | `Timesteps::new` at `ferrotorch-diffusion/src/time_embedding.rs:57..75`; consumer: `ferrotorch-diffusion/src/unet.rs:50` imports `Timesteps` for `UNet2DConditionModel::new` |
27//! | REQ-2 | SHIPPED | `Timesteps::forward_t` at `ferrotorch-diffusion/src/time_embedding.rs:87..152`; consumer: `ferrotorch-diffusion/src/unet.rs` invokes the time-projection forward path |
28//! | REQ-3 | SHIPPED | `TimestepEmbedding::new` at `ferrotorch-diffusion/src/time_embedding.rs:217..226`; consumer: `ferrotorch-diffusion/src/unet.rs:50` builds the MLP inside `UNet2DConditionModel::new` |
29//! | REQ-4 | SHIPPED | `named_parameters` at `ferrotorch-diffusion/src/time_embedding.rs:248..257`; consumer: `ferrotorch-diffusion/src/safetensors_loader.rs:151..175` routes HF UNet keys through the layout |
30//! | REQ-5 | SHIPPED | `Module<T> for Timesteps` empty-parameter impl at `ferrotorch-diffusion/src/time_embedding.rs:156..186`; consumer: `ferrotorch-diffusion/src/unet.rs` enumerates parameters without `Timesteps` contributing |
31
32use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
33use ferrotorch_nn::module::{Module, StateDict};
34use ferrotorch_nn::parameter::Parameter;
35use ferrotorch_nn::{Linear, SiLU};
36
37// ---------------------------------------------------------------------------
38// Timesteps (sinusoidal positional encoding)
39// ---------------------------------------------------------------------------
40
41/// `Timesteps` — sinusoidal positional encoding of a scalar timestep.
42///
43/// Parameter-free. Reproduces
44/// `diffusers.models.embeddings.Timesteps(num_channels, flip_sin_to_cos,
45/// downscale_freq_shift)` for `flip_sin_to_cos=true` and
46/// `downscale_freq_shift=0` (the SD-1.5 settings).
47#[derive(Debug, Clone)]
48pub struct Timesteps {
49    /// Output channel count (must be even). For SD UNet: 320.
50    pub num_channels: usize,
51    /// If true, `cat([cos, sin])` (SD-style). If false, `cat([sin, cos])`.
52    pub flip_sin_to_cos: bool,
53    /// `downscale_freq_shift` from diffusers (subtracted in the exponent
54    /// denominator). Always 0 for SD-1.5.
55    pub downscale_freq_shift: f64,
56    /// Maximum period of the sinusoid (diffusers default: 10000).
57    pub max_period: f64,
58}
59
60impl Timesteps {
61    /// Build a `Timesteps` module.
62    ///
63    /// # Errors
64    ///
65    /// Returns [`FerrotorchError::InvalidArgument`] when `num_channels`
66    /// is not a positive even integer.
67    pub fn new(
68        num_channels: usize,
69        flip_sin_to_cos: bool,
70        downscale_freq_shift: f64,
71    ) -> FerrotorchResult<Self> {
72        if num_channels == 0 || num_channels % 2 != 0 {
73            return Err(FerrotorchError::InvalidArgument {
74                message: format!(
75                    "Timesteps::new: num_channels must be a positive even integer, got {num_channels}"
76                ),
77            });
78        }
79        Ok(Self {
80            num_channels,
81            flip_sin_to_cos,
82            downscale_freq_shift,
83            max_period: 10_000.0,
84        })
85    }
86
87    /// Compute the sinusoidal encoding for a batch of timesteps.
88    ///
89    /// `timesteps` has shape `[B]` (1-D, dtype `T`). The output has
90    /// shape `[B, num_channels]`.
91    ///
92    /// # Errors
93    ///
94    /// Returns [`FerrotorchError::ShapeMismatch`] when the input is not
95    /// rank-1, [`FerrotorchError::InvalidArgument`] if the half-channel
96    /// math overflows `T`.
97    pub fn forward_t<T: Float>(&self, timesteps: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
98        if timesteps.ndim() != 1 {
99            return Err(FerrotorchError::ShapeMismatch {
100                message: format!(
101                    "Timesteps::forward_t: expected 1-D timesteps [B], got {:?}",
102                    timesteps.shape()
103                ),
104            });
105        }
106        let batch = timesteps.shape()[0];
107        let half = self.num_channels / 2;
108        // exponent = -ln(max_period) * i / (half - downscale_freq_shift),
109        // for i in 0..half. half - downscale_freq_shift defaults to half
110        // (downscale_freq_shift = 0 for SD-1.5).
111        let denom = (half as f64) - self.downscale_freq_shift;
112        if denom <= 0.0 {
113            return Err(FerrotorchError::InvalidArgument {
114                message: format!(
115                    "Timesteps::forward_t: invalid denominator {denom} (half={half}, \
116                     downscale_freq_shift={})",
117                    self.downscale_freq_shift,
118                ),
119            });
120        }
121        let log_max = self.max_period.ln();
122        let mut freqs = Vec::with_capacity(half);
123        for i in 0..half {
124            let exponent = -log_max * (i as f64) / denom;
125            freqs.push(exponent.exp());
126        }
127        // Read the timestep values into f64 for the multiply, then cast
128        // back to T for the sin/cos call.
129        let ts_data = timesteps.data()?;
130        let zero_t = T::from(0.0).ok_or_else(|| FerrotorchError::InvalidArgument {
131            message: "Timesteps::forward_t: failed to cast 0.0 into Float".into(),
132        })?;
133        let mut out = vec![zero_t; batch * self.num_channels];
134        for (b, &t) in ts_data.iter().enumerate() {
135            let t_f64: f64 = t.to_f64().ok_or_else(|| FerrotorchError::InvalidArgument {
136                message: "Timesteps::forward_t: failed to cast timestep into f64".into(),
137            })?;
138            for (i, &freq) in freqs.iter().enumerate() {
139                let arg = t_f64 * freq;
140                let cos_v = arg.cos();
141                let sin_v = arg.sin();
142                let (left, right) = if self.flip_sin_to_cos {
143                    (cos_v, sin_v)
144                } else {
145                    (sin_v, cos_v)
146                };
147                out[b * self.num_channels + i] =
148                    T::from(left).ok_or_else(|| FerrotorchError::InvalidArgument {
149                        message: "Timesteps: cast left value to T failed".into(),
150                    })?;
151                out[b * self.num_channels + half + i] =
152                    T::from(right).ok_or_else(|| FerrotorchError::InvalidArgument {
153                        message: "Timesteps: cast right value to T failed".into(),
154                    })?;
155            }
156        }
157        Tensor::from_storage(
158            TensorStorage::cpu(out),
159            vec![batch, self.num_channels],
160            false,
161        )
162    }
163}
164
165// `Module` impl: forward expects a 1-D timestep tensor. Parameter-free.
166impl<T: Float> Module<T> for Timesteps {
167    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
168        self.forward_t(input)
169    }
170    fn parameters(&self) -> Vec<&Parameter<T>> {
171        Vec::new()
172    }
173    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
174        Vec::new()
175    }
176    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
177        Vec::new()
178    }
179    fn train(&mut self) {
180        // Timesteps is parameter-free and stateless — no training-mode
181        // flag to flip. Mirrors diffusers' `Timesteps` (which has no
182        // submodules / parameters / buffers; PyTorch nn.Module.train()
183        // recurses over children that don't exist here).
184    }
185    fn eval(&mut self) {
186        // Same as `train` — nothing to switch.
187    }
188    fn is_training(&self) -> bool {
189        // Always-inference (deterministic arithmetic, no dropout / norm
190        // running stats).
191        false
192    }
193    fn load_state_dict(&mut self, _state: &StateDict<T>, _strict: bool) -> FerrotorchResult<()> {
194        Ok(())
195    }
196}
197
198// ---------------------------------------------------------------------------
199// TimestepEmbedding (MLP)
200// ---------------------------------------------------------------------------
201
202/// `TimestepEmbedding` — `Linear -> SiLU -> Linear` applied to the
203/// sinusoidal encoding.
204///
205/// State-dict key layout (matches diffusers):
206///
207/// ```text
208/// linear_1.{weight,bias}    [time_emb_dim, in_channels]
209/// linear_2.{weight,bias}    [time_emb_dim, time_emb_dim]
210/// ```
211#[derive(Debug)]
212pub struct TimestepEmbedding<T: Float> {
213    /// First linear `in_channels -> time_emb_dim`.
214    pub linear_1: Linear<T>,
215    /// Output linear `time_emb_dim -> time_emb_dim`.
216    pub linear_2: Linear<T>,
217    activation: SiLU,
218    training: bool,
219}
220
221impl<T: Float> TimestepEmbedding<T> {
222    /// Build a randomly-initialized `TimestepEmbedding`.
223    ///
224    /// # Errors
225    ///
226    /// Returns the underlying [`FerrotorchError`] on bad dims.
227    pub fn new(in_channels: usize, time_emb_dim: usize) -> FerrotorchResult<Self> {
228        let linear_1 = Linear::<T>::new(in_channels, time_emb_dim, true)?;
229        let linear_2 = Linear::<T>::new(time_emb_dim, time_emb_dim, true)?;
230        Ok(Self {
231            linear_1,
232            linear_2,
233            activation: SiLU::new(),
234            training: false,
235        })
236    }
237}
238
239impl<T: Float> Module<T> for TimestepEmbedding<T> {
240    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
241        let h = self.linear_1.forward(input)?;
242        let h = self.activation.forward(&h)?;
243        self.linear_2.forward(&h)
244    }
245
246    fn parameters(&self) -> Vec<&Parameter<T>> {
247        let mut o = Vec::new();
248        o.extend(self.linear_1.parameters());
249        o.extend(self.linear_2.parameters());
250        o
251    }
252    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
253        let mut o = Vec::new();
254        o.extend(self.linear_1.parameters_mut());
255        o.extend(self.linear_2.parameters_mut());
256        o
257    }
258    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
259        let mut o = Vec::new();
260        for (n, p) in self.linear_1.named_parameters() {
261            o.push((format!("linear_1.{n}"), p));
262        }
263        for (n, p) in self.linear_2.named_parameters() {
264            o.push((format!("linear_2.{n}"), p));
265        }
266        o
267    }
268
269    fn train(&mut self) {
270        self.training = true;
271    }
272    fn eval(&mut self) {
273        self.training = false;
274    }
275    fn is_training(&self) -> bool {
276        self.training
277    }
278
279    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
280        let extract = |prefix: &str| -> StateDict<T> {
281            let p = format!("{prefix}.");
282            state
283                .iter()
284                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
285                .collect()
286        };
287        if strict {
288            for k in state.keys() {
289                if !(k.starts_with("linear_1.") || k.starts_with("linear_2.")) {
290                    return Err(FerrotorchError::InvalidArgument {
291                        message: format!("unexpected key in TimestepEmbedding state_dict: \"{k}\""),
292                    });
293                }
294            }
295        }
296        self.linear_1
297            .load_state_dict(&extract("linear_1"), strict)?;
298        self.linear_2
299            .load_state_dict(&extract("linear_2"), strict)?;
300        Ok(())
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn timesteps_shape_flip_true() {
310        let t = Timesteps::new(8, true, 0.0).unwrap();
311        let ts = Tensor::from_storage(
312            TensorStorage::cpu(vec![0.0f32, 50.0, 100.0]),
313            vec![3],
314            false,
315        )
316        .unwrap();
317        let e = t.forward_t(&ts).unwrap();
318        assert_eq!(e.shape(), &[3, 8]);
319        // For t=0, cos=1 and sin=0 for all freqs => first half ones, second half zeros.
320        let d = e.data().unwrap();
321        for i in 0..4 {
322            assert!((d[i] - 1.0).abs() < 1e-6);
323        }
324        for i in 4..8 {
325            assert!(d[i].abs() < 1e-6);
326        }
327    }
328
329    #[test]
330    fn timesteps_rejects_odd_channels() {
331        assert!(Timesteps::new(7, true, 0.0).is_err());
332    }
333
334    #[test]
335    fn timestep_embedding_shapes() {
336        let mlp = TimestepEmbedding::<f32>::new(8, 16).unwrap();
337        let x =
338            Tensor::from_storage(TensorStorage::cpu(vec![0.5f32; 8]), vec![1, 8], false).unwrap();
339        let y = mlp.forward(&x).unwrap();
340        assert_eq!(y.shape(), &[1, 16]);
341    }
342
343    #[test]
344    fn timestep_embedding_named_parameters() {
345        let mlp = TimestepEmbedding::<f32>::new(8, 16).unwrap();
346        let names: Vec<String> = mlp.named_parameters().into_iter().map(|(n, _)| n).collect();
347        for k in [
348            "linear_1.weight",
349            "linear_1.bias",
350            "linear_2.weight",
351            "linear_2.bias",
352        ] {
353            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
354        }
355    }
356}