osf-rs 0.0.1

OSF Sleep Foundation Model — inference in Rust with Burn ML
Documentation
# osf-rs — OSF Sleep Foundation Model in Rust

Pure-Rust inference for the [OSF (Open Sleep Foundation)](https://huggingface.co/yang-ai-lab/OSF-Base) model, built on [Burn 0.20](https://burn.dev).

OSF is a Vision Transformer (ViT-Base) trained with DINO+iBOT self-distillation on 166,500 hours of polysomnography data from 9 public sleep cohorts. It produces general-purpose representations from 12-channel PSG signals.

## Features

- **100% parity** with the Python OSF-Base backbone (`vit1d_cls.py`) — max error < 8e-6
- **CPU** (NdArray + Rayon/Accelerate/OpenBLAS) and **GPU** (wgpu: Metal/Vulkan/WGSL)
- **7.1× GPU speedup** over CPU on Apple M4 Pro (Metal)
- Loads weights from `.safetensors` format
- CLS token + patch embeddings output

## Benchmarks

Benchmarked on Apple M4 Pro — CPU (NdArray + Apple Accelerate) vs GPU (wgpu Metal/MSL).

### Inference Latency

| Backend | Mean | Min | Max | Std | Throughput |
|---------|------|-----|-----|-----|------------|
| **CPU** (Accelerate) | 103.5 ms | 102.2 ms | 104.6 ms | 0.7 ms | 9.7 epochs/s |
| **GPU** (Metal) | 14.6 ms | 14.0 ms | 15.4 ms | 0.5 ms | 68.3 epochs/s |

> **GPU is 7.1× faster** than CPU for single-epoch inference.

![Inference Latency](./figures/inference_latency.png)

### Batch Size Scaling

Per-epoch cost drops significantly with larger batches, especially on GPU:

| Batch Size | CPU Total | CPU/Epoch | GPU Total | GPU/Epoch |
|------------|-----------|-----------|-----------|-----------|
| 1 | 105.2 ms | 105.2 ms | 13.1 ms | 13.1 ms |
| 2 | 126.2 ms | 63.1 ms | 13.6 ms | 6.8 ms |
| 4 | 176.5 ms | 44.1 ms | 19.8 ms | 5.0 ms |
| 8 | 272.4 ms | 34.1 ms | 38.3 ms | 4.8 ms |
| 16 | 465.4 ms | 29.1 ms | 61.4 ms | 3.8 ms |

![Batch Scaling](./figures/batch_scaling.png)

### Latency Distribution

Extremely stable latency with minimal variance on both backends:

![Latency Distribution](./figures/latency_distribution.png)

### Channel Scaling

Latency is constant regardless of how many channels carry active signal (always 12 input channels, zero-padded):

![Channel Scaling](./figures/channel_scaling.png)

### Model Load Time

Weight loading from safetensors takes ~400 ms on both backends (I/O bound):

![Load Time](./figures/load_time.png)

### Run Benchmarks Yourself

```bash
# Full CPU + GPU benchmark with charts
./bench.sh

# Custom warmup/runs
./bench.sh 5 20

# Manual
cargo run --example benchmark --release --features ndarray,blas-accelerate -- \
    --weights data/osf_backbone.safetensors --json
```

## Quick Start

### 1. Export weights

```bash
# Install safetensors
pip install safetensors torch

# Convert .pth → .safetensors
python scripts/export_safetensors.py \
    --input ../OSF-Base/osf_backbone.pth \
    --output data/osf_backbone.safetensors
```

### 2. Build & run

```bash
# CPU (default)
cargo run --release --bin infer -- --weights data/osf_backbone.safetensors -v

# macOS GPU (Metal)
cargo run --release --no-default-features --features metal --bin infer -- \
    --weights data/osf_backbone.safetensors -v

# Linux GPU (Vulkan)
cargo run --release --no-default-features --features vulkan --bin infer -- \
    --weights data/osf_backbone.safetensors -v
```

### 3. Use as a library

```rust
use osf_rs::{OsfEncoder, ModelConfig, build_batch};
use burn::backend::NdArray as B;

let device = burn::backend::ndarray::NdArrayDevice::Cpu;

// Load model
let (encoder, _ms) = OsfEncoder::<B>::load_with_config(
    ModelConfig::default(),
    Path::new("data/osf_backbone.safetensors"),
    device.clone(),
)?;

// Build input: 12 PSG channels × 1920 samples (64 Hz × 30 s)
let signal = vec![0.0f32; 12 * 1920];
let batch = build_batch::<B>(signal, 12, 1920, &device);

// Extract embeddings
let emb = encoder.run_batch(&batch)?;
println!("CLS embedding: {} dims", emb.embed_dim);       // 768
println!("Patch embeddings: {} patches", emb.num_patches); // 90
```

## Architecture

```
PSG signal [B, 12, 1920]
Conv2d(1, 768, (4, 64), stride=(4, 64))  — 2D patchification
[B, 90, 768] patch tokens
Prepend CLS token → [B, 91, 768]
+ Positional embedding [1, 91, 768]
    12× TransformerBlock (PreNorm → Attention → Residual → PreNorm → FFN → Residual)
    LayerNorm
    CLS: [B, 768]  |  Patches: [B, 90, 768]
```

### Input Format

| Property | Value |
|----------|-------|
| Channels | 12: ECG, EMG_Chin, EMG_LLeg, EMG_RLeg, ABD, THX, NP, SN, EOG_E1_A2, EOG_E2_A1, EEG_C3_A2, EEG_C4_A1 |
| Sample Rate | 64 Hz |
| Epoch Length | 30 seconds |
| Input Shape | `[B, 12, 1920]` |

### Model Variants

| Variant | Width | Depth | Heads | MLP | Params |
|---------|-------|-------|-------|-----|--------|
| vit_nano | 128 | 6 | 4 | 512 | ~1M |
| vit_tiny | 192 | 12 | 3 | 768 | ~4M |
| vit_small | 384 | 12 | 6 | 1536 | ~15M |
| vit_middle | 512 | 12 | 8 | 2048 | ~27M |
| vit_base | 768 | 12 | 12 | 3072 | ~86M |

## Backend Features

| Feature | Description |
|---------|-------------|
| `ndarray` (default) | CPU via NdArray + Rayon |
| `blas-accelerate` | CPU + Apple Accelerate BLAS |
| `openblas-system` | CPU + system OpenBLAS |
| `wgpu` | GPU via wgpu (auto-detect Metal/Vulkan/DX12) |
| `metal` | GPU with native Metal shaders (macOS) |
| `vulkan` | GPU with native Vulkan shaders (Linux/Windows) |

## Parity Testing

Python ↔ Rust numerical parity verified to near floating-point precision:

| Output | Max Error | Mean Error |
|--------|-----------|------------|
| CLS embedding (768-d) | 3.04e-6 | 7.05e-7 |
| Patch embeddings (90×768) | 7.81e-6 | 1.08e-6 |

```bash
# Generate test vectors
python scripts/export_parity_vectors.py

# Run parity test
cargo test --release test_forward_parity_with_python -- --nocapture
```

## Project Structure

```
osf-rs/
├── src/
│   ├── lib.rs              # Public API
│   ├── config.rs           # Model configuration
│   ├── data.rs             # Input batch construction
│   ├── encoder.rs          # High-level encoder API
│   ├── weights.rs          # Safetensors weight loading
│   ├── model/
│   │   ├── vit.rs          # ViT with CLS token
│   │   ├── patch_embed.rs  # 1D/2D patch embedding
│   │   ├── transformer_block.rs
│   │   ├── attention.rs    # Multi-head self-attention
│   │   ├── feedforward.rs  # FFN (GELU)
│   │   └── norm.rs         # LayerNorm
│   └── bin/
│       └── infer.rs        # CLI inference
├── examples/
│   ├── embed.rs            # Embedding extraction
│   └── benchmark.rs        # Latency benchmark
├── scripts/
│   ├── export_safetensors.py
│   ├── export_parity_vectors.py
│   └── generate_figures.py
├── bench.sh                # Full CPU+GPU benchmark runner
├── figures/                # Generated benchmark charts
└── tests/
    ├── forward_pass.rs     # Shape tests
    └── python_parity.rs    # Numerical parity
```

## Citation

Cite this repo:

```bibtex
@article{kosmyna2026osf,
  title={OSF Sleep Foundation Model — inference in Rust with Burn ML},
  author={Nataliya Kosmyna},
  year={2026}
}
```

and original work:

```bibtex
@article{shuai2026osf,
  title={OSF: On Pre-training and Scaling of Sleep Foundation Models},
  author={Shuai, Zitao and Xu, Zongzhe and Yang, David and Wang, Wei and Yang, Yuzhe},
  journal={arXiv preprint arXiv:2603.00190},
  year={2026}
}
```

## License

MIT — matching the original OSF-Base license.