1pub mod channel_positions;
23pub mod channel_vocab;
24pub mod config;
25pub mod csv_loader;
26pub mod data;
27pub mod encoder;
28pub mod model;
29pub mod preprocessing;
30pub mod quantize;
31pub mod weights;
32
33pub use encoder::{LunaEncoder, EpochEmbedding, EncodingResult};
35pub use config::{ModelConfig, DataConfig};
36pub use data::{InputBatch, FifInfo, build_batch_named};
37pub use channel_positions::{channel_xyz, bipolar_channel_xyz, MontageLayout, montage_channels, nearest_channel, normalise};
38pub use channel_vocab::{CHANNEL_VOCAB, VOCAB_SIZE, channel_index, channel_indices, channel_indices_unwrap, TUEG_CHANNELS, SIENA_CHANNELS, SEED_CHANNELS};
39pub use csv_loader::{load_from_csv, CsvInfo};
40pub use preprocessing::{load_edf, load_fif, load_luna_epochs, load_csv_and_preprocess, PreprocInfo};
41
42#[cfg(test)]
43mod repeat_dim_test {
44 use burn::backend::NdArray as B;
45 use burn::prelude::*;
46
47 #[test]
48 fn test_repeat_dim_matches_pytorch() {
49 let device = burn::backend::ndarray::NdArrayDevice::Cpu;
50 let data = vec![1f32, 2., 3., 4., 5., 6., 7., 8.];
52 let x = Tensor::<B, 3>::from_data(TensorData::new(data, vec![2, 2, 2]), &device);
53
54 let r = x.repeat_dim(0, 3); assert_eq!(r.dims(), [6, 2, 2]);
56 let vals = r.into_data().to_vec::<f32>().unwrap();
57 let expected = vec![1.,2.,3.,4., 5.,6.,7.,8., 1.,2.,3.,4., 5.,6.,7.,8., 1.,2.,3.,4., 5.,6.,7.,8.];
60 assert_eq!(vals, expected, "repeat_dim should match PyTorch .repeat()");
61 }
62}
63
64#[cfg(test)]
65mod trace_forward_test {
66 use burn::backend::NdArray as B;
67 use burn::prelude::*;
68
69 #[test]
70 fn test_conv2d_basic() {
71 let device = burn::backend::ndarray::NdArrayDevice::Cpu;
72 use burn::nn::conv::{Conv2dConfig};
73
74 let conv = Conv2dConfig::new([1, 16], [1, 3])
76 .with_stride([1, 1])
77 .with_padding(burn::nn::PaddingConfig2d::Explicit(0, 1))
78 .with_bias(true)
79 .init::<B>(&device);
80 let x = Tensor::<B, 4>::ones([1, 1, 8, 40], &device);
81 eprintln!("Simple Conv2d input: {:?}", x.dims());
82 let y = conv.forward(x);
83 eprintln!("Simple Conv2d output: {:?}", y.dims());
84
85 let conv2 = Conv2dConfig::new([1, 16], [1, 19])
89 .with_stride([1, 10])
90 .with_padding(burn::nn::PaddingConfig2d::Valid)
91 .with_bias(true)
92 .init::<B>(&device);
93 let x2 = Tensor::<B, 4>::ones([1, 1, 8, 40], &device);
94 let pad_left = Tensor::<B, 4>::zeros([1, 1, 8, 9], &device);
96 let pad_right = Tensor::<B, 4>::zeros([1, 1, 8, 9], &device);
97 let x2_padded = Tensor::cat(vec![pad_left, x2, pad_right], 3); eprintln!("Manual-padded Conv2d input: {:?}", x2_padded.dims());
99 let y2 = conv2.forward(x2_padded);
100 eprintln!("Manual-padded Conv2d output: {:?}", y2.dims());
101 }
102
103 #[test]
104 fn test_patch_embed_only() {
105 let device = burn::backend::ndarray::NdArrayDevice::Cpu;
106 let pe = crate::model::patch_embed::PatchEmbedNetwork::<B>::new(64, 40, &device);
107
108 let x = Tensor::<B, 3>::ones([1, 4, 80], &device).mul_scalar(0.1f32);
109 eprintln!("PatchEmbed input: {:?}", x.dims());
110 let y = pe.forward(x);
111 eprintln!("PatchEmbed output: {:?}", y.dims());
112 }
113}