osf-rs 0.0.1

OSF Sleep Foundation Model — inference in Rust with Burn ML
Documentation
/// Load pretrained OSF weights from a safetensors file.
///
/// The OSF backbone stores weights as float32 in PyTorch state_dict format.
/// The weight keys follow this pattern:
///
///   to_patch_embedding.weight        [768, 1, 4, 64]   — Conv2d
///   cls_token                        [1, 1, 768]
///   pos_embedding                    [1, 91, 768]
///   lead_emb.weight                  [3, 768]
///   block{i}.attn.norm.weight        [768]
///   block{i}.attn.norm.bias          [768]
///   block{i}.attn.fn.to_qkv.weight   [2304, 768]
///   block{i}.attn.fn.to_qkv.bias     [2304]
///   block{i}.attn.fn.to_out.0.weight  [768, 768]
///   block{i}.attn.fn.to_out.0.bias    [768]
///   block{i}.ff.norm.weight           [768]
///   block{i}.ff.norm.bias             [768]
///   block{i}.ff.fn.net.0.weight       [3072, 768]
///   block{i}.ff.fn.net.0.bias         [3072]
///   block{i}.ff.fn.net.3.weight       [768, 3072]
///   block{i}.ff.fn.net.3.bias         [768]
///   norm.weight                       [768]
///   norm.bias                         [768]

use std::collections::HashMap;
use burn::prelude::*;
use half::bf16;
use safetensors::SafeTensors;

use crate::config::ModelConfig;
use crate::model::vit::OsfViT;

// ── Raw tensor map ────────────────────────────────────────────────────────────

pub struct WeightMap {
    pub tensors: HashMap<String, (Vec<f32>, Vec<usize>)>,
}

impl WeightMap {
    /// Load all tensors from a safetensors file.
    pub fn from_file(path: &str) -> anyhow::Result<Self> {
        let bytes = std::fs::read(path)?;
        let st = SafeTensors::deserialize(&bytes)?;
        let mut tensors = HashMap::with_capacity(st.len());

        for (raw_key, view) in st.tensors() {
            let key = raw_key.to_string();
            let shape: Vec<usize> = view.shape().to_vec();
            let data = view.data();

            let f32s: Vec<f32> = match view.dtype() {
                safetensors::Dtype::BF16 => data
                    .chunks_exact(2)
                    .map(|b| bf16::from_le_bytes([b[0], b[1]]).to_f32())
                    .collect(),
                safetensors::Dtype::F32 => data
                    .chunks_exact(4)
                    .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
                    .collect(),
                other => anyhow::bail!("unsupported dtype {:?} for key {key}", other),
            };

            tensors.insert(key, (f32s, shape));
        }

        Ok(Self { tensors })
    }

    /// Take a tensor by key (removes from map).
    pub fn take<B: Backend, const N: usize>(
        &mut self,
        key: &str,
        device: &B::Device,
    ) -> anyhow::Result<Tensor<B, N>> {
        let (data, shape) = self.tensors.remove(key)
            .ok_or_else(|| anyhow::anyhow!("weight key not found: {key}"))?;
        if shape.len() != N {
            anyhow::bail!("rank mismatch for {key}: expected {N}, got {}", shape.len());
        }
        Ok(Tensor::<B, N>::from_data(
            TensorData::new(data, shape),
            device,
        ))
    }

    pub fn has(&self, key: &str) -> bool {
        self.tensors.contains_key(key)
    }

    pub fn print_keys(&self) {
        let mut keys: Vec<&str> = self.tensors.keys().map(String::as_str).collect();
        keys.sort();
        for k in keys {
            let (_, s) = &self.tensors[k];
            println!("  {k:60}  {s:?}");
        }
    }

    pub fn remaining_keys(&self) -> Vec<String> {
        let mut keys: Vec<String> = self.tensors.keys().cloned().collect();
        keys.sort();
        keys
    }
}

// ── Weight assignment helpers ─────────────────────────────────────────────────

fn set_linear_wb<B: Backend>(linear: &mut burn::nn::Linear<B>, w: Tensor<B, 2>, b: Tensor<B, 1>) {
    // PyTorch [out, in] → burn [in, out]
    linear.weight = linear.weight.clone().map(|_| w.transpose());
    if let Some(ref bias) = linear.bias {
        linear.bias = Some(bias.clone().map(|_| b));
    }
}

fn set_layernorm<B: Backend>(norm: &mut crate::model::norm::OsfLayerNorm<B>, w: Tensor<B, 1>, b: Tensor<B, 1>) {
    norm.inner.gamma = norm.inner.gamma.clone().map(|_| w);
    if let Some(ref beta) = norm.inner.beta {
        norm.inner.beta = Some(beta.clone().map(|_| b));
    }
}

fn set_conv1d_w<B: Backend>(conv: &mut burn::nn::conv::Conv1d<B>, w: Tensor<B, 3>) {
    conv.weight = conv.weight.clone().map(|_| w);
}

fn set_conv2d_w<B: Backend>(conv: &mut burn::nn::conv::Conv2d<B>, w: Tensor<B, 4>) {
    conv.weight = conv.weight.clone().map(|_| w);
}

// ── Full model loader ─────────────────────────────────────────────────────────

/// Load an OSF ViT model from a safetensors file.
pub fn load_model<B: Backend>(
    cfg: &ModelConfig,
    weights_path: &str,
    device: &B::Device,
) -> anyhow::Result<OsfViT<B>> {
    let mut wm = WeightMap::from_file(weights_path)?;
    eprintln!("Loading {} weight tensors...", wm.tensors.len());
    load_model_from_wm(cfg, &mut wm, device)
}

/// Load an OSF ViT model from a pre-loaded [`WeightMap`].
pub fn load_model_from_wm<B: Backend>(
    cfg: &ModelConfig,
    wm: &mut WeightMap,
    device: &B::Device,
) -> anyhow::Result<OsfViT<B>> {
    let mut model = OsfViT::new(
        cfg.num_leads,
        cfg.seq_len,
        cfg.patch_size_time,
        cfg.patch_size_ch,
        cfg.lead_wise,
        cfg.width,
        cfg.depth,
        cfg.mlp_dim,
        cfg.heads,
        cfg.dim_head,
        device,
    );

    load_vit_weights(wm, &mut model, device)?;
    Ok(model)
}

fn load_vit_weights<B: Backend>(
    wm: &mut WeightMap,
    model: &mut OsfViT<B>,
    device: &B::Device,
) -> anyhow::Result<()> {
    // ── Patch embedding ─────────────────────────────────────────────────────
    if model.lead_wise == 0 {
        // Conv1d weight: [width, num_leads, patch_size]
        if let Ok(w) = wm.take::<B, 3>("to_patch_embedding.weight", device) {
            set_conv1d_w(model.patch_embed.conv1d.as_mut().unwrap(), w);
        }
    } else {
        // Conv2d weight: [width, 1, patch_size_ch, patch_size_time]
        if let Ok(w) = wm.take::<B, 4>("to_patch_embedding.weight", device) {
            set_conv2d_w(model.patch_embed.conv2d.as_mut().unwrap(), w);
        }
    }

    // ── CLS token ───────────────────────────────────────────────────────────
    if let Ok(t) = wm.take::<B, 3>("cls_token", device) {
        model.cls_token = model.cls_token.clone().map(|_| t);
    }

    // ── Positional embedding ────────────────────────────────────────────────
    if let Ok(t) = wm.take::<B, 3>("pos_embedding", device) {
        model.pos_embedding = model.pos_embedding.clone().map(|_| t);
    }

    // ── Lead embedding (lead_wise=1) ────────────────────────────────────────
    if let Some(ref mut emb) = model.lead_emb {
        if let Ok(w) = wm.take::<B, 2>("lead_emb.weight", device) {
            emb.weight = emb.weight.clone().map(|_| w);
        }
    }

    // ── Transformer blocks ──────────────────────────────────────────────────
    for (i, block) in model.blocks.iter_mut().enumerate() {
        let p = format!("block{i}");

        // Attention pre-norm: block{i}.attn.norm
        if let (Ok(w), Ok(b)) = (
            wm.take::<B, 1>(&format!("{p}.attn.norm.weight"), device),
            wm.take::<B, 1>(&format!("{p}.attn.norm.bias"), device),
        ) { set_layernorm(&mut block.attn_norm, w, b); }

        // Attention QKV: block{i}.attn.fn.to_qkv
        if let (Ok(w), Ok(b)) = (
            wm.take::<B, 2>(&format!("{p}.attn.fn.to_qkv.weight"), device),
            wm.take::<B, 1>(&format!("{p}.attn.fn.to_qkv.bias"), device),
        ) { set_linear_wb(&mut block.attn.to_qkv, w, b); }

        // Attention output: block{i}.attn.fn.to_out.0
        if let (Ok(w), Ok(b)) = (
            wm.take::<B, 2>(&format!("{p}.attn.fn.to_out.0.weight"), device),
            wm.take::<B, 1>(&format!("{p}.attn.fn.to_out.0.bias"), device),
        ) { set_linear_wb(&mut block.attn.to_out, w, b); }

        // FeedForward pre-norm: block{i}.ff.norm
        if let (Ok(w), Ok(b)) = (
            wm.take::<B, 1>(&format!("{p}.ff.norm.weight"), device),
            wm.take::<B, 1>(&format!("{p}.ff.norm.bias"), device),
        ) { set_layernorm(&mut block.ff_norm, w, b); }

        // FeedForward fc1: block{i}.ff.fn.net.0
        if let (Ok(w), Ok(b)) = (
            wm.take::<B, 2>(&format!("{p}.ff.fn.net.0.weight"), device),
            wm.take::<B, 1>(&format!("{p}.ff.fn.net.0.bias"), device),
        ) { set_linear_wb(&mut block.ff.fc1, w, b); }

        // FeedForward fc2: block{i}.ff.fn.net.3
        if let (Ok(w), Ok(b)) = (
            wm.take::<B, 2>(&format!("{p}.ff.fn.net.3.weight"), device),
            wm.take::<B, 1>(&format!("{p}.ff.fn.net.3.bias"), device),
        ) { set_linear_wb(&mut block.ff.fc2, w, b); }
    }

    // ── Final norm ──────────────────────────────────────────────────────────
    if let (Ok(w), Ok(b)) = (
        wm.take::<B, 1>("norm.weight", device),
        wm.take::<B, 1>("norm.bias", device),
    ) { set_layernorm(&mut model.norm, w, b); }

    // Report any leftover keys
    let remaining = wm.remaining_keys();
    if !remaining.is_empty() {
        eprintln!("Warning: {} unused weight keys:", remaining.len());
        for k in &remaining {
            eprintln!("  {k}");
        }
    }

    Ok(())
}