luna-rs 0.0.3

LUNA EEG Foundation Model — inference in Rust with Burn ML
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
# luna-rs

**LUNA** (Latent Unified Network Architecture) EEG Foundation Model — inference in Rust with [Burn ML](https://burn.dev).

A pure-Rust implementation of the [LUNA](https://huggingface.co/thorir/LUNA) model from [BioFoundation](https://github.com/pulp-bio/BioFoundation) (ETH Zurich), a topology-agnostic EEG foundation model that uses cross-attention with learned queries to handle variable-channel EEG recordings.

Weights are downloaded automatically from HuggingFace. Numerical parity with the Python implementation is verified to **RMSE 0.000002** (Pearson r = 1.000000).

## Architecture

LUNA's key innovation is **channel unification via cross-attention**: regardless of whether the input has 20, 22, or 62 EEG channels, it compresses them into a fixed number of learned queries per time patch.

```
EEG signal (B, C, T)
    ├─→ PatchEmbedNetwork (3-layer CNN)  ──┐
    │                                       ├─→ sum → (B, C×S, D)
    └─→ FrequencyFeatureEmbedder (FFT+MLP)─┘
                              + NeRF positional encoding of 3D electrode locations
                              + channel location MLP
                              + mask tokens (if pre-training)
                              rearrange: (B, C×S, D) → (B×S, C, D)
                              CrossAttentionBlock
                              Q learned queries attend to C channels
                              → FFN → 3-layer query self-attention
                              (B×S, Q, D) → reshape → (B, S, Q×D)
                              N × RotaryTransformerBlock (RoPE self-attention + FFN)
                              LayerNorm → (B, S, Q×D)
                    ┌───────────────────────┴───────────────────────┐
                    │ Reconstruction (pretrain)                      │ Classification (finetune)
                    │                                                │
        TransformerDecoderLayer                           Learned aggregation query
        (channel queries reconstruct patches)             → cross-attention → MLP
                    │                                                │
              (B, C, T) signal                              (B, num_classes) logits
```

### Model Variants

| Variant    | Params | Layers | Queries (Q) | embed_dim (D) | Q×D |
|------------|--------|--------|-------------|---------------|-----|
| LUNA-Base  | 7M     | 8      | 4           | 64            | 256 |
| LUNA-Large | 43M    | 10     | 6           | 96            | 576 |
| LUNA-Huge  | 311M   | 24     | 8           | 128           | 1024|

Weights hosted at [`thorir/LUNA`](https://huggingface.co/thorir/LUNA) on HuggingFace.

---

## Benchmarks

Inference benchmarks across two platforms: **Linux aarch64 VM** (16C/16T, 46GB RAM, Virtio/Vulkan GPU) and **Apple M3 Max** (12C/16T, 48GB RAM, Metal GPU). All runs use 22 EEG channels × 1280 samples (5s @ 256Hz), 3 warmup + 10 timed runs.

### Inference Latency

![Inference Latency](./figures/inference_latency.png)

| Variant | Linux CPU | Linux GPU (Vulkan) | M3 Max CPU (Accelerate) | M3 Max GPU (Metal) |
|---------|-----------|-------------------|------------------------|-------------------|
| **Base** (7M) | 82.3 ms | 226.1 ms | 26.5 ms | **13.2 ms** |
| **Large** (43M) | 181.1 ms | 328.2 ms | 64.2 ms | **13.0 ms** |
| **Huge** (311M) | 2550.7 ms | 771.2 ms | 602.7 ms | **23.6 ms** |

### Speedup vs Linux CPU Baseline

![Speedup](./figures/speedup.png)

| Variant | M3 Max CPU | M3 Max GPU (Metal) |
|---------|-----------|-------------------|
| **Base** | 3.1× | **6.2×** |
| **Large** | 2.8× | **13.9×** |
| **Huge** | 4.2× | **108.1×** |

### Model Load Time

![Load Time](./figures/load_time.png)

### Latency Distribution

![Latency Distribution](./figures/latency_distribution.png)

### Channel Scaling

![Channel Scaling](./figures/channel_scaling.png)

M3 Max Metal GPU latency is nearly flat across channel counts (12–25ms regardless of 4 or 32 channels), showing the GPU is compute-bound rather than memory-bound at these sizes.

### Run Benchmarks

```sh
# All variants, CPU vs GPU
./bench.sh base,large,huge

# Custom warmup/runs
./bench.sh base,large,huge 5 20
```

---

## Quick Start

```sh
# Download weights and run reconstruction on synthetic EEG
cargo run --example reconstruct --release --features hf-download -- -v
```

Output:
```
▸ Input: 22 channels × 1280 samples (5s @ 256Hz)
▸ Forward pass: 83 ms

▸ Outputs:
  x_reconstructed: [1, 22, 1280]
  attention_scores: [32, 4, 22]

▸ Query → Channel attention (first time patch):
    Q0: top-3 = P3-O1(0.565), P4-O2(0.193), T3-C3(0.177)
    Q1: top-3 = CZ-C4(0.242), C3-CZ(0.239), F3-C3(0.212)
    Q2: top-3 = C4-P4(0.371), T4-A2(0.336), A1-T3(0.126)
    Q3: top-3 = F7-T3(0.454), FP2-F8(0.231), T4-T6(0.112)
```

---

## Build

```sh
# CPU (default — Rayon multi-threading + SIMD)
cargo build --release

# CPU — macOS with Apple Accelerate BLAS
cargo build --release --features blas-accelerate

# CPU — Linux with OpenBLAS
cargo build --release --features openblas-system

# GPU — cross-platform WGSL shaders (Metal on macOS, Vulkan on Linux, DX12 on Windows)
cargo build --release --no-default-features --features wgpu

# GPU — macOS: native Metal shaders (MSL) — fastest on Apple Silicon
cargo build --release --no-default-features --features metal

# GPU — Linux/Windows: native Vulkan shaders (SPIR-V) — fastest on NVIDIA/AMD
cargo build --release --no-default-features --features vulkan
```

### GPU Backend Details

| Platform | Runtime | Shader pipeline | Feature flag |
|----------|---------|----------------|--------------|
| macOS    | Metal   | WGSL (generic) | `--features wgpu` |
| macOS    | Metal   | MSL (native, faster) | `--features metal` |
| Linux    | Vulkan  | WGSL (generic) | `--features wgpu` |
| Linux    | Vulkan  | SPIR-V (native, faster) | `--features vulkan` |
| Windows  | Vulkan/DX12 | WGSL (generic) | `--features wgpu` |
| Windows  | Vulkan  | SPIR-V (native, faster) | `--features vulkan` |

---

## API

### High-level: `LunaEncoder`

```rust
use luna_rs::{LunaEncoder, build_batch_named, TUEG_CHANNELS};
use std::path::Path;

// Load model
let (encoder, _ms) = LunaEncoder::<B>::load(
    Path::new("config.json"),
    Path::new("model.safetensors"),
    device,
)?;

// Build input from channel names (auto-resolves positions + vocab indices)
let batch = build_batch_named::<B>(signal_vec, TUEG_CHANNELS, 1280, &device);

// Run inference
let result = encoder.run_batch(&batch)?;
println!("Output shape: {:?}", result.shape);

// Save / load results
result.save_safetensors("output.safetensors")?;
let loaded = EncodingResult::load_safetensors("output.safetensors")?;
```

### Low-level: direct model access

```rust
use luna_rs::model::luna::{Luna, LunaOutput};
use luna_rs::model::rope::RotaryEmbedding;

let model = luna_rs::weights::load_model::<B>(&cfg, "weights.safetensors", 90, &device)?;
let rope = RotaryEmbedding::new(head_dim, 1024, 10_000.0, &device);

let output = model.forward(signal, channel_locations, None, Some(channel_names), &rope);

match output {
    LunaOutput::Reconstruction { x_reconstructed, x_original, attention_scores } => { ... }
    LunaOutput::Classification { logits, x_original } => { ... }
}
```

### CSV input

```rust
use luna_rs::load_from_csv;

let (batches, info) = load_from_csv::<B>(Path::new("recording.csv"), 256.0, 1280, &device)?;
println!("{} epochs from {} channels", info.n_epochs, info.ch_names.len());
```

---

## Examples

All examples auto-download LUNA-Base weights from HuggingFace.

| Example | What it demonstrates | Command |
|---------|---------------------|---------|
| **`load_and_inspect`** | Download weights, print architecture summary and parameter breakdown | `cargo run --example load_and_inspect --release --features hf-download` |
| **`reconstruct`** | Full reconstruction forward pass, per-channel RMSE, query→channel attention patterns | `cargo run --example reconstruct --release --features hf-download -- -v` |
| **`channel_invariance`** | Same model on 4 different channel counts (8, 10, 16, 22) — all work | `cargo run --example channel_invariance --release --features hf-download` |
| **`benchmark`** | Inference latency, channel-scaling benchmark (4→32 channels) | `cargo run --example benchmark --release --features hf-download` |
| **`embed`** | High-level `LunaEncoder` API, multi-epoch processing, save to safetensors | `cargo run --example embed --release --features hf-download -- -v` |

Use `--variant large` or `--variant huge` to switch model sizes.

---

## Binaries

| Binary | Purpose | Command |
|--------|---------|---------|
| **`infer`** | Run inference on dummy input, print timing | `cargo run --release -- --weights W --config C --output O` |
| **`download_weights`** | Download weights from HuggingFace | `cargo run --bin download_weights --release --features hf-download -- --variant base` |

---

## Python Parity

Numerically verified against the Python [BioFoundation](https://github.com/pulp-bio/BioFoundation) LUNA implementation. Test vectors are exported from Python with `mask=None` (inference mode) and compared in Rust with strict assertions.

### Per-component accuracy

| Component | Max error | Test file |
|-----------|-----------|-----------|
| `PatchEmbedNetwork` (3-layer CNN) | 0.000008 | `intermediate_parity.rs` |
| `FrequencyFeatureEmbedder` (rustfft f64 + MLP) | 0.000055 | `intermediate_parity.rs` |
| `nerf_positional_encoding` | 0.000000 | `intermediate_parity.rs` |
| `channel_location_embedder` (MLP) | 0.000001 | `intermediate_parity.rs` |
| `CrossAttentionBlock` output | 0.000019 | `intermediate_parity.rs` |
| `CrossAttentionBlock` attention scores | 0.000005 | `intermediate_parity.rs` |
| Transformer blocks 0–7 (each) | ≤ 0.000008 | `block_parity.rs` |
| `ReconstructionHead` (TransformerDecoder) | 0.000003 | `decoder_parity.rs` |

### End-to-end accuracy

| Metric | Value |
|--------|-------|
| **RMSE** | **0.000002** |
| **Max absolute error** | **0.000046** |
| **Relative RMSE** | **0.000005 (0.00%)** |
| **Pearson correlation** | **1.000000** |

### Reproducing parity tests

```sh
# 1. Export Python reference vectors (requires PyTorch + BioFoundation repo)
python scripts/export_parity_vectors.py
python scripts/export_intermediates.py

# 2. Run all 24 tests
cargo test --release
```

### What enables exact parity

| Technique | Why it matters |
|-----------|---------------|
| `rustfft` in **f64** for FFT | Matches `torch.fft.rfft`'s internal f64 promotion on CPU |
| `f32::atan2` on CPU | Bit-identical to PyTorch's `torch.angle()` (same libc `atan2f`) |
| `FusedMultiheadAttention` with single `in_proj` Linear | Matches `nn.MultiheadAttention`'s fused `in_proj_weight [3D, D]` layout |
| `TransformerEncoderLayer` with `norm_first` | Matches `nn.TransformerEncoderLayer(norm_first=True)` structure |
| 3-sublayer `TransformerDecoderLayer` | Self-attn → cross-attn → FFN, matches `nn.TransformerDecoderLayer(norm_first=True)` |
| `mask=None` at inference | Avoids Python's training-time `randn * 0.02` noise on channel locations |
| Correct `(D E)` flatten in `PatchEmbedNetwork` | Matches `einops.rearrange('B E CS D -> B CS (D E)')` — D-inner, E-outer |
| `repeat_dim(0, n)` for channel embeddings | Matches PyTorch `.repeat(n, 1, 1)` tile semantics |
| DC/Nyquist bin clamping in FFT | Forces `imag=0` at k=0 and k=N/2, matching `rfft` guarantees |

---

## Test Suite

24 tests across 8 test files, all passing with zero warnings.

| File | Tests | What it verifies |
|------|-------|------------------|
| `tests/python_parity.rs` | 1 | End-to-end: RMSE < 0.0001, correlation > 0.9999 |
| `tests/intermediate_parity.rs` | 1 | Per-component: patch, freq, nerf, loc, cross-attn (all < 0.000055) |
| `tests/block_parity.rs` | 1 | Per-transformer-block: 8 blocks + norm (all < 0.000008) |
| `tests/decoder_parity.rs` | 1 | Decoder head in isolation (max_err = 0.000003) |
| `tests/f64_parity.rs` | 1 | f64 backend gives same parity (RMSE = 0.000002) |
| `tests/forward_pass.rs` | 4 | Output shapes, value ranges, variable channels (4–29), channel vocab |
| `src/lib.rs` (unit) | 15 | Channel vocab (7), positions (3), CSV (2), conv2d (1), patch_embed (1), repeat_dim (1) |

---

## Project Structure

```
luna-rs/
├── src/
│   ├── lib.rs                  # Public API, re-exports
│   ├── config.rs               # ModelConfig, DataConfig
│   ├── data.rs                 # InputBatch, build_batch, build_batch_named, channel_wise_normalize
│   ├── encoder.rs              # LunaEncoder (high-level API), EncodingResult (save/load safetensors)
│   ├── weights.rs              # WeightMap, load_model (safetensors → Burn tensors)
│   ├── channel_positions.rs    # 6 embedded ELC montage files, bipolar_channel_xyz
│   ├── channel_vocab.rs        # 90-channel vocabulary (TUEG + Siena + SEED)
│   ├── csv_loader.rs           # load_from_csv (CSV → InputBatch epochs)
│   ├── model/
│   │   ├── luna.rs             # Full LUNA model, nerf_positional_encoding, LunaOutput enum
│   │   ├── patch_embed.rs      # PatchEmbedNetwork (3-layer CNN)
│   │   ├── freq_embed.rs       # FrequencyFeatureEmbedder (rustfft f64 + MLP)
│   │   ├── cross_attention.rs  # CrossAttentionBlock, FusedMultiheadAttention, TransformerEncoderLayer
│   │   ├── attention.rs        # RotarySelfAttention (1-D RoPE)
│   │   ├── encoder_block.rs    # RotaryEncoderBlock (norm → attn → norm → FFN)
│   │   ├── feedforward.rs      # FeedForward (fc1 → GELU → LayerNorm → fc2)
│   │   ├── rope.rs             # RotaryEmbedding (precomputed rotation matrices)
│   │   ├── norm.rs             # LunaLayerNorm wrapper
│   │   ├── reconstruction_head.rs  # PatchReconstructionHead (TransformerDecoderLayer + MLP)
│   │   └── classification_head.rs  # ClassificationHead (aggregation query + MLP)
│   ├── bin/
│   │   ├── infer.rs            # CLI inference binary
│   │   └── download_weights.rs # HuggingFace weight downloader
│   └── montages/               # 6 ASA .elc montage files (standard_1005, 1020, etc.)
├── examples/
│   ├── common/mod.rs           # Shared utilities, HF weight resolution, synthetic EEG generation
│   ├── load_and_inspect.rs     # Architecture inspection
│   ├── reconstruct.rs          # Masked reconstruction with attention analysis
│   ├── channel_invariance.rs   # Variable channel count demonstration
│   ├── benchmark.rs            # Latency benchmarking
│   └── embed.rs                # High-level embedding extraction
├── tests/
│   ├── python_parity.rs        # End-to-end numerical parity (RMSE = 0.000002)
│   ├── intermediate_parity.rs  # Per-component numerical parity
│   ├── block_parity.rs         # Per-transformer-block parity
│   ├── decoder_parity.rs       # Decoder head parity
│   ├── f64_parity.rs           # f64 backend parity
│   ├── forward_pass.rs         # Integration tests with real weights
│   └── vectors/                # Exported Python reference tensors (safetensors)
├── scripts/
│   ├── export_parity_vectors.py     # Export Python LUNA output for Rust comparison
│   └── export_intermediates.py      # Export per-component intermediate tensors
├── Cargo.toml
├── README.md
└── PLAN.md                     # Development roadmap
```

---

## Dependencies

### Core (always compiled)
- [`burn`]https://burn.dev 0.20.1 — ML framework (tensor ops, nn modules)
- [`rustfft`]https://crates.io/crates/rustfft 6 — FFT for frequency embedder (exact parity with torch.fft.rfft)
- [`exg`]https://github.com/eugenehp/exg — EEG preprocessing (FIF/EDF reader, filtering, resampling, montage)
- `safetensors` — weight loading and result I/O
- `serde` + `serde_json` — config parsing
- `half` — bf16→f32 weight conversion
- `anyhow` — error handling

### Optional
- `burn-ndarray` — CPU backend (default)
- `burn-wgpu` — GPU backend
- `hf-hub` — HuggingFace weight download (`--features hf-download`)
- `clap` — CLI argument parsing (binaries only)

---

## Citation

If you use LUNA, please cite the original paper:

```bibtex
@inproceedings{
  doner2025luna,
  title={{LUNA}: Efficient and Topology-Agnostic Foundation Model for {EEG} Signal Analysis},
  author={Berkay D{\"o}ner and Thorir Mar Ingolfsson and Luca Benini and Yawei Li},
  booktitle={The Thirty-ninth Annual Conference on Neural Information Processing Systems},
  year={2025},
  url={https://openreview.net/forum?id=uazfjnFL0G}
}
```

## License

Apache-2.0