osf-rs 0.0.1

OSF Sleep Foundation Model — inference in Rust with Burn ML
Documentation
/// Multi-Head Self-Attention for OSF ViT.
///
/// Python: `Attention` in vit1d_cls.py:
///   to_qkv = nn.Linear(input_dim, inner_dim * 3, bias=qkv_bias)
///   to_out = nn.Sequential(nn.Linear(inner_dim, output_dim), nn.Dropout)
///
/// qkv = to_qkv(x).chunk(3, dim=-1)
/// q, k, v = map(λ: rearrange(t, 'b n (h d) -> b h n d'), qkv)
/// attn = softmax(q @ k^T / sqrt(d)) @ v
/// out = rearrange(out, 'b h n d -> b n (h d)')
/// out = to_out(out)

use burn::prelude::*;
use burn::nn::{Linear, LinearConfig};
use burn::tensor::activation::softmax;

#[derive(Module, Debug)]
pub struct Attention<B: Backend> {
    pub to_qkv: Linear<B>,
    pub to_out: Linear<B>,
    pub n_heads: usize,
    pub dim_head: usize,
}

impl<B: Backend> Attention<B> {
    pub fn new(
        input_dim: usize,
        output_dim: usize,
        heads: usize,
        dim_head: usize,
        qkv_bias: bool,
        device: &B::Device,
    ) -> Self {
        let inner_dim = dim_head * heads;
        Self {
            to_qkv: LinearConfig::new(input_dim, inner_dim * 3)
                .with_bias(qkv_bias)
                .init(device),
            to_out: LinearConfig::new(inner_dim, output_dim)
                .with_bias(true)
                .init(device),
            n_heads: heads,
            dim_head,
        }
    }

    /// x: [B, N, D] → [B, N, D_out]
    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
        let [b, n, _] = x.dims();
        let (h, dh) = (self.n_heads, self.dim_head);
        let inner_dim = h * dh;

        // QKV projection: [B, N, 3 * inner_dim]
        let qkv = self.to_qkv.forward(x);

        // Split into Q, K, V: each [B, N, inner_dim]
        let q = qkv.clone().narrow(2, 0, inner_dim);
        let k = qkv.clone().narrow(2, inner_dim, inner_dim);
        let v = qkv.narrow(2, inner_dim * 2, inner_dim);

        // Reshape to [B, H, N, D]
        let q = q.reshape([b, n, h, dh]).swap_dims(1, 2);
        let k = k.reshape([b, n, h, dh]).swap_dims(1, 2);
        let v = v.reshape([b, n, h, dh]).swap_dims(1, 2);

        // Scaled dot-product attention
        let scale = (dh as f64).powf(-0.5) as f32;
        let attn = softmax(q.matmul(k.transpose()).mul_scalar(scale), 3);
        let out = attn.matmul(v); // [B, H, N, D]

        // Reshape back: [B, N, inner_dim]
        let out = out.swap_dims(1, 2).reshape([b, n, inner_dim]);
        self.to_out.forward(out)
    }
}