seizuretransformer 0.0.1

SeizureTransformer EEG model in Rust (Burn + wgpu)
Documentation

seizuretransformer-rs

Crate name on crates.io: seizuretransformer.

Rust + Burn port of the time-step SeizureTransformer model from SeizureTransformer/time_step_level/model.py.

What is included

  • U-shaped CNN encoder/decoder
  • Residual CNN stack
  • Transformer encoder (fused-QKV MHA, post-norm)
  • Time-step sigmoid output [B, T]
  • Optional weight loading from .safetensors
  • CPU (ndarray) and GPU (wgpu) backends

Build

# CPU
cargo build --release

# GPU (wgpu)
cargo build --release --no-default-features --features wgpu

# GPU Metal (macOS)
cargo build --release --no-default-features --features metal
# or alias
cargo build --release --no-default-features --features wgpu-metal

# GPU Vulkan (Linux/Windows)
cargo build --release --no-default-features --features vulkan
# or alias
cargo build --release --no-default-features --features wgpu-vulkan

Run

# random-init model + dummy input
cargo run --release --bin st-infer -- --batch-size 1

# with config + weights
cargo run --release --bin st-infer -- \
  --config ./config.json \
  --weights ./model.safetensors \
  --batch-size 1

Config JSON example

{
  "in_channels": 18,
  "in_samples": 15360,
  "dim_feedforward": 2048,
  "num_layers": 8,
  "num_heads": 4,
  "drop_rate": 0.1,
  "max_pos_len": 6000
}

Export PyTorch .pth weights to safetensors

Use the helper script:

python3 scripts/export_weights_to_safetensors.py \
  --pth ../SeizureTransformer/time_step_level/ckp/model.pth \
  --out ./model.safetensors

This exports all floating tensors from the PyTorch state dict with original key names.

Bulk conversion (all .pth/.pt in data/):

python3 scripts/convert_all_weights.py --data-dir data --recursive

Architecture & numerical parity status

For the public competition checkpoint extracted from yujjio/seizure_transformer (wu_2025/model.pth), this Rust port matches the Python time-step model architecture and outputs.

Implemented/model-matched path:

  • SeizureTransformer in time_step_level/model.py (time-step inference)
  • encoder + ResCNN + positional encoding + Transformer encoder + decoder + sigmoid head

Measured parity (same input tensor, same weights):

  • Rust CPU (NdArray) vs Python
    • MAE: 1.2e-9
    • RMSE: 2.0e-9
    • Max abs: 7.45e-8
    • Pearson: 1.0
  • Rust GPU (wgpu Metal) vs Python
    • MAE: 2.3e-9
    • RMSE: 4.2e-9
    • Max abs: 1.94e-7
    • Pearson: 1.0

Note: parity claim above is for the implemented time-step inference model path. Window-level variants and training loop parity are out of scope for now.

Python ↔ Rust parity workflow

scripts/run_parity.sh \
  /Users/Shared/SeizureTransformer \
  /Users/Shared/SeizureTransformer/time_step_level/ckp/model.pth \
  ./config.json \
  ./parity

What this does:

  1. Converts .pth.safetensors
  2. Runs Python model on deterministic random input, saves output_py_f32.bin
  3. Runs Rust model on the exact same input, saves output_rs_f32.bin
  4. Compares MAE / RMSE / MAX ABS / Pearson and fails if drift is above threshold

Thresholds currently enforced in scripts/compare_parity.py:

  • max_abs <= 1e-5
  • rmse <= 1e-6

Benchmark backends

Run all main backends (CPU ndarray, CPU+Accelerate, GPU wgpu, GPU Metal):

scripts/bench_backends.sh data/model_config.json data/model.safetensors 1 2 20

Latest measured results (batch=1, warmup=2, iters=20, model=data/model.safetensors):

Backend Avg inference latency
Rust CPU (NdArray) 1311.3 ms
Rust CPU (NdArray + Accelerate) 964.3 ms
Rust GPU (wgpu WGSL) 89.5 ms
Rust GPU (wgpu Metal) 24.1 ms

Notes:

  • blas-accelerate improves CPU by ~1.36x vs plain NdArray.
  • wgpu-metal is the fastest path on macOS (~40x faster than plain NdArray CPU in this setup).