1use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
23use ferrotorch_nn::module::{Module, StateDict};
24use ferrotorch_nn::parameter::Parameter;
25use ferrotorch_nn::{Linear, SiLU};
26
27#[derive(Debug, Clone)]
38pub struct Timesteps {
39 pub num_channels: usize,
41 pub flip_sin_to_cos: bool,
43 pub downscale_freq_shift: f64,
46 pub max_period: f64,
48}
49
50impl Timesteps {
51 pub fn new(
58 num_channels: usize,
59 flip_sin_to_cos: bool,
60 downscale_freq_shift: f64,
61 ) -> FerrotorchResult<Self> {
62 if num_channels == 0 || num_channels % 2 != 0 {
63 return Err(FerrotorchError::InvalidArgument {
64 message: format!(
65 "Timesteps::new: num_channels must be a positive even integer, got {num_channels}"
66 ),
67 });
68 }
69 Ok(Self {
70 num_channels,
71 flip_sin_to_cos,
72 downscale_freq_shift,
73 max_period: 10_000.0,
74 })
75 }
76
77 pub fn forward_t<T: Float>(&self, timesteps: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
88 if timesteps.ndim() != 1 {
89 return Err(FerrotorchError::ShapeMismatch {
90 message: format!(
91 "Timesteps::forward_t: expected 1-D timesteps [B], got {:?}",
92 timesteps.shape()
93 ),
94 });
95 }
96 let batch = timesteps.shape()[0];
97 let half = self.num_channels / 2;
98 let denom = (half as f64) - self.downscale_freq_shift;
102 if denom <= 0.0 {
103 return Err(FerrotorchError::InvalidArgument {
104 message: format!(
105 "Timesteps::forward_t: invalid denominator {denom} (half={half}, \
106 downscale_freq_shift={})",
107 self.downscale_freq_shift,
108 ),
109 });
110 }
111 let log_max = self.max_period.ln();
112 let mut freqs = Vec::with_capacity(half);
113 for i in 0..half {
114 let exponent = -log_max * (i as f64) / denom;
115 freqs.push(exponent.exp());
116 }
117 let ts_data = timesteps.data()?;
120 let zero_t = T::from(0.0).ok_or_else(|| FerrotorchError::InvalidArgument {
121 message: "Timesteps::forward_t: failed to cast 0.0 into Float".into(),
122 })?;
123 let mut out = vec![zero_t; batch * self.num_channels];
124 for (b, &t) in ts_data.iter().enumerate() {
125 let t_f64: f64 = t.to_f64().ok_or_else(|| FerrotorchError::InvalidArgument {
126 message: "Timesteps::forward_t: failed to cast timestep into f64".into(),
127 })?;
128 for (i, &freq) in freqs.iter().enumerate() {
129 let arg = t_f64 * freq;
130 let cos_v = arg.cos();
131 let sin_v = arg.sin();
132 let (left, right) = if self.flip_sin_to_cos {
133 (cos_v, sin_v)
134 } else {
135 (sin_v, cos_v)
136 };
137 out[b * self.num_channels + i] =
138 T::from(left).ok_or_else(|| FerrotorchError::InvalidArgument {
139 message: "Timesteps: cast left value to T failed".into(),
140 })?;
141 out[b * self.num_channels + half + i] =
142 T::from(right).ok_or_else(|| FerrotorchError::InvalidArgument {
143 message: "Timesteps: cast right value to T failed".into(),
144 })?;
145 }
146 }
147 Tensor::from_storage(
148 TensorStorage::cpu(out),
149 vec![batch, self.num_channels],
150 false,
151 )
152 }
153}
154
155impl<T: Float> Module<T> for Timesteps {
157 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
158 self.forward_t(input)
159 }
160 fn parameters(&self) -> Vec<&Parameter<T>> {
161 Vec::new()
162 }
163 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
164 Vec::new()
165 }
166 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
167 Vec::new()
168 }
169 fn train(&mut self) {
170 }
175 fn eval(&mut self) {
176 }
178 fn is_training(&self) -> bool {
179 false
182 }
183 fn load_state_dict(&mut self, _state: &StateDict<T>, _strict: bool) -> FerrotorchResult<()> {
184 Ok(())
185 }
186}
187
188#[derive(Debug)]
202pub struct TimestepEmbedding<T: Float> {
203 pub linear_1: Linear<T>,
205 pub linear_2: Linear<T>,
207 activation: SiLU,
208 training: bool,
209}
210
211impl<T: Float> TimestepEmbedding<T> {
212 pub fn new(in_channels: usize, time_emb_dim: usize) -> FerrotorchResult<Self> {
218 let linear_1 = Linear::<T>::new(in_channels, time_emb_dim, true)?;
219 let linear_2 = Linear::<T>::new(time_emb_dim, time_emb_dim, true)?;
220 Ok(Self {
221 linear_1,
222 linear_2,
223 activation: SiLU::new(),
224 training: false,
225 })
226 }
227}
228
229impl<T: Float> Module<T> for TimestepEmbedding<T> {
230 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
231 let h = self.linear_1.forward(input)?;
232 let h = self.activation.forward(&h)?;
233 self.linear_2.forward(&h)
234 }
235
236 fn parameters(&self) -> Vec<&Parameter<T>> {
237 let mut o = Vec::new();
238 o.extend(self.linear_1.parameters());
239 o.extend(self.linear_2.parameters());
240 o
241 }
242 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
243 let mut o = Vec::new();
244 o.extend(self.linear_1.parameters_mut());
245 o.extend(self.linear_2.parameters_mut());
246 o
247 }
248 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
249 let mut o = Vec::new();
250 for (n, p) in self.linear_1.named_parameters() {
251 o.push((format!("linear_1.{n}"), p));
252 }
253 for (n, p) in self.linear_2.named_parameters() {
254 o.push((format!("linear_2.{n}"), p));
255 }
256 o
257 }
258
259 fn train(&mut self) {
260 self.training = true;
261 }
262 fn eval(&mut self) {
263 self.training = false;
264 }
265 fn is_training(&self) -> bool {
266 self.training
267 }
268
269 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
270 let extract = |prefix: &str| -> StateDict<T> {
271 let p = format!("{prefix}.");
272 state
273 .iter()
274 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
275 .collect()
276 };
277 if strict {
278 for k in state.keys() {
279 if !(k.starts_with("linear_1.") || k.starts_with("linear_2.")) {
280 return Err(FerrotorchError::InvalidArgument {
281 message: format!("unexpected key in TimestepEmbedding state_dict: \"{k}\""),
282 });
283 }
284 }
285 }
286 self.linear_1
287 .load_state_dict(&extract("linear_1"), strict)?;
288 self.linear_2
289 .load_state_dict(&extract("linear_2"), strict)?;
290 Ok(())
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn timesteps_shape_flip_true() {
300 let t = Timesteps::new(8, true, 0.0).unwrap();
301 let ts = Tensor::from_storage(
302 TensorStorage::cpu(vec![0.0f32, 50.0, 100.0]),
303 vec![3],
304 false,
305 )
306 .unwrap();
307 let e = t.forward_t(&ts).unwrap();
308 assert_eq!(e.shape(), &[3, 8]);
309 let d = e.data().unwrap();
311 for i in 0..4 {
312 assert!((d[i] - 1.0).abs() < 1e-6);
313 }
314 for i in 4..8 {
315 assert!(d[i].abs() < 1e-6);
316 }
317 }
318
319 #[test]
320 fn timesteps_rejects_odd_channels() {
321 assert!(Timesteps::new(7, true, 0.0).is_err());
322 }
323
324 #[test]
325 fn timestep_embedding_shapes() {
326 let mlp = TimestepEmbedding::<f32>::new(8, 16).unwrap();
327 let x =
328 Tensor::from_storage(TensorStorage::cpu(vec![0.5f32; 8]), vec![1, 8], false).unwrap();
329 let y = mlp.forward(&x).unwrap();
330 assert_eq!(y.shape(), &[1, 16]);
331 }
332
333 #[test]
334 fn timestep_embedding_named_parameters() {
335 let mlp = TimestepEmbedding::<f32>::new(8, 16).unwrap();
336 let names: Vec<String> = mlp.named_parameters().into_iter().map(|(n, _)| n).collect();
337 for k in [
338 "linear_1.weight",
339 "linear_1.bias",
340 "linear_2.weight",
341 "linear_2.bias",
342 ] {
343 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
344 }
345 }
346}