Skip to main content

eegdino_rs/model/
embedding.rs

1/// Patch embedding layer combining temporal, spectral, and channel embeddings.
2///
3/// Matches the Python `PatchEmbedding` class from `embedding_{small,medium,large}.py`.
4///
5/// Input:  `[B, C, P, L]`  (batch, channels, patches, patch_length=200)
6/// Output: `[B, C, P, D]`  (batch, channels, patches, d_model)
7///
8/// Three embedding streams are summed:
9/// 1. **Temporal** (`proj_in`): 3-layer Conv2d stack
10/// 2. **Spectral**: rfft magnitude via on-device DFT matmul
11/// 3. **Channel**: one-hot(channel_idx) → Linear(19, D)
12///
13/// A depthwise conv `time_encoding` is added on top.
14use burn::prelude::*;
15use burn::module::Ignored;
16use burn::nn::{
17    Linear,
18    conv::{Conv2d, Conv2dConfig},
19    GroupNorm, GroupNormConfig,
20};
21#[allow(unused_imports)]
22use rayon::prelude::*;
23#[allow(unused_imports)]
24use rustfft::{FftPlanner, num_complex::Complex64};
25
26use crate::config::ModelConfig;
27use super::linear_zeros;
28
29// ── Conv-Norm block ─────────────────────────────────────────────────────────
30
31#[derive(Module, Debug)]
32pub struct ConvNormBlock<B: Backend> {
33    pub conv: Conv2d<B>,
34    pub norm: GroupNorm<B>,
35}
36
37impl<B: Backend> ConvNormBlock<B> {
38    fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
39        burn::tensor::activation::gelu(self.norm.forward(self.conv.forward(x)))
40    }
41}
42
43// ── Embedding cache (on-device, created once) ───────────────────────────────
44
45/// Cached on-device tensors for the patch embedding.
46///
47/// Created once via [`EmbeddingCache::new`] and reused across forward calls,
48/// avoiding repeated CPU→device transfers of constant data.
49pub struct EmbeddingCache<B: Backend> {
50    /// DFT cosine basis `[K, N]` where K = spectral_bins, N = patch_size.
51    pub dft_cos: Tensor<B, 2>,
52    /// DFT sine basis `[K, N]`.
53    pub dft_sin: Tensor<B, 2>,
54    /// Channel one-hot matrix `[C, C]`.
55    pub channel_one_hot: Tensor<B, 2>,
56    pub spectral_bins: usize,
57    pub patch_size: usize,
58}
59
60impl<B: Backend> EmbeddingCache<B> {
61    /// Build cached device tensors for a given config.
62    pub fn new(cfg: &ModelConfig, device: &B::Device) -> Self {
63        let n = cfg.patch_size;
64        let k = cfg.spectral_bins();
65        let c = cfg.num_channels;
66
67        // DFT basis
68        let two_pi_over_n = 2.0 * std::f64::consts::PI / n as f64;
69        let mut cos_data = Vec::with_capacity(k * n);
70        let mut sin_data = Vec::with_capacity(k * n);
71        for ki in 0..k {
72            for ni in 0..n {
73                let angle = two_pi_over_n * (ki as f64) * (ni as f64);
74                cos_data.push(angle.cos() as f32);
75                sin_data.push(angle.sin() as f32);
76            }
77        }
78        let dft_cos = Tensor::<B, 1>::from_floats(cos_data.as_slice(), device).reshape([k, n]);
79        let dft_sin = Tensor::<B, 1>::from_floats(sin_data.as_slice(), device).reshape([k, n]);
80
81        // Channel one-hot
82        let mut oh = vec![0.0f32; c * c];
83        for i in 0..c {
84            oh[i * c + i] = 1.0;
85        }
86        let channel_one_hot = Tensor::<B, 1>::from_floats(oh.as_slice(), device).reshape([c, c]);
87
88        Self { dft_cos, dft_sin, channel_one_hot, spectral_bins: k, patch_size: n }
89    }
90}
91
92// ── PatchEmbedding ──────────────────────────────────────────────────────────
93
94#[derive(Module, Debug)]
95pub struct PatchEmbedding<B: Backend> {
96    /// Temporal conv stack (`proj_in`): 3 x (Conv2d + GroupNorm + GELU)
97    pub conv_block1: ConvNormBlock<B>,
98    pub conv_block2: ConvNormBlock<B>,
99    pub conv_block3: ConvNormBlock<B>,
100    /// Spectral projection: Linear(101, d_model)
101    pub spectral_proj: Linear<B>,
102    /// Channel position embedding: Linear(num_channels, d_model)
103    pub channel_embedding: Linear<B>,
104    /// Depthwise temporal encoding: Conv2d(d_model, d_model, (1,5), groups=d_model)
105    pub time_encoding: Conv2d<B>,
106    /// Fallback DFT basis (used only when no EmbeddingCache is provided).
107    pub dft_basis: Ignored<DftBasis>,
108    /// Fallback channel one-hot.
109    pub channel_one_hot: Ignored<ChannelOneHot>,
110    pub d_model: usize,
111    pub num_channels: usize,
112    pub patch_size: usize,
113}
114
115impl<B: Backend> PatchEmbedding<B> {
116    pub fn new(cfg: &ModelConfig, device: &B::Device) -> Self {
117        let [c1, c2, c3] = cfg.conv_channels;
118        let [g1, g2, g3] = cfg.norm_groups;
119        let d = cfg.feature_size;
120
121        let conv1 = Conv2dConfig::new([1, c1], [1, 49])
122            .with_stride([1, 25])
123            .with_padding(burn::nn::PaddingConfig2d::Valid)
124            .init(device);
125        let norm1 = GroupNormConfig::new(g1, c1).init(device);
126        let conv2 = Conv2dConfig::new([c1, c2], [1, 3])
127            .with_padding(burn::nn::PaddingConfig2d::Explicit(0, 1))
128            .init(device);
129        let norm2 = GroupNormConfig::new(g2, c2).init(device);
130        let conv3 = Conv2dConfig::new([c2, c3], [1, 3])
131            .with_padding(burn::nn::PaddingConfig2d::Explicit(0, 1))
132            .init(device);
133        let norm3 = GroupNormConfig::new(g3, c3).init(device);
134
135        Self {
136            conv_block1: ConvNormBlock { conv: conv1, norm: norm1 },
137            conv_block2: ConvNormBlock { conv: conv2, norm: norm2 },
138            conv_block3: ConvNormBlock { conv: conv3, norm: norm3 },
139            spectral_proj: linear_zeros::<B>(cfg.spectral_bins(), d, true, device),
140            channel_embedding: linear_zeros::<B>(cfg.num_channels, d, true, device),
141            time_encoding: Conv2dConfig::new([d, d], [1, 5])
142                .with_padding(burn::nn::PaddingConfig2d::Explicit(0, 2))
143                .with_groups(d)
144                .init(device),
145            dft_basis: Ignored(DftBasis::new(cfg.patch_size)),
146            channel_one_hot: Ignored(ChannelOneHot::new(cfg.num_channels)),
147            d_model: d,
148            num_channels: cfg.num_channels,
149            patch_size: cfg.patch_size,
150        }
151    }
152
153    /// Forward pass using a pre-built on-device cache (fast path).
154    pub fn forward_cached(&self, x: Tensor<B, 4>, cache: &EmbeddingCache<B>) -> Tensor<B, 4> {
155        let [bz, ch_num, patch_num, patch_size] = x.dims();
156        let device = x.device();
157
158        // 1. Temporal conv stack
159        let x_conv = x.clone().reshape([bz, 1, ch_num * patch_num, patch_size]);
160        let pad_w = 24;
161        let zeros = Tensor::<B, 4>::zeros([bz, 1, ch_num * patch_num, pad_w], &device);
162        let x_padded = Tensor::cat(vec![zeros.clone(), x_conv, zeros], 3);
163        let patch_emb = self.conv_block1.forward(x_padded);
164        let patch_emb = self.conv_block2.forward(patch_emb);
165        let patch_emb = self.conv_block3.forward(patch_emb);
166        let patch_emb = patch_emb
167            .permute([0, 2, 1, 3])
168            .reshape([bz, ch_num, patch_num, self.d_model]);
169
170        // 2. Spectral (cached DFT basis)
171        let total = bz * ch_num * patch_num;
172        let k = cache.spectral_bins;
173        let inv_n = 1.0 / patch_size as f32;
174        let flat = x.reshape([total, patch_size]);
175        let real = flat.clone().matmul(cache.dft_cos.clone().transpose());
176        let imag = flat.matmul(cache.dft_sin.clone().transpose());
177        let spectral = (real.clone() * real + imag.clone() * imag).sqrt() * inv_n;
178        let spectral_emb = self.spectral_proj
179            .forward(spectral.reshape([bz, ch_num, patch_num, k]));
180
181        let mut patch_emb = patch_emb + spectral_emb;
182
183        // 3. Channel (cached one-hot)
184        let chan_emb = self.channel_embedding
185            .forward(cache.channel_one_hot.clone())
186            .unsqueeze::<3>()
187            .unsqueeze_dim::<4>(2)
188            .expand([bz, ch_num, patch_num, self.d_model]);
189        patch_emb = patch_emb + chan_emb;
190
191        // 4. Time encoding
192        let time_emb = self.time_encoding
193            .forward(patch_emb.clone().permute([0, 3, 1, 2]))
194            .permute([0, 2, 3, 1]);
195        patch_emb + time_emb
196    }
197
198    /// Forward pass without cache (rebuilds DFT/one-hot from CPU each call).
199    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
200        let [bz, ch_num, patch_num, patch_size] = x.dims();
201        let device = x.device();
202
203        let x_conv = x.clone().reshape([bz, 1, ch_num * patch_num, patch_size]);
204        let pad_w = 24;
205        let zeros = Tensor::<B, 4>::zeros([bz, 1, ch_num * patch_num, pad_w], &device);
206        let x_padded = Tensor::cat(vec![zeros.clone(), x_conv, zeros], 3);
207        let patch_emb = self.conv_block1.forward(x_padded);
208        let patch_emb = self.conv_block2.forward(patch_emb);
209        let patch_emb = self.conv_block3.forward(patch_emb);
210        let patch_emb = patch_emb
211            .permute([0, 2, 1, 3])
212            .reshape([bz, ch_num, patch_num, self.d_model]);
213
214        let spectral_emb = self.spectral_proj
215            .forward(self.dft_basis.0.apply::<B>(&x, &device));
216        let mut patch_emb = patch_emb + spectral_emb;
217
218        let chan_emb = self.channel_embedding
219            .forward(self.channel_one_hot.0.to_tensor::<B>(&device))
220            .unsqueeze::<3>()
221            .unsqueeze_dim::<4>(2)
222            .expand([bz, ch_num, patch_num, self.d_model]);
223        patch_emb = patch_emb + chan_emb;
224
225        let time_emb = self.time_encoding
226            .forward(patch_emb.clone().permute([0, 3, 1, 2]))
227            .permute([0, 2, 3, 1]);
228        patch_emb + time_emb
229    }
230}
231
232// ── Fallback types (for uncached path) ──────────────────────────────────────
233
234/// Pre-computed DFT basis stored as `Vec<f32>` (CPU-side fallback).
235#[derive(Debug, Clone)]
236pub struct DftBasis {
237    cos_data: Vec<f32>,
238    sin_data: Vec<f32>,
239    spectral_bins: usize,
240}
241
242impl DftBasis {
243    pub fn new(patch_size: usize) -> Self {
244        let k = patch_size / 2 + 1;
245        let two_pi_over_n = 2.0 * std::f64::consts::PI / patch_size as f64;
246        let mut cos_data = Vec::with_capacity(k * patch_size);
247        let mut sin_data = Vec::with_capacity(k * patch_size);
248        for ki in 0..k {
249            for ni in 0..patch_size {
250                let angle = two_pi_over_n * (ki as f64) * (ni as f64);
251                cos_data.push(angle.cos() as f32);
252                sin_data.push(angle.sin() as f32);
253            }
254        }
255        Self { cos_data, sin_data, spectral_bins: k }
256    }
257
258    fn apply<B: Backend>(&self, x: &Tensor<B, 4>, device: &B::Device) -> Tensor<B, 4> {
259        let [bz, ch, p, n] = x.dims();
260        let total = bz * ch * p;
261        let k = self.spectral_bins;
262        let inv_n = 1.0 / n as f32;
263        let cos_basis = Tensor::<B, 1>::from_floats(self.cos_data.as_slice(), device).reshape([k, n]);
264        let sin_basis = Tensor::<B, 1>::from_floats(self.sin_data.as_slice(), device).reshape([k, n]);
265        let flat = x.clone().reshape([total, n]);
266        let real = flat.clone().matmul(cos_basis.transpose());
267        let imag = flat.matmul(sin_basis.transpose());
268        let mag = (real.clone() * real + imag.clone() * imag).sqrt() * inv_n;
269        mag.reshape([bz, ch, p, k])
270    }
271}
272
273/// Pre-computed one-hot matrix stored as `Vec<f32>` (CPU-side fallback).
274#[derive(Debug, Clone)]
275pub struct ChannelOneHot {
276    data: Vec<f32>,
277    num_channels: usize,
278}
279
280impl ChannelOneHot {
281    pub fn new(num_channels: usize) -> Self {
282        let mut data = vec![0.0f32; num_channels * num_channels];
283        for i in 0..num_channels {
284            data[i * num_channels + i] = 1.0;
285        }
286        Self { data, num_channels }
287    }
288
289    fn to_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 2> {
290        let n = self.num_channels;
291        Tensor::<B, 1>::from_floats(self.data.as_slice(), device).reshape([n, n])
292    }
293}