1use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
23use ferrotorch_nn::module::{Module, StateDict};
24use ferrotorch_nn::parameter::Parameter;
25use ferrotorch_nn::{Linear, SiLU};
26
27#[derive(Debug, Clone)]
38pub struct Timesteps {
39 pub num_channels: usize,
41 pub flip_sin_to_cos: bool,
43 pub downscale_freq_shift: f64,
46 pub max_period: f64,
48}
49
50impl Timesteps {
51 pub fn new(
58 num_channels: usize,
59 flip_sin_to_cos: bool,
60 downscale_freq_shift: f64,
61 ) -> FerrotorchResult<Self> {
62 if num_channels == 0 || num_channels % 2 != 0 {
63 return Err(FerrotorchError::InvalidArgument {
64 message: format!(
65 "Timesteps::new: num_channels must be a positive even integer, got {num_channels}"
66 ),
67 });
68 }
69 Ok(Self {
70 num_channels,
71 flip_sin_to_cos,
72 downscale_freq_shift,
73 max_period: 10_000.0,
74 })
75 }
76
77 pub fn forward_t<T: Float>(&self, timesteps: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
88 if timesteps.ndim() != 1 {
89 return Err(FerrotorchError::ShapeMismatch {
90 message: format!(
91 "Timesteps::forward_t: expected 1-D timesteps [B], got {:?}",
92 timesteps.shape()
93 ),
94 });
95 }
96 let batch = timesteps.shape()[0];
97 let half = self.num_channels / 2;
98 let denom = (half as f64) - self.downscale_freq_shift;
102 if denom <= 0.0 {
103 return Err(FerrotorchError::InvalidArgument {
104 message: format!(
105 "Timesteps::forward_t: invalid denominator {denom} (half={half}, \
106 downscale_freq_shift={})",
107 self.downscale_freq_shift,
108 ),
109 });
110 }
111 let log_max = self.max_period.ln();
112 let mut freqs = Vec::with_capacity(half);
113 for i in 0..half {
114 let exponent = -log_max * (i as f64) / denom;
115 freqs.push(exponent.exp());
116 }
117 let ts_data = timesteps.data()?;
120 let zero_t = T::from(0.0).ok_or_else(|| FerrotorchError::InvalidArgument {
121 message: "Timesteps::forward_t: failed to cast 0.0 into Float".into(),
122 })?;
123 let mut out = vec![zero_t; batch * self.num_channels];
124 for (b, &t) in ts_data.iter().enumerate() {
125 let t_f64: f64 = t
126 .to_f64()
127 .ok_or_else(|| FerrotorchError::InvalidArgument {
128 message: "Timesteps::forward_t: failed to cast timestep into f64".into(),
129 })?;
130 for (i, &freq) in freqs.iter().enumerate() {
131 let arg = t_f64 * freq;
132 let cos_v = arg.cos();
133 let sin_v = arg.sin();
134 let (left, right) = if self.flip_sin_to_cos {
135 (cos_v, sin_v)
136 } else {
137 (sin_v, cos_v)
138 };
139 out[b * self.num_channels + i] = T::from(left).ok_or_else(|| {
140 FerrotorchError::InvalidArgument {
141 message: "Timesteps: cast left value to T failed".into(),
142 }
143 })?;
144 out[b * self.num_channels + half + i] = T::from(right).ok_or_else(|| {
145 FerrotorchError::InvalidArgument {
146 message: "Timesteps: cast right value to T failed".into(),
147 }
148 })?;
149 }
150 }
151 Tensor::from_storage(
152 TensorStorage::cpu(out),
153 vec![batch, self.num_channels],
154 false,
155 )
156 }
157}
158
159impl<T: Float> Module<T> for Timesteps {
161 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
162 self.forward_t(input)
163 }
164 fn parameters(&self) -> Vec<&Parameter<T>> {
165 Vec::new()
166 }
167 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
168 Vec::new()
169 }
170 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
171 Vec::new()
172 }
173 fn train(&mut self) {
174 }
179 fn eval(&mut self) {
180 }
182 fn is_training(&self) -> bool {
183 false
186 }
187 fn load_state_dict(&mut self, _state: &StateDict<T>, _strict: bool) -> FerrotorchResult<()> {
188 Ok(())
189 }
190}
191
192#[derive(Debug)]
206pub struct TimestepEmbedding<T: Float> {
207 pub linear_1: Linear<T>,
209 pub linear_2: Linear<T>,
211 activation: SiLU,
212 training: bool,
213}
214
215impl<T: Float> TimestepEmbedding<T> {
216 pub fn new(in_channels: usize, time_emb_dim: usize) -> FerrotorchResult<Self> {
222 let linear_1 = Linear::<T>::new(in_channels, time_emb_dim, true)?;
223 let linear_2 = Linear::<T>::new(time_emb_dim, time_emb_dim, true)?;
224 Ok(Self {
225 linear_1,
226 linear_2,
227 activation: SiLU::new(),
228 training: false,
229 })
230 }
231}
232
233impl<T: Float> Module<T> for TimestepEmbedding<T> {
234 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
235 let h = self.linear_1.forward(input)?;
236 let h = self.activation.forward(&h)?;
237 self.linear_2.forward(&h)
238 }
239
240 fn parameters(&self) -> Vec<&Parameter<T>> {
241 let mut o = Vec::new();
242 o.extend(self.linear_1.parameters());
243 o.extend(self.linear_2.parameters());
244 o
245 }
246 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
247 let mut o = Vec::new();
248 o.extend(self.linear_1.parameters_mut());
249 o.extend(self.linear_2.parameters_mut());
250 o
251 }
252 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
253 let mut o = Vec::new();
254 for (n, p) in self.linear_1.named_parameters() {
255 o.push((format!("linear_1.{n}"), p));
256 }
257 for (n, p) in self.linear_2.named_parameters() {
258 o.push((format!("linear_2.{n}"), p));
259 }
260 o
261 }
262
263 fn train(&mut self) {
264 self.training = true;
265 }
266 fn eval(&mut self) {
267 self.training = false;
268 }
269 fn is_training(&self) -> bool {
270 self.training
271 }
272
273 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
274 let extract = |prefix: &str| -> StateDict<T> {
275 let p = format!("{prefix}.");
276 state
277 .iter()
278 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
279 .collect()
280 };
281 if strict {
282 for k in state.keys() {
283 if !(k.starts_with("linear_1.") || k.starts_with("linear_2.")) {
284 return Err(FerrotorchError::InvalidArgument {
285 message: format!(
286 "unexpected key in TimestepEmbedding state_dict: \"{k}\""
287 ),
288 });
289 }
290 }
291 }
292 self.linear_1
293 .load_state_dict(&extract("linear_1"), strict)?;
294 self.linear_2
295 .load_state_dict(&extract("linear_2"), strict)?;
296 Ok(())
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn timesteps_shape_flip_true() {
306 let t = Timesteps::new(8, true, 0.0).unwrap();
307 let ts = Tensor::from_storage(
308 TensorStorage::cpu(vec![0.0f32, 50.0, 100.0]),
309 vec![3],
310 false,
311 )
312 .unwrap();
313 let e = t.forward_t(&ts).unwrap();
314 assert_eq!(e.shape(), &[3, 8]);
315 let d = e.data().unwrap();
317 for i in 0..4 {
318 assert!((d[i] - 1.0).abs() < 1e-6);
319 }
320 for i in 4..8 {
321 assert!(d[i].abs() < 1e-6);
322 }
323 }
324
325 #[test]
326 fn timesteps_rejects_odd_channels() {
327 assert!(Timesteps::new(7, true, 0.0).is_err());
328 }
329
330 #[test]
331 fn timestep_embedding_shapes() {
332 let mlp = TimestepEmbedding::<f32>::new(8, 16).unwrap();
333 let x = Tensor::from_storage(
334 TensorStorage::cpu(vec![0.5f32; 8]),
335 vec![1, 8],
336 false,
337 )
338 .unwrap();
339 let y = mlp.forward(&x).unwrap();
340 assert_eq!(y.shape(), &[1, 16]);
341 }
342
343 #[test]
344 fn timestep_embedding_named_parameters() {
345 let mlp = TimestepEmbedding::<f32>::new(8, 16).unwrap();
346 let names: Vec<String> = mlp.named_parameters().into_iter().map(|(n, _)| n).collect();
347 for k in [
348 "linear_1.weight",
349 "linear_1.bias",
350 "linear_2.weight",
351 "linear_2.bias",
352 ] {
353 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
354 }
355 }
356}