sensorlm-rs 0.1.0

SensorLM – wearable sensor foundation model in Rust (Burn + WGPU)
Documentation
//! Text transformer encoder.
//!
//! Encodes token sequences into fixed-size L2-normalised embeddings using
//! a 12-layer bidirectional transformer with masked mean-pooling.
//!
//! # Architecture
//!
//! ```text
//! token_ids (B, L)
//!   → TokenEmbedding(vocab_size, D) + PositionalEmbedding(max_len, D)
//!   → Dropout
//!   → [EncoderBlock × depth]
//!   → LayerNorm
//!   → MaskedMeanPool  → (B, D)
//!   → Linear(D, out)  → (B, D)
//!   → L2-normalise    → (B, D)
//! ```

use burn::{
    module::{Module, Param},
    nn::{
        Dropout, DropoutConfig, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig,
        Linear, LinearConfig,
    },
    tensor::{
        backend::Backend,
        Distribution, Int, Tensor,
    },
};

use crate::config::TextEncoderConfig;
use crate::model::sensor_encoder::{EncoderBlock, l2_normalize};

/// Bidirectional transformer text encoder.
#[derive(Module, Debug)]
pub struct TextEncoder<B: Backend> {
    tok_embed: Embedding<B>,
    pos_embed: Param<Tensor<B, 3>>,
    blocks:    Vec<EncoderBlock<B>>,
    norm:      LayerNorm<B>,
    proj:      Option<Linear<B>>,
    dropout:   Dropout,
    d_model:   usize,
}

impl<B: Backend> TextEncoder<B> {
    /// Build a text encoder from [`TextEncoderConfig`].
    pub fn new(cfg: &TextEncoderConfig, device: &B::Device) -> Self {
        let tok_embed = EmbeddingConfig::new(cfg.vocab_size, cfg.d_model).init(device);

        let pos = Tensor::<B, 3>::random(
            [1, cfg.max_seq_len, cfg.d_model],
            Distribution::Normal(0.0, (1.0 / cfg.d_model as f64).sqrt()),
            device,
        );

        let blocks: Vec<EncoderBlock<B>> = (0..cfg.depth)
            .map(|_| EncoderBlock::new(cfg.d_model, cfg.num_heads, cfg.mlp_dim, cfg.dropout, 0, device))
            .collect();

        let norm = LayerNormConfig::new(cfg.d_model).init(device);
        let proj = cfg.out_dim.map(|out| LinearConfig::new(cfg.d_model, out).init(device));

        Self {
            tok_embed,
            pos_embed: Param::from_tensor(pos),
            blocks,
            norm,
            proj,
            dropout: DropoutConfig::new(cfg.dropout).init(),
            d_model: cfg.d_model,
        }
    }

    /// Encode token sequences to L2-normalised embeddings.
    ///
    /// # Arguments
    ///
    /// * `input_ids`      – `(B, L)` token IDs.
    /// * `attention_mask` – `(B, L)` mask; `1` = real token, `0` = padding.
    pub fn forward(
        &self,
        input_ids: Tensor<B, 2, Int>,
        attention_mask: Tensor<B, 2, Int>,
    ) -> Tensor<B, 2> {
        let [batch, seq] = input_ids.dims();

        // Token + positional embeddings.
        let tok = self.tok_embed.forward(input_ids);
        let pos = self.pos_embed.val()
            .slice([0..1, 0..seq, 0..self.d_model])
            .expand([batch, seq, self.d_model]);

        let mut x = tok + pos;
        x = self.dropout.forward(x);

        for block in &self.blocks {
            x = block.forward(x);
        }
        x = self.norm.forward(x);

        // Masked mean pool.
        // unsqueeze_dim::<3>(2) inserts a dimension at index 2: (B,L) → (B,L,1)
        let mask: Tensor<B, 3> = attention_mask
            .float()
            .unsqueeze_dim::<3>(2)
            .expand([batch, seq, self.d_model]);

        let sum    = (x * mask.clone()).sum_dim(1);
        let counts = mask.sum_dim(1).clamp_min(1.0f32);
        let pooled: Tensor<B, 2> = (sum / counts).squeeze(1);

        let projected = match &self.proj {
            Some(p) => p.forward(pooled),
            None    => pooled,
        };

        l2_normalize(projected)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use burn::backend::NdArray;
    use burn::tensor::Tensor;

    type B = NdArray;

    fn tiny_cfg() -> TextEncoderConfig {
        TextEncoderConfig {
            vocab_size: 100,
            max_seq_len: 32,
            d_model: 32,
            depth: 2,
            num_heads: 4,
            mlp_dim: 64,
            dropout: 0.0,
            out_dim: Some(32),
        }
    }

    #[test]
    fn test_text_encoder_forward() {
        let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
        let cfg = tiny_cfg();
        let encoder = TextEncoder::<B>::new(&cfg, &device);

        let ids  = Tensor::<B, 2, Int>::from_ints([[1, 2, 3, 0, 0], [4, 5, 6, 7, 0]], &device);
        let mask = Tensor::<B, 2, Int>::from_ints([[1, 1, 1, 0, 0], [1, 1, 1, 1, 0]], &device);

        let out = encoder.forward(ids, mask);
        let [b, d] = out.dims();
        assert_eq!(b, 2);
        assert_eq!(d, 32);
    }

    #[test]
    fn test_output_unit_norm() {
        let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
        let cfg = tiny_cfg();
        let encoder = TextEncoder::<B>::new(&cfg, &device);

        let ids  = Tensor::<B, 2, Int>::from_ints([[1, 2, 3]], &device);
        let mask = Tensor::<B, 2, Int>::from_ints([[1, 1, 1]], &device);

        let out = encoder.forward(ids, mask);
        let norm: Vec<f32> = out.powf_scalar(2.0).sum_dim(1).sqrt()
            .into_data().to_vec::<f32>().unwrap();
        for n in norm {
            assert!((n - 1.0).abs() < 1e-5, "Expected unit norm, got {n}");
        }
    }
}