# osf-rs — OSF Sleep Foundation Model in Rust
Pure-Rust inference for the [OSF (Open Sleep Foundation)](https://huggingface.co/yang-ai-lab/OSF-Base) model, built on [Burn 0.20](https://burn.dev).
OSF is a Vision Transformer (ViT-Base) trained with DINO+iBOT self-distillation on 166,500 hours of polysomnography data from 9 public sleep cohorts. It produces general-purpose representations from 12-channel PSG signals.
## Features
- **100% parity** with the Python OSF-Base backbone (`vit1d_cls.py`) — max error < 8e-6
- **CPU** (NdArray + Rayon/Accelerate/OpenBLAS) and **GPU** (wgpu: Metal/Vulkan/WGSL)
- **7.1× GPU speedup** over CPU on Apple M4 Pro (Metal)
- Loads weights from `.safetensors` format
- CLS token + patch embeddings output
## Benchmarks
Benchmarked on Apple M4 Pro — CPU (NdArray + Apple Accelerate) vs GPU (wgpu Metal/MSL).
### Inference Latency
| **CPU** (Accelerate) | 103.5 ms | 102.2 ms | 104.6 ms | 0.7 ms | 9.7 epochs/s |
| **GPU** (Metal) | 14.6 ms | 14.0 ms | 15.4 ms | 0.5 ms | 68.3 epochs/s |
> **GPU is 7.1× faster** than CPU for single-epoch inference.

### Batch Size Scaling
Per-epoch cost drops significantly with larger batches, especially on GPU:
| 1 | 105.2 ms | 105.2 ms | 13.1 ms | 13.1 ms |
| 2 | 126.2 ms | 63.1 ms | 13.6 ms | 6.8 ms |
| 4 | 176.5 ms | 44.1 ms | 19.8 ms | 5.0 ms |
| 8 | 272.4 ms | 34.1 ms | 38.3 ms | 4.8 ms |
| 16 | 465.4 ms | 29.1 ms | 61.4 ms | 3.8 ms |

### Latency Distribution
Extremely stable latency with minimal variance on both backends:

### Channel Scaling
Latency is constant regardless of how many channels carry active signal (always 12 input channels, zero-padded):

### Model Load Time
Weight loading from safetensors takes ~400 ms on both backends (I/O bound):

### Run Benchmarks Yourself
```bash
# Full CPU + GPU benchmark with charts
./bench.sh
# Custom warmup/runs
./bench.sh 5 20
# Manual
cargo run --example benchmark --release --features ndarray,blas-accelerate -- \
--weights data/osf_backbone.safetensors --json
```
## Quick Start
### 1. Export weights
```bash
# Install safetensors
pip install safetensors torch
# Convert .pth → .safetensors
python scripts/export_safetensors.py \
--input ../OSF-Base/osf_backbone.pth \
--output data/osf_backbone.safetensors
```
### 2. Build & run
```bash
# CPU (default)
cargo run --release --bin infer -- --weights data/osf_backbone.safetensors -v
# macOS GPU (Metal)
cargo run --release --no-default-features --features metal --bin infer -- \
--weights data/osf_backbone.safetensors -v
# Linux GPU (Vulkan)
cargo run --release --no-default-features --features vulkan --bin infer -- \
--weights data/osf_backbone.safetensors -v
```
### 3. Use as a library
```rust
use osf_rs::{OsfEncoder, ModelConfig, build_batch};
use burn::backend::NdArray as B;
let device = burn::backend::ndarray::NdArrayDevice::Cpu;
// Load model
let (encoder, _ms) = OsfEncoder::<B>::load_with_config(
ModelConfig::default(),
Path::new("data/osf_backbone.safetensors"),
device.clone(),
)?;
// Build input: 12 PSG channels × 1920 samples (64 Hz × 30 s)
let signal = vec![0.0f32; 12 * 1920];
let batch = build_batch::<B>(signal, 12, 1920, &device);
// Extract embeddings
let emb = encoder.run_batch(&batch)?;
println!("CLS embedding: {} dims", emb.embed_dim); // 768
println!("Patch embeddings: {} patches", emb.num_patches); // 90
```
## Architecture
```
PSG signal [B, 12, 1920]
↓
Conv2d(1, 768, (4, 64), stride=(4, 64)) — 2D patchification
↓
[B, 90, 768] patch tokens
↓
Prepend CLS token → [B, 91, 768]
↓
+ Positional embedding [1, 91, 768]
↓
12× TransformerBlock (PreNorm → Attention → Residual → PreNorm → FFN → Residual)
↓
LayerNorm
↓
CLS: [B, 768] | Patches: [B, 90, 768]
```
### Input Format
| Property | Value |
|----------|-------|
| Channels | 12: ECG, EMG_Chin, EMG_LLeg, EMG_RLeg, ABD, THX, NP, SN, EOG_E1_A2, EOG_E2_A1, EEG_C3_A2, EEG_C4_A1 |
| Sample Rate | 64 Hz |
| Epoch Length | 30 seconds |
| Input Shape | `[B, 12, 1920]` |
### Model Variants
| Variant | Width | Depth | Heads | MLP | Params |
|---------|-------|-------|-------|-----|--------|
| vit_nano | 128 | 6 | 4 | 512 | ~1M |
| vit_tiny | 192 | 12 | 3 | 768 | ~4M |
| vit_small | 384 | 12 | 6 | 1536 | ~15M |
| vit_middle | 512 | 12 | 8 | 2048 | ~27M |
| vit_base | 768 | 12 | 12 | 3072 | ~86M |
## Backend Features
| Feature | Description |
|---------|-------------|
| `ndarray` (default) | CPU via NdArray + Rayon |
| `blas-accelerate` | CPU + Apple Accelerate BLAS |
| `openblas-system` | CPU + system OpenBLAS |
| `wgpu` | GPU via wgpu (auto-detect Metal/Vulkan/DX12) |
| `metal` | GPU with native Metal shaders (macOS) |
| `vulkan` | GPU with native Vulkan shaders (Linux/Windows) |
## Parity Testing
Python ↔ Rust numerical parity verified to near floating-point precision:
| Output | Max Error | Mean Error |
|--------|-----------|------------|
| CLS embedding (768-d) | 3.04e-6 | 7.05e-7 |
| Patch embeddings (90×768) | 7.81e-6 | 1.08e-6 |
```bash
# Generate test vectors
python scripts/export_parity_vectors.py
# Run parity test
cargo test --release test_forward_parity_with_python -- --nocapture
```
## Project Structure
```
osf-rs/
├── src/
│ ├── lib.rs # Public API
│ ├── config.rs # Model configuration
│ ├── data.rs # Input batch construction
│ ├── encoder.rs # High-level encoder API
│ ├── weights.rs # Safetensors weight loading
│ ├── model/
│ │ ├── vit.rs # ViT with CLS token
│ │ ├── patch_embed.rs # 1D/2D patch embedding
│ │ ├── transformer_block.rs
│ │ ├── attention.rs # Multi-head self-attention
│ │ ├── feedforward.rs # FFN (GELU)
│ │ └── norm.rs # LayerNorm
│ └── bin/
│ └── infer.rs # CLI inference
├── examples/
│ ├── embed.rs # Embedding extraction
│ └── benchmark.rs # Latency benchmark
├── scripts/
│ ├── export_safetensors.py
│ ├── export_parity_vectors.py
│ └── generate_figures.py
├── bench.sh # Full CPU+GPU benchmark runner
├── figures/ # Generated benchmark charts
└── tests/
├── forward_pass.rs # Shape tests
└── python_parity.rs # Numerical parity
```
## Citation
Cite this repo:
```bibtex
@article{kosmyna2026osf,
title={OSF Sleep Foundation Model — inference in Rust with Burn ML},
author={Nataliya Kosmyna},
year={2026}
}
```
and original work:
```bibtex
@article{shuai2026osf,
title={OSF: On Pre-training and Scaling of Sleep Foundation Models},
author={Shuai, Zitao and Xu, Zongzhe and Yang, David and Wang, Wei and Yang, Yuzhe},
journal={arXiv preprint arXiv:2603.00190},
year={2026}
}
```
## License
MIT — matching the original OSF-Base license.