osf-rs 0.0.1

OSF Sleep Foundation Model — inference in Rust with Burn ML
Documentation

osf-rs — OSF Sleep Foundation Model in Rust

Pure-Rust inference for the OSF (Open Sleep Foundation) model, built on Burn 0.20.

OSF is a Vision Transformer (ViT-Base) trained with DINO+iBOT self-distillation on 166,500 hours of polysomnography data from 9 public sleep cohorts. It produces general-purpose representations from 12-channel PSG signals.

Features

  • 100% parity with the Python OSF-Base backbone (vit1d_cls.py) — max error < 8e-6
  • CPU (NdArray + Rayon/Accelerate/OpenBLAS) and GPU (wgpu: Metal/Vulkan/WGSL)
  • 7.1× GPU speedup over CPU on Apple M4 Pro (Metal)
  • Loads weights from .safetensors format
  • CLS token + patch embeddings output

Benchmarks

Benchmarked on Apple M4 Pro — CPU (NdArray + Apple Accelerate) vs GPU (wgpu Metal/MSL).

Inference Latency

Backend Mean Min Max Std Throughput
CPU (Accelerate) 103.5 ms 102.2 ms 104.6 ms 0.7 ms 9.7 epochs/s
GPU (Metal) 14.6 ms 14.0 ms 15.4 ms 0.5 ms 68.3 epochs/s

GPU is 7.1× faster than CPU for single-epoch inference.

Inference Latency

Batch Size Scaling

Per-epoch cost drops significantly with larger batches, especially on GPU:

Batch Size CPU Total CPU/Epoch GPU Total GPU/Epoch
1 105.2 ms 105.2 ms 13.1 ms 13.1 ms
2 126.2 ms 63.1 ms 13.6 ms 6.8 ms
4 176.5 ms 44.1 ms 19.8 ms 5.0 ms
8 272.4 ms 34.1 ms 38.3 ms 4.8 ms
16 465.4 ms 29.1 ms 61.4 ms 3.8 ms

Batch Scaling

Latency Distribution

Extremely stable latency with minimal variance on both backends:

Latency Distribution

Channel Scaling

Latency is constant regardless of how many channels carry active signal (always 12 input channels, zero-padded):

Channel Scaling

Model Load Time

Weight loading from safetensors takes ~400 ms on both backends (I/O bound):

Load Time

Run Benchmarks Yourself

# Full CPU + GPU benchmark with charts
./bench.sh

# Custom warmup/runs
./bench.sh 5 20

# Manual
cargo run --example benchmark --release --features ndarray,blas-accelerate -- \
    --weights data/osf_backbone.safetensors --json

Quick Start

1. Export weights

# Install safetensors
pip install safetensors torch

# Convert .pth → .safetensors
python scripts/export_safetensors.py \
    --input ../OSF-Base/osf_backbone.pth \
    --output data/osf_backbone.safetensors

2. Build & run

# CPU (default)
cargo run --release --bin infer -- --weights data/osf_backbone.safetensors -v

# macOS GPU (Metal)
cargo run --release --no-default-features --features metal --bin infer -- \
    --weights data/osf_backbone.safetensors -v

# Linux GPU (Vulkan)
cargo run --release --no-default-features --features vulkan --bin infer -- \
    --weights data/osf_backbone.safetensors -v

3. Use as a library

use osf_rs::{OsfEncoder, ModelConfig, build_batch};
use burn::backend::NdArray as B;

let device = burn::backend::ndarray::NdArrayDevice::Cpu;

// Load model
let (encoder, _ms) = OsfEncoder::<B>::load_with_config(
    ModelConfig::default(),
    Path::new("data/osf_backbone.safetensors"),
    device.clone(),
)?;

// Build input: 12 PSG channels × 1920 samples (64 Hz × 30 s)
let signal = vec![0.0f32; 12 * 1920];
let batch = build_batch::<B>(signal, 12, 1920, &device);

// Extract embeddings
let emb = encoder.run_batch(&batch)?;
println!("CLS embedding: {} dims", emb.embed_dim);       // 768
println!("Patch embeddings: {} patches", emb.num_patches); // 90

Architecture

PSG signal [B, 12, 1920]
    ↓
Conv2d(1, 768, (4, 64), stride=(4, 64))  — 2D patchification
    ↓
[B, 90, 768] patch tokens
    ↓
Prepend CLS token → [B, 91, 768]
    ↓
+ Positional embedding [1, 91, 768]
    ↓
12× TransformerBlock (PreNorm → Attention → Residual → PreNorm → FFN → Residual)
    ↓
LayerNorm
    ↓
CLS: [B, 768]  |  Patches: [B, 90, 768]

Input Format

Property Value
Channels 12: ECG, EMG_Chin, EMG_LLeg, EMG_RLeg, ABD, THX, NP, SN, EOG_E1_A2, EOG_E2_A1, EEG_C3_A2, EEG_C4_A1
Sample Rate 64 Hz
Epoch Length 30 seconds
Input Shape [B, 12, 1920]

Model Variants

Variant Width Depth Heads MLP Params
vit_nano 128 6 4 512 ~1M
vit_tiny 192 12 3 768 ~4M
vit_small 384 12 6 1536 ~15M
vit_middle 512 12 8 2048 ~27M
vit_base 768 12 12 3072 ~86M

Backend Features

Feature Description
ndarray (default) CPU via NdArray + Rayon
blas-accelerate CPU + Apple Accelerate BLAS
openblas-system CPU + system OpenBLAS
wgpu GPU via wgpu (auto-detect Metal/Vulkan/DX12)
metal GPU with native Metal shaders (macOS)
vulkan GPU with native Vulkan shaders (Linux/Windows)

Parity Testing

Python ↔ Rust numerical parity verified to near floating-point precision:

Output Max Error Mean Error
CLS embedding (768-d) 3.04e-6 7.05e-7
Patch embeddings (90×768) 7.81e-6 1.08e-6
# Generate test vectors
python scripts/export_parity_vectors.py

# Run parity test
cargo test --release test_forward_parity_with_python -- --nocapture

Project Structure

osf-rs/
├── src/
│   ├── lib.rs              # Public API
│   ├── config.rs           # Model configuration
│   ├── data.rs             # Input batch construction
│   ├── encoder.rs          # High-level encoder API
│   ├── weights.rs          # Safetensors weight loading
│   ├── model/
│   │   ├── vit.rs          # ViT with CLS token
│   │   ├── patch_embed.rs  # 1D/2D patch embedding
│   │   ├── transformer_block.rs
│   │   ├── attention.rs    # Multi-head self-attention
│   │   ├── feedforward.rs  # FFN (GELU)
│   │   └── norm.rs         # LayerNorm
│   └── bin/
│       └── infer.rs        # CLI inference
├── examples/
│   ├── embed.rs            # Embedding extraction
│   └── benchmark.rs        # Latency benchmark
├── scripts/
│   ├── export_safetensors.py
│   ├── export_parity_vectors.py
│   └── generate_figures.py
├── bench.sh                # Full CPU+GPU benchmark runner
├── figures/                # Generated benchmark charts
└── tests/
    ├── forward_pass.rs     # Shape tests
    └── python_parity.rs    # Numerical parity

Citation

Cite this repo:

@article{kosmyna2026osf,
  title={OSF Sleep Foundation Model — inference in Rust with Burn ML},
  author={Nataliya Kosmyna},
  year={2026}
}

and original work:

@article{shuai2026osf,
  title={OSF: On Pre-training and Scaling of Sleep Foundation Models},
  author={Shuai, Zitao and Xu, Zongzhe and Yang, David and Wang, Wei and Yang, Yuzhe},
  journal={arXiv preprint arXiv:2603.00190},
  year={2026}
}

License

MIT — matching the original OSF-Base license.