1use burn::prelude::*;
15use burn::module::Ignored;
16use burn::nn::{
17 Linear,
18 conv::{Conv2d, Conv2dConfig},
19 GroupNorm, GroupNormConfig,
20};
21#[allow(unused_imports)]
22use rayon::prelude::*;
23#[allow(unused_imports)]
24use rustfft::{FftPlanner, num_complex::Complex64};
25
26use crate::config::ModelConfig;
27use super::linear_zeros;
28
29#[derive(Module, Debug)]
32pub struct ConvNormBlock<B: Backend> {
33 pub conv: Conv2d<B>,
34 pub norm: GroupNorm<B>,
35}
36
37impl<B: Backend> ConvNormBlock<B> {
38 fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
39 burn::tensor::activation::gelu(self.norm.forward(self.conv.forward(x)))
40 }
41}
42
43pub struct EmbeddingCache<B: Backend> {
50 pub dft_cos: Tensor<B, 2>,
52 pub dft_sin: Tensor<B, 2>,
54 pub channel_one_hot: Tensor<B, 2>,
56 pub spectral_bins: usize,
57 pub patch_size: usize,
58}
59
60impl<B: Backend> EmbeddingCache<B> {
61 pub fn new(cfg: &ModelConfig, device: &B::Device) -> Self {
63 let n = cfg.patch_size;
64 let k = cfg.spectral_bins();
65 let c = cfg.num_channels;
66
67 let two_pi_over_n = 2.0 * std::f64::consts::PI / n as f64;
69 let mut cos_data = Vec::with_capacity(k * n);
70 let mut sin_data = Vec::with_capacity(k * n);
71 for ki in 0..k {
72 for ni in 0..n {
73 let angle = two_pi_over_n * (ki as f64) * (ni as f64);
74 cos_data.push(angle.cos() as f32);
75 sin_data.push(angle.sin() as f32);
76 }
77 }
78 let dft_cos = Tensor::<B, 1>::from_floats(cos_data.as_slice(), device).reshape([k, n]);
79 let dft_sin = Tensor::<B, 1>::from_floats(sin_data.as_slice(), device).reshape([k, n]);
80
81 let mut oh = vec![0.0f32; c * c];
83 for i in 0..c {
84 oh[i * c + i] = 1.0;
85 }
86 let channel_one_hot = Tensor::<B, 1>::from_floats(oh.as_slice(), device).reshape([c, c]);
87
88 Self { dft_cos, dft_sin, channel_one_hot, spectral_bins: k, patch_size: n }
89 }
90}
91
92#[derive(Module, Debug)]
95pub struct PatchEmbedding<B: Backend> {
96 pub conv_block1: ConvNormBlock<B>,
98 pub conv_block2: ConvNormBlock<B>,
99 pub conv_block3: ConvNormBlock<B>,
100 pub spectral_proj: Linear<B>,
102 pub channel_embedding: Linear<B>,
104 pub time_encoding: Conv2d<B>,
106 pub dft_basis: Ignored<DftBasis>,
108 pub channel_one_hot: Ignored<ChannelOneHot>,
110 pub d_model: usize,
111 pub num_channels: usize,
112 pub patch_size: usize,
113}
114
115impl<B: Backend> PatchEmbedding<B> {
116 pub fn new(cfg: &ModelConfig, device: &B::Device) -> Self {
117 let [c1, c2, c3] = cfg.conv_channels;
118 let [g1, g2, g3] = cfg.norm_groups;
119 let d = cfg.feature_size;
120
121 let conv1 = Conv2dConfig::new([1, c1], [1, 49])
122 .with_stride([1, 25])
123 .with_padding(burn::nn::PaddingConfig2d::Valid)
124 .init(device);
125 let norm1 = GroupNormConfig::new(g1, c1).init(device);
126 let conv2 = Conv2dConfig::new([c1, c2], [1, 3])
127 .with_padding(burn::nn::PaddingConfig2d::Explicit(0, 1))
128 .init(device);
129 let norm2 = GroupNormConfig::new(g2, c2).init(device);
130 let conv3 = Conv2dConfig::new([c2, c3], [1, 3])
131 .with_padding(burn::nn::PaddingConfig2d::Explicit(0, 1))
132 .init(device);
133 let norm3 = GroupNormConfig::new(g3, c3).init(device);
134
135 Self {
136 conv_block1: ConvNormBlock { conv: conv1, norm: norm1 },
137 conv_block2: ConvNormBlock { conv: conv2, norm: norm2 },
138 conv_block3: ConvNormBlock { conv: conv3, norm: norm3 },
139 spectral_proj: linear_zeros::<B>(cfg.spectral_bins(), d, true, device),
140 channel_embedding: linear_zeros::<B>(cfg.num_channels, d, true, device),
141 time_encoding: Conv2dConfig::new([d, d], [1, 5])
142 .with_padding(burn::nn::PaddingConfig2d::Explicit(0, 2))
143 .with_groups(d)
144 .init(device),
145 dft_basis: Ignored(DftBasis::new(cfg.patch_size)),
146 channel_one_hot: Ignored(ChannelOneHot::new(cfg.num_channels)),
147 d_model: d,
148 num_channels: cfg.num_channels,
149 patch_size: cfg.patch_size,
150 }
151 }
152
153 pub fn forward_cached(&self, x: Tensor<B, 4>, cache: &EmbeddingCache<B>) -> Tensor<B, 4> {
155 let [bz, ch_num, patch_num, patch_size] = x.dims();
156 let device = x.device();
157
158 let x_conv = x.clone().reshape([bz, 1, ch_num * patch_num, patch_size]);
160 let pad_w = 24;
161 let zeros = Tensor::<B, 4>::zeros([bz, 1, ch_num * patch_num, pad_w], &device);
162 let x_padded = Tensor::cat(vec![zeros.clone(), x_conv, zeros], 3);
163 let patch_emb = self.conv_block1.forward(x_padded);
164 let patch_emb = self.conv_block2.forward(patch_emb);
165 let patch_emb = self.conv_block3.forward(patch_emb);
166 let patch_emb = patch_emb
167 .permute([0, 2, 1, 3])
168 .reshape([bz, ch_num, patch_num, self.d_model]);
169
170 let total = bz * ch_num * patch_num;
172 let k = cache.spectral_bins;
173 let inv_n = 1.0 / patch_size as f32;
174 let flat = x.reshape([total, patch_size]);
175 let real = flat.clone().matmul(cache.dft_cos.clone().transpose());
176 let imag = flat.matmul(cache.dft_sin.clone().transpose());
177 let spectral = (real.clone() * real + imag.clone() * imag).sqrt() * inv_n;
178 let spectral_emb = self.spectral_proj
179 .forward(spectral.reshape([bz, ch_num, patch_num, k]));
180
181 let mut patch_emb = patch_emb + spectral_emb;
182
183 let chan_emb = self.channel_embedding
185 .forward(cache.channel_one_hot.clone())
186 .unsqueeze::<3>()
187 .unsqueeze_dim::<4>(2)
188 .expand([bz, ch_num, patch_num, self.d_model]);
189 patch_emb = patch_emb + chan_emb;
190
191 let time_emb = self.time_encoding
193 .forward(patch_emb.clone().permute([0, 3, 1, 2]))
194 .permute([0, 2, 3, 1]);
195 patch_emb + time_emb
196 }
197
198 pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
200 let [bz, ch_num, patch_num, patch_size] = x.dims();
201 let device = x.device();
202
203 let x_conv = x.clone().reshape([bz, 1, ch_num * patch_num, patch_size]);
204 let pad_w = 24;
205 let zeros = Tensor::<B, 4>::zeros([bz, 1, ch_num * patch_num, pad_w], &device);
206 let x_padded = Tensor::cat(vec![zeros.clone(), x_conv, zeros], 3);
207 let patch_emb = self.conv_block1.forward(x_padded);
208 let patch_emb = self.conv_block2.forward(patch_emb);
209 let patch_emb = self.conv_block3.forward(patch_emb);
210 let patch_emb = patch_emb
211 .permute([0, 2, 1, 3])
212 .reshape([bz, ch_num, patch_num, self.d_model]);
213
214 let spectral_emb = self.spectral_proj
215 .forward(self.dft_basis.0.apply::<B>(&x, &device));
216 let mut patch_emb = patch_emb + spectral_emb;
217
218 let chan_emb = self.channel_embedding
219 .forward(self.channel_one_hot.0.to_tensor::<B>(&device))
220 .unsqueeze::<3>()
221 .unsqueeze_dim::<4>(2)
222 .expand([bz, ch_num, patch_num, self.d_model]);
223 patch_emb = patch_emb + chan_emb;
224
225 let time_emb = self.time_encoding
226 .forward(patch_emb.clone().permute([0, 3, 1, 2]))
227 .permute([0, 2, 3, 1]);
228 patch_emb + time_emb
229 }
230}
231
232#[derive(Debug, Clone)]
236pub struct DftBasis {
237 cos_data: Vec<f32>,
238 sin_data: Vec<f32>,
239 spectral_bins: usize,
240}
241
242impl DftBasis {
243 pub fn new(patch_size: usize) -> Self {
244 let k = patch_size / 2 + 1;
245 let two_pi_over_n = 2.0 * std::f64::consts::PI / patch_size as f64;
246 let mut cos_data = Vec::with_capacity(k * patch_size);
247 let mut sin_data = Vec::with_capacity(k * patch_size);
248 for ki in 0..k {
249 for ni in 0..patch_size {
250 let angle = two_pi_over_n * (ki as f64) * (ni as f64);
251 cos_data.push(angle.cos() as f32);
252 sin_data.push(angle.sin() as f32);
253 }
254 }
255 Self { cos_data, sin_data, spectral_bins: k }
256 }
257
258 fn apply<B: Backend>(&self, x: &Tensor<B, 4>, device: &B::Device) -> Tensor<B, 4> {
259 let [bz, ch, p, n] = x.dims();
260 let total = bz * ch * p;
261 let k = self.spectral_bins;
262 let inv_n = 1.0 / n as f32;
263 let cos_basis = Tensor::<B, 1>::from_floats(self.cos_data.as_slice(), device).reshape([k, n]);
264 let sin_basis = Tensor::<B, 1>::from_floats(self.sin_data.as_slice(), device).reshape([k, n]);
265 let flat = x.clone().reshape([total, n]);
266 let real = flat.clone().matmul(cos_basis.transpose());
267 let imag = flat.matmul(sin_basis.transpose());
268 let mag = (real.clone() * real + imag.clone() * imag).sqrt() * inv_n;
269 mag.reshape([bz, ch, p, k])
270 }
271}
272
273#[derive(Debug, Clone)]
275pub struct ChannelOneHot {
276 data: Vec<f32>,
277 num_channels: usize,
278}
279
280impl ChannelOneHot {
281 pub fn new(num_channels: usize) -> Self {
282 let mut data = vec![0.0f32; num_channels * num_channels];
283 for i in 0..num_channels {
284 data[i * num_channels + i] = 1.0;
285 }
286 Self { data, num_channels }
287 }
288
289 fn to_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 2> {
290 let n = self.num_channels;
291 Tensor::<B, 1>::from_floats(self.data.as_slice(), device).reshape([n, n])
292 }
293}