use burn::prelude::*;
use burn::module::{Param, ParamId};
use burn::nn::Linear;
use crate::model::linear_zeros;
#[derive(Module, Debug)]
pub struct FourierConditioner<B: Backend> {
pub weight: Param<Tensor<B, 2>>,
pub proj: Linear<B>,
pub half_dim: usize,
}
impl<B: Backend> FourierConditioner<B> {
pub fn new(output_dim: usize, device: &B::Device) -> Self {
let half_dim = output_dim / 2;
Self {
weight: Param::initialized(
ParamId::new(),
Tensor::zeros([half_dim, 1], device),
),
proj: linear_zeros(output_dim, output_dim, true, device),
half_dim,
}
}
pub fn forward(&self, t: Tensor<B, 3>) -> Tensor<B, 3> {
let [b, s, _] = t.dims();
let t_flat = t.reshape([b * s, 1]);
let w = self.weight.val(); let f = t_flat
.matmul(w.transpose()) .mul_scalar(2.0_f32 * std::f32::consts::PI)
.reshape([b, s, self.half_dim]);
let features = Tensor::cat(vec![f.clone().cos(), f.sin()], 2); self.proj.forward(features)
}
}