use burn::prelude::*;
use burn::module::Ignored;
use burn::nn::{
Linear,
conv::{Conv2d, Conv2dConfig},
GroupNorm, GroupNormConfig,
};
#[allow(unused_imports)]
use rayon::prelude::*;
#[allow(unused_imports)]
use rustfft::{FftPlanner, num_complex::Complex64};
use crate::config::ModelConfig;
use super::linear_zeros;
#[derive(Module, Debug)]
pub struct ConvNormBlock<B: Backend> {
pub conv: Conv2d<B>,
pub norm: GroupNorm<B>,
}
impl<B: Backend> ConvNormBlock<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
burn::tensor::activation::gelu(self.norm.forward(self.conv.forward(x)))
}
}
pub struct EmbeddingCache<B: Backend> {
pub dft_cos: Tensor<B, 2>,
pub dft_sin: Tensor<B, 2>,
pub channel_one_hot: Tensor<B, 2>,
pub spectral_bins: usize,
pub patch_size: usize,
}
impl<B: Backend> EmbeddingCache<B> {
pub fn new(cfg: &ModelConfig, device: &B::Device) -> Self {
let n = cfg.patch_size;
let k = cfg.spectral_bins();
let c = cfg.num_channels;
let two_pi_over_n = 2.0 * std::f64::consts::PI / n as f64;
let mut cos_data = Vec::with_capacity(k * n);
let mut sin_data = Vec::with_capacity(k * n);
for ki in 0..k {
for ni in 0..n {
let angle = two_pi_over_n * (ki as f64) * (ni as f64);
cos_data.push(angle.cos() as f32);
sin_data.push(angle.sin() as f32);
}
}
let dft_cos = Tensor::<B, 1>::from_floats(cos_data.as_slice(), device).reshape([k, n]);
let dft_sin = Tensor::<B, 1>::from_floats(sin_data.as_slice(), device).reshape([k, n]);
let mut oh = vec![0.0f32; c * c];
for i in 0..c {
oh[i * c + i] = 1.0;
}
let channel_one_hot = Tensor::<B, 1>::from_floats(oh.as_slice(), device).reshape([c, c]);
Self { dft_cos, dft_sin, channel_one_hot, spectral_bins: k, patch_size: n }
}
}
#[derive(Module, Debug)]
pub struct PatchEmbedding<B: Backend> {
pub conv_block1: ConvNormBlock<B>,
pub conv_block2: ConvNormBlock<B>,
pub conv_block3: ConvNormBlock<B>,
pub spectral_proj: Linear<B>,
pub channel_embedding: Linear<B>,
pub time_encoding: Conv2d<B>,
pub dft_basis: Ignored<DftBasis>,
pub channel_one_hot: Ignored<ChannelOneHot>,
pub d_model: usize,
pub num_channels: usize,
pub patch_size: usize,
}
impl<B: Backend> PatchEmbedding<B> {
pub fn new(cfg: &ModelConfig, device: &B::Device) -> Self {
let [c1, c2, c3] = cfg.conv_channels;
let [g1, g2, g3] = cfg.norm_groups;
let d = cfg.feature_size;
let conv1 = Conv2dConfig::new([1, c1], [1, 49])
.with_stride([1, 25])
.with_padding(burn::nn::PaddingConfig2d::Valid)
.init(device);
let norm1 = GroupNormConfig::new(g1, c1).init(device);
let conv2 = Conv2dConfig::new([c1, c2], [1, 3])
.with_padding(burn::nn::PaddingConfig2d::Explicit(0, 1))
.init(device);
let norm2 = GroupNormConfig::new(g2, c2).init(device);
let conv3 = Conv2dConfig::new([c2, c3], [1, 3])
.with_padding(burn::nn::PaddingConfig2d::Explicit(0, 1))
.init(device);
let norm3 = GroupNormConfig::new(g3, c3).init(device);
Self {
conv_block1: ConvNormBlock { conv: conv1, norm: norm1 },
conv_block2: ConvNormBlock { conv: conv2, norm: norm2 },
conv_block3: ConvNormBlock { conv: conv3, norm: norm3 },
spectral_proj: linear_zeros::<B>(cfg.spectral_bins(), d, true, device),
channel_embedding: linear_zeros::<B>(cfg.num_channels, d, true, device),
time_encoding: Conv2dConfig::new([d, d], [1, 5])
.with_padding(burn::nn::PaddingConfig2d::Explicit(0, 2))
.with_groups(d)
.init(device),
dft_basis: Ignored(DftBasis::new(cfg.patch_size)),
channel_one_hot: Ignored(ChannelOneHot::new(cfg.num_channels)),
d_model: d,
num_channels: cfg.num_channels,
patch_size: cfg.patch_size,
}
}
pub fn forward_cached(&self, x: Tensor<B, 4>, cache: &EmbeddingCache<B>) -> Tensor<B, 4> {
let [bz, ch_num, patch_num, patch_size] = x.dims();
let device = x.device();
let x_conv = x.clone().reshape([bz, 1, ch_num * patch_num, patch_size]);
let pad_w = 24;
let zeros = Tensor::<B, 4>::zeros([bz, 1, ch_num * patch_num, pad_w], &device);
let x_padded = Tensor::cat(vec![zeros.clone(), x_conv, zeros], 3);
let patch_emb = self.conv_block1.forward(x_padded);
let patch_emb = self.conv_block2.forward(patch_emb);
let patch_emb = self.conv_block3.forward(patch_emb);
let patch_emb = patch_emb
.permute([0, 2, 1, 3])
.reshape([bz, ch_num, patch_num, self.d_model]);
let total = bz * ch_num * patch_num;
let k = cache.spectral_bins;
let inv_n = 1.0 / patch_size as f32;
let flat = x.reshape([total, patch_size]);
let real = flat.clone().matmul(cache.dft_cos.clone().transpose());
let imag = flat.matmul(cache.dft_sin.clone().transpose());
let spectral = (real.clone() * real + imag.clone() * imag).sqrt() * inv_n;
let spectral_emb = self.spectral_proj
.forward(spectral.reshape([bz, ch_num, patch_num, k]));
let mut patch_emb = patch_emb + spectral_emb;
let chan_emb = self.channel_embedding
.forward(cache.channel_one_hot.clone())
.unsqueeze::<3>()
.unsqueeze_dim::<4>(2)
.expand([bz, ch_num, patch_num, self.d_model]);
patch_emb = patch_emb + chan_emb;
let time_emb = self.time_encoding
.forward(patch_emb.clone().permute([0, 3, 1, 2]))
.permute([0, 2, 3, 1]);
patch_emb + time_emb
}
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let [bz, ch_num, patch_num, patch_size] = x.dims();
let device = x.device();
let x_conv = x.clone().reshape([bz, 1, ch_num * patch_num, patch_size]);
let pad_w = 24;
let zeros = Tensor::<B, 4>::zeros([bz, 1, ch_num * patch_num, pad_w], &device);
let x_padded = Tensor::cat(vec![zeros.clone(), x_conv, zeros], 3);
let patch_emb = self.conv_block1.forward(x_padded);
let patch_emb = self.conv_block2.forward(patch_emb);
let patch_emb = self.conv_block3.forward(patch_emb);
let patch_emb = patch_emb
.permute([0, 2, 1, 3])
.reshape([bz, ch_num, patch_num, self.d_model]);
let spectral_emb = self.spectral_proj
.forward(self.dft_basis.0.apply::<B>(&x, &device));
let mut patch_emb = patch_emb + spectral_emb;
let chan_emb = self.channel_embedding
.forward(self.channel_one_hot.0.to_tensor::<B>(&device))
.unsqueeze::<3>()
.unsqueeze_dim::<4>(2)
.expand([bz, ch_num, patch_num, self.d_model]);
patch_emb = patch_emb + chan_emb;
let time_emb = self.time_encoding
.forward(patch_emb.clone().permute([0, 3, 1, 2]))
.permute([0, 2, 3, 1]);
patch_emb + time_emb
}
}
#[derive(Debug, Clone)]
pub struct DftBasis {
cos_data: Vec<f32>,
sin_data: Vec<f32>,
spectral_bins: usize,
}
impl DftBasis {
pub fn new(patch_size: usize) -> Self {
let k = patch_size / 2 + 1;
let two_pi_over_n = 2.0 * std::f64::consts::PI / patch_size as f64;
let mut cos_data = Vec::with_capacity(k * patch_size);
let mut sin_data = Vec::with_capacity(k * patch_size);
for ki in 0..k {
for ni in 0..patch_size {
let angle = two_pi_over_n * (ki as f64) * (ni as f64);
cos_data.push(angle.cos() as f32);
sin_data.push(angle.sin() as f32);
}
}
Self { cos_data, sin_data, spectral_bins: k }
}
fn apply<B: Backend>(&self, x: &Tensor<B, 4>, device: &B::Device) -> Tensor<B, 4> {
let [bz, ch, p, n] = x.dims();
let total = bz * ch * p;
let k = self.spectral_bins;
let inv_n = 1.0 / n as f32;
let cos_basis = Tensor::<B, 1>::from_floats(self.cos_data.as_slice(), device).reshape([k, n]);
let sin_basis = Tensor::<B, 1>::from_floats(self.sin_data.as_slice(), device).reshape([k, n]);
let flat = x.clone().reshape([total, n]);
let real = flat.clone().matmul(cos_basis.transpose());
let imag = flat.matmul(sin_basis.transpose());
let mag = (real.clone() * real + imag.clone() * imag).sqrt() * inv_n;
mag.reshape([bz, ch, p, k])
}
}
#[derive(Debug, Clone)]
pub struct ChannelOneHot {
data: Vec<f32>,
num_channels: usize,
}
impl ChannelOneHot {
pub fn new(num_channels: usize) -> Self {
let mut data = vec![0.0f32; num_channels * num_channels];
for i in 0..num_channels {
data[i * num_channels + i] = 1.0;
}
Self { data, num_channels }
}
fn to_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 2> {
let n = self.num_channels;
Tensor::<B, 1>::from_floats(self.data.as_slice(), device).reshape([n, n])
}
}