seizuretransformer-rs
Crate name on crates.io: seizuretransformer.
Rust + Burn port of the time-step SeizureTransformer model from SeizureTransformer/time_step_level/model.py.
What is included
- U-shaped CNN encoder/decoder
- Residual CNN stack
- Transformer encoder (fused-QKV MHA, post-norm)
- Time-step sigmoid output
[B, T] - Optional weight loading from
.safetensors - CPU (
ndarray) and GPU (wgpu) backends
Build
# CPU
# GPU (wgpu)
# GPU Metal (macOS)
# or alias
# GPU Vulkan (Linux/Windows)
# or alias
Run
# random-init model + dummy input
# with config + weights
Config JSON example
Export PyTorch .pth weights to safetensors
Use the helper script:
This exports all floating tensors from the PyTorch state dict with original key names.
Bulk conversion (all .pth/.pt in data/):
Architecture & numerical parity status
For the public competition checkpoint extracted from yujjio/seizure_transformer (wu_2025/model.pth), this Rust port matches the Python time-step model architecture and outputs.
Implemented/model-matched path:
SeizureTransformerintime_step_level/model.py(time-step inference)- encoder + ResCNN + positional encoding + Transformer encoder + decoder + sigmoid head
Measured parity (same input tensor, same weights):
- Rust CPU (NdArray) vs Python
- MAE:
1.2e-9 - RMSE:
2.0e-9 - Max abs:
7.45e-8 - Pearson:
1.0
- MAE:
- Rust GPU (wgpu Metal) vs Python
- MAE:
2.3e-9 - RMSE:
4.2e-9 - Max abs:
1.94e-7 - Pearson:
1.0
- MAE:
Note: parity claim above is for the implemented time-step inference model path. Window-level variants and training loop parity are out of scope for now.
Python ↔ Rust parity workflow
What this does:
- Converts
.pth→.safetensors - Runs Python model on deterministic random input, saves
output_py_f32.bin - Runs Rust model on the exact same input, saves
output_rs_f32.bin - Compares MAE / RMSE / MAX ABS / Pearson and fails if drift is above threshold
Thresholds currently enforced in scripts/compare_parity.py:
max_abs <= 1e-5rmse <= 1e-6
Benchmark backends
Run all main backends (CPU ndarray, CPU+Accelerate, GPU wgpu, GPU Metal):
Latest measured results (batch=1, warmup=2, iters=20, model=data/model.safetensors):
| Backend | Avg inference latency |
|---|---|
| Rust CPU (NdArray) | 1311.3 ms |
| Rust CPU (NdArray + Accelerate) | 964.3 ms |
| Rust GPU (wgpu WGSL) | 89.5 ms |
| Rust GPU (wgpu Metal) | 24.1 ms |
Notes:
blas-accelerateimproves CPU by ~1.36x vs plain NdArray.wgpu-metalis the fastest path on macOS (~40x faster than plain NdArray CPU in this setup).