1use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
33use ferrotorch_nn::module::{Module, StateDict};
34use ferrotorch_nn::parameter::Parameter;
35use ferrotorch_nn::{Linear, SiLU};
36
37#[derive(Debug, Clone)]
48pub struct Timesteps {
49 pub num_channels: usize,
51 pub flip_sin_to_cos: bool,
53 pub downscale_freq_shift: f64,
56 pub max_period: f64,
58}
59
60impl Timesteps {
61 pub fn new(
68 num_channels: usize,
69 flip_sin_to_cos: bool,
70 downscale_freq_shift: f64,
71 ) -> FerrotorchResult<Self> {
72 if num_channels == 0 || num_channels % 2 != 0 {
73 return Err(FerrotorchError::InvalidArgument {
74 message: format!(
75 "Timesteps::new: num_channels must be a positive even integer, got {num_channels}"
76 ),
77 });
78 }
79 Ok(Self {
80 num_channels,
81 flip_sin_to_cos,
82 downscale_freq_shift,
83 max_period: 10_000.0,
84 })
85 }
86
87 pub fn forward_t<T: Float>(&self, timesteps: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
98 if timesteps.ndim() != 1 {
99 return Err(FerrotorchError::ShapeMismatch {
100 message: format!(
101 "Timesteps::forward_t: expected 1-D timesteps [B], got {:?}",
102 timesteps.shape()
103 ),
104 });
105 }
106 let batch = timesteps.shape()[0];
107 let half = self.num_channels / 2;
108 let denom = (half as f64) - self.downscale_freq_shift;
112 if denom <= 0.0 {
113 return Err(FerrotorchError::InvalidArgument {
114 message: format!(
115 "Timesteps::forward_t: invalid denominator {denom} (half={half}, \
116 downscale_freq_shift={})",
117 self.downscale_freq_shift,
118 ),
119 });
120 }
121 let log_max = self.max_period.ln();
122 let mut freqs = Vec::with_capacity(half);
123 for i in 0..half {
124 let exponent = -log_max * (i as f64) / denom;
125 freqs.push(exponent.exp());
126 }
127 let ts_data = timesteps.data()?;
130 let zero_t = T::from(0.0).ok_or_else(|| FerrotorchError::InvalidArgument {
131 message: "Timesteps::forward_t: failed to cast 0.0 into Float".into(),
132 })?;
133 let mut out = vec![zero_t; batch * self.num_channels];
134 for (b, &t) in ts_data.iter().enumerate() {
135 let t_f64: f64 = t.to_f64().ok_or_else(|| FerrotorchError::InvalidArgument {
136 message: "Timesteps::forward_t: failed to cast timestep into f64".into(),
137 })?;
138 for (i, &freq) in freqs.iter().enumerate() {
139 let arg = t_f64 * freq;
140 let cos_v = arg.cos();
141 let sin_v = arg.sin();
142 let (left, right) = if self.flip_sin_to_cos {
143 (cos_v, sin_v)
144 } else {
145 (sin_v, cos_v)
146 };
147 out[b * self.num_channels + i] =
148 T::from(left).ok_or_else(|| FerrotorchError::InvalidArgument {
149 message: "Timesteps: cast left value to T failed".into(),
150 })?;
151 out[b * self.num_channels + half + i] =
152 T::from(right).ok_or_else(|| FerrotorchError::InvalidArgument {
153 message: "Timesteps: cast right value to T failed".into(),
154 })?;
155 }
156 }
157 Tensor::from_storage(
158 TensorStorage::cpu(out),
159 vec![batch, self.num_channels],
160 false,
161 )
162 }
163}
164
165impl<T: Float> Module<T> for Timesteps {
167 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
168 self.forward_t(input)
169 }
170 fn parameters(&self) -> Vec<&Parameter<T>> {
171 Vec::new()
172 }
173 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
174 Vec::new()
175 }
176 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
177 Vec::new()
178 }
179 fn train(&mut self) {
180 }
185 fn eval(&mut self) {
186 }
188 fn is_training(&self) -> bool {
189 false
192 }
193 fn load_state_dict(&mut self, _state: &StateDict<T>, _strict: bool) -> FerrotorchResult<()> {
194 Ok(())
195 }
196}
197
198#[derive(Debug)]
212pub struct TimestepEmbedding<T: Float> {
213 pub linear_1: Linear<T>,
215 pub linear_2: Linear<T>,
217 activation: SiLU,
218 training: bool,
219}
220
221impl<T: Float> TimestepEmbedding<T> {
222 pub fn new(in_channels: usize, time_emb_dim: usize) -> FerrotorchResult<Self> {
228 let linear_1 = Linear::<T>::new(in_channels, time_emb_dim, true)?;
229 let linear_2 = Linear::<T>::new(time_emb_dim, time_emb_dim, true)?;
230 Ok(Self {
231 linear_1,
232 linear_2,
233 activation: SiLU::new(),
234 training: false,
235 })
236 }
237}
238
239impl<T: Float> Module<T> for TimestepEmbedding<T> {
240 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
241 let h = self.linear_1.forward(input)?;
242 let h = self.activation.forward(&h)?;
243 self.linear_2.forward(&h)
244 }
245
246 fn parameters(&self) -> Vec<&Parameter<T>> {
247 let mut o = Vec::new();
248 o.extend(self.linear_1.parameters());
249 o.extend(self.linear_2.parameters());
250 o
251 }
252 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
253 let mut o = Vec::new();
254 o.extend(self.linear_1.parameters_mut());
255 o.extend(self.linear_2.parameters_mut());
256 o
257 }
258 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
259 let mut o = Vec::new();
260 for (n, p) in self.linear_1.named_parameters() {
261 o.push((format!("linear_1.{n}"), p));
262 }
263 for (n, p) in self.linear_2.named_parameters() {
264 o.push((format!("linear_2.{n}"), p));
265 }
266 o
267 }
268
269 fn train(&mut self) {
270 self.training = true;
271 }
272 fn eval(&mut self) {
273 self.training = false;
274 }
275 fn is_training(&self) -> bool {
276 self.training
277 }
278
279 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
280 let extract = |prefix: &str| -> StateDict<T> {
281 let p = format!("{prefix}.");
282 state
283 .iter()
284 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
285 .collect()
286 };
287 if strict {
288 for k in state.keys() {
289 if !(k.starts_with("linear_1.") || k.starts_with("linear_2.")) {
290 return Err(FerrotorchError::InvalidArgument {
291 message: format!("unexpected key in TimestepEmbedding state_dict: \"{k}\""),
292 });
293 }
294 }
295 }
296 self.linear_1
297 .load_state_dict(&extract("linear_1"), strict)?;
298 self.linear_2
299 .load_state_dict(&extract("linear_2"), strict)?;
300 Ok(())
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn timesteps_shape_flip_true() {
310 let t = Timesteps::new(8, true, 0.0).unwrap();
311 let ts = Tensor::from_storage(
312 TensorStorage::cpu(vec![0.0f32, 50.0, 100.0]),
313 vec![3],
314 false,
315 )
316 .unwrap();
317 let e = t.forward_t(&ts).unwrap();
318 assert_eq!(e.shape(), &[3, 8]);
319 let d = e.data().unwrap();
321 for i in 0..4 {
322 assert!((d[i] - 1.0).abs() < 1e-6);
323 }
324 for i in 4..8 {
325 assert!(d[i].abs() < 1e-6);
326 }
327 }
328
329 #[test]
330 fn timesteps_rejects_odd_channels() {
331 assert!(Timesteps::new(7, true, 0.0).is_err());
332 }
333
334 #[test]
335 fn timestep_embedding_shapes() {
336 let mlp = TimestepEmbedding::<f32>::new(8, 16).unwrap();
337 let x =
338 Tensor::from_storage(TensorStorage::cpu(vec![0.5f32; 8]), vec![1, 8], false).unwrap();
339 let y = mlp.forward(&x).unwrap();
340 assert_eq!(y.shape(), &[1, 16]);
341 }
342
343 #[test]
344 fn timestep_embedding_named_parameters() {
345 let mlp = TimestepEmbedding::<f32>::new(8, 16).unwrap();
346 let names: Vec<String> = mlp.named_parameters().into_iter().map(|(n, _)| n).collect();
347 for k in [
348 "linear_1.weight",
349 "linear_1.bias",
350 "linear_2.weight",
351 "linear_2.bias",
352 ] {
353 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
354 }
355 }
356}