osf-rs 0.0.1

OSF Sleep Foundation Model — inference in Rust with Burn ML
Documentation
/// Feed-Forward Network for OSF ViT transformer blocks.
///
/// Python: `FeedForward` in vit1d_cls.py:
///   nn.Linear(input_dim, hidden_dim) → GELU → Dropout →
///   nn.Linear(hidden_dim, output_dim) → Dropout
///
/// At inference dropout is disabled.

use burn::prelude::*;
use burn::nn::{Linear, LinearConfig};
use burn::tensor::activation::gelu;

#[derive(Module, Debug)]
pub struct FeedForward<B: Backend> {
    pub fc1: Linear<B>,
    pub fc2: Linear<B>,
}

impl<B: Backend> FeedForward<B> {
    pub fn new(input_dim: usize, output_dim: usize, hidden_dim: usize, device: &B::Device) -> Self {
        Self {
            fc1: LinearConfig::new(input_dim, hidden_dim).with_bias(true).init(device),
            fc2: LinearConfig::new(hidden_dim, output_dim).with_bias(true).init(device),
        }
    }

    /// x: [B, S, dim] → [B, S, dim]
    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
        let h = gelu(self.fc1.forward(x));
        self.fc2.forward(h)
    }
}