osf-rs — OSF Sleep Foundation Model in Rust
Pure-Rust inference for the OSF (Open Sleep Foundation) model, built on Burn 0.20.
OSF is a Vision Transformer (ViT-Base) trained with DINO+iBOT self-distillation on 166,500 hours of polysomnography data from 9 public sleep cohorts. It produces general-purpose representations from 12-channel PSG signals.
Features
- 100% parity with the Python OSF-Base backbone (
vit1d_cls.py) — max error < 8e-6 - CPU (NdArray + Rayon/Accelerate/OpenBLAS) and GPU (wgpu: Metal/Vulkan/WGSL)
- 7.1× GPU speedup over CPU on Apple M4 Pro (Metal)
- Loads weights from
.safetensorsformat - CLS token + patch embeddings output
Benchmarks
Benchmarked on Apple M4 Pro — CPU (NdArray + Apple Accelerate) vs GPU (wgpu Metal/MSL).
Inference Latency
| Backend | Mean | Min | Max | Std | Throughput |
|---|---|---|---|---|---|
| CPU (Accelerate) | 103.5 ms | 102.2 ms | 104.6 ms | 0.7 ms | 9.7 epochs/s |
| GPU (Metal) | 14.6 ms | 14.0 ms | 15.4 ms | 0.5 ms | 68.3 epochs/s |
GPU is 7.1× faster than CPU for single-epoch inference.

Batch Size Scaling
Per-epoch cost drops significantly with larger batches, especially on GPU:
| Batch Size | CPU Total | CPU/Epoch | GPU Total | GPU/Epoch |
|---|---|---|---|---|
| 1 | 105.2 ms | 105.2 ms | 13.1 ms | 13.1 ms |
| 2 | 126.2 ms | 63.1 ms | 13.6 ms | 6.8 ms |
| 4 | 176.5 ms | 44.1 ms | 19.8 ms | 5.0 ms |
| 8 | 272.4 ms | 34.1 ms | 38.3 ms | 4.8 ms |
| 16 | 465.4 ms | 29.1 ms | 61.4 ms | 3.8 ms |

Latency Distribution
Extremely stable latency with minimal variance on both backends:

Channel Scaling
Latency is constant regardless of how many channels carry active signal (always 12 input channels, zero-padded):

Model Load Time
Weight loading from safetensors takes ~400 ms on both backends (I/O bound):

Run Benchmarks Yourself
# Full CPU + GPU benchmark with charts
# Custom warmup/runs
# Manual
Quick Start
1. Export weights
# Install safetensors
# Convert .pth → .safetensors
2. Build & run
# CPU (default)
# macOS GPU (Metal)
# Linux GPU (Vulkan)
3. Use as a library
use ;
use NdArray as B;
let device = Cpu;
// Load model
let = load_with_config?;
// Build input: 12 PSG channels × 1920 samples (64 Hz × 30 s)
let signal = vec!;
let batch = ;
// Extract embeddings
let emb = encoder.run_batch?;
println!; // 768
println!; // 90
Architecture
PSG signal [B, 12, 1920]
↓
Conv2d(1, 768, (4, 64), stride=(4, 64)) — 2D patchification
↓
[B, 90, 768] patch tokens
↓
Prepend CLS token → [B, 91, 768]
↓
+ Positional embedding [1, 91, 768]
↓
12× TransformerBlock (PreNorm → Attention → Residual → PreNorm → FFN → Residual)
↓
LayerNorm
↓
CLS: [B, 768] | Patches: [B, 90, 768]
Input Format
| Property | Value |
|---|---|
| Channels | 12: ECG, EMG_Chin, EMG_LLeg, EMG_RLeg, ABD, THX, NP, SN, EOG_E1_A2, EOG_E2_A1, EEG_C3_A2, EEG_C4_A1 |
| Sample Rate | 64 Hz |
| Epoch Length | 30 seconds |
| Input Shape | [B, 12, 1920] |
Model Variants
| Variant | Width | Depth | Heads | MLP | Params |
|---|---|---|---|---|---|
| vit_nano | 128 | 6 | 4 | 512 | ~1M |
| vit_tiny | 192 | 12 | 3 | 768 | ~4M |
| vit_small | 384 | 12 | 6 | 1536 | ~15M |
| vit_middle | 512 | 12 | 8 | 2048 | ~27M |
| vit_base | 768 | 12 | 12 | 3072 | ~86M |
Backend Features
| Feature | Description |
|---|---|
ndarray (default) |
CPU via NdArray + Rayon |
blas-accelerate |
CPU + Apple Accelerate BLAS |
openblas-system |
CPU + system OpenBLAS |
wgpu |
GPU via wgpu (auto-detect Metal/Vulkan/DX12) |
metal |
GPU with native Metal shaders (macOS) |
vulkan |
GPU with native Vulkan shaders (Linux/Windows) |
Parity Testing
Python ↔ Rust numerical parity verified to near floating-point precision:
| Output | Max Error | Mean Error |
|---|---|---|
| CLS embedding (768-d) | 3.04e-6 | 7.05e-7 |
| Patch embeddings (90×768) | 7.81e-6 | 1.08e-6 |
# Generate test vectors
# Run parity test
Project Structure
osf-rs/
├── src/
│ ├── lib.rs # Public API
│ ├── config.rs # Model configuration
│ ├── data.rs # Input batch construction
│ ├── encoder.rs # High-level encoder API
│ ├── weights.rs # Safetensors weight loading
│ ├── model/
│ │ ├── vit.rs # ViT with CLS token
│ │ ├── patch_embed.rs # 1D/2D patch embedding
│ │ ├── transformer_block.rs
│ │ ├── attention.rs # Multi-head self-attention
│ │ ├── feedforward.rs # FFN (GELU)
│ │ └── norm.rs # LayerNorm
│ └── bin/
│ └── infer.rs # CLI inference
├── examples/
│ ├── embed.rs # Embedding extraction
│ └── benchmark.rs # Latency benchmark
├── scripts/
│ ├── export_safetensors.py
│ ├── export_parity_vectors.py
│ └── generate_figures.py
├── bench.sh # Full CPU+GPU benchmark runner
├── figures/ # Generated benchmark charts
└── tests/
├── forward_pass.rs # Shape tests
└── python_parity.rs # Numerical parity
Citation
Cite this repo:
and original work:
License
MIT — matching the original OSF-Base license.