Skip to main content

luna_rs/
lib.rs

1//! # luna-rs — LUNA EEG Foundation Model inference in Rust
2//!
3//! Pure-Rust inference for the LUNA (Latent Unified Network Architecture)
4//! EEG foundation model, built on [Burn 0.20](https://burn.dev).
5//!
6//! LUNA is a topology-agnostic EEG model that uses cross-attention with
7//! learned queries to compress variable-channel inputs into a fixed-size
8//! latent space, then processes them with a Rotary Transformer encoder.
9//!
10//! ## Quick start
11//!
12//! ```rust,ignore
13//! use luna_rs::LunaEncoder;
14//!
15//! let (model, _ms) = LunaEncoder::<B>::load(
16//!     Path::new("config.json"),
17//!     Path::new("model.safetensors"),
18//!     device,
19//! )?;
20//! ```
21
22pub mod channel_positions;
23pub mod channel_vocab;
24pub mod config;
25pub mod csv_loader;
26pub mod data;
27pub mod encoder;
28pub mod model;
29pub mod preprocessing;
30pub mod quantize;
31pub mod weights;
32
33// Flat re-exports
34pub use encoder::{LunaEncoder, EpochEmbedding, EncodingResult};
35pub use config::{ModelConfig, DataConfig};
36pub use data::{InputBatch, FifInfo, build_batch_named};
37pub use channel_positions::{channel_xyz, bipolar_channel_xyz, MontageLayout, montage_channels, nearest_channel, normalise};
38pub use channel_vocab::{CHANNEL_VOCAB, VOCAB_SIZE, channel_index, channel_indices, channel_indices_unwrap, TUEG_CHANNELS, SIENA_CHANNELS, SEED_CHANNELS};
39pub use csv_loader::{load_from_csv, CsvInfo};
40pub use preprocessing::{load_edf, load_fif, load_luna_epochs, load_csv_and_preprocess, PreprocInfo};
41
42#[cfg(test)]
43mod repeat_dim_test {
44    use burn::backend::NdArray as B;
45    use burn::prelude::*;
46    
47    #[test]
48    fn test_repeat_dim_matches_pytorch() {
49        let device = burn::backend::ndarray::NdArrayDevice::Cpu;
50        // [2, 2, 2] tensor
51        let data = vec![1f32, 2., 3., 4., 5., 6., 7., 8.];
52        let x = Tensor::<B, 3>::from_data(TensorData::new(data, vec![2, 2, 2]), &device);
53        
54        let r = x.repeat_dim(0, 3); // repeat dim 0 by 3
55        assert_eq!(r.dims(), [6, 2, 2]);
56        let vals = r.into_data().to_vec::<f32>().unwrap();
57        // PyTorch .repeat(3,1,1): [b0, b1, b0, b1, b0, b1]
58        // = [1,2,3,4, 5,6,7,8, 1,2,3,4, 5,6,7,8, 1,2,3,4, 5,6,7,8]
59        let expected = vec![1.,2.,3.,4., 5.,6.,7.,8., 1.,2.,3.,4., 5.,6.,7.,8., 1.,2.,3.,4., 5.,6.,7.,8.];
60        assert_eq!(vals, expected, "repeat_dim should match PyTorch .repeat()");
61    }
62}
63
64#[cfg(test)]
65mod trace_forward_test {
66    use burn::backend::NdArray as B;
67    use burn::prelude::*;
68
69    #[test]
70    fn test_conv2d_basic() {
71        let device = burn::backend::ndarray::NdArrayDevice::Cpu;
72        use burn::nn::conv::{Conv2dConfig};
73        
74        // Test simpler config first
75        let conv = Conv2dConfig::new([1, 16], [1, 3])
76            .with_stride([1, 1])
77            .with_padding(burn::nn::PaddingConfig2d::Explicit(0, 1))
78            .with_bias(true)
79            .init::<B>(&device);
80        let x = Tensor::<B, 4>::ones([1, 1, 8, 40], &device);
81        eprintln!("Simple Conv2d input: {:?}", x.dims());
82        let y = conv.forward(x);
83        eprintln!("Simple Conv2d output: {:?}", y.dims());
84        
85        // Now test the problematic config — use manual padding instead
86        // Conv2d(1, 16, (1, 19), stride=(1, 10), padding=(0, 9))
87        // Pad the input manually, then use no-padding conv
88        let conv2 = Conv2dConfig::new([1, 16], [1, 19])
89            .with_stride([1, 10])
90            .with_padding(burn::nn::PaddingConfig2d::Valid)
91            .with_bias(true)
92            .init::<B>(&device);
93        let x2 = Tensor::<B, 4>::ones([1, 1, 8, 40], &device);
94        // Manual pad: pad 9 on each side of W dim
95        let pad_left = Tensor::<B, 4>::zeros([1, 1, 8, 9], &device);
96        let pad_right = Tensor::<B, 4>::zeros([1, 1, 8, 9], &device);
97        let x2_padded = Tensor::cat(vec![pad_left, x2, pad_right], 3); // [1,1,8,58]
98        eprintln!("Manual-padded Conv2d input: {:?}", x2_padded.dims());
99        let y2 = conv2.forward(x2_padded);
100        eprintln!("Manual-padded Conv2d output: {:?}", y2.dims());
101    }
102    
103    #[test]
104    fn test_patch_embed_only() {
105        let device = burn::backend::ndarray::NdArrayDevice::Cpu;
106        let pe = crate::model::patch_embed::PatchEmbedNetwork::<B>::new(64, 40, &device);
107        
108        let x = Tensor::<B, 3>::ones([1, 4, 80], &device).mul_scalar(0.1f32);
109        eprintln!("PatchEmbed input: {:?}", x.dims());
110        let y = pe.forward(x);
111        eprintln!("PatchEmbed output: {:?}", y.dims());
112    }
113}