sensorlm-rs 0.1.0

SensorLM – wearable sensor foundation model in Rust (Burn + WGPU)
Documentation
//! Two-tower SensorLM model and Burn training/validation step wrappers.
//!
//! # Two-tower architecture
//!
//! ```text
//!  sensor_tensor (B,T,C) ──► SensorEncoder ──► z_s (B,D) ─┐
//!                                                           ├─► SigLIP loss
//!  token_ids (B,L)       ──► TextEncoder   ──► z_t (B,D) ─┘
//!
//!  S[i,j] = temperature · dot(z_s[i], z_t[j]) + bias
//!  L = -mean_ij[ log(sigmoid( y[i,j] · S[i,j] )) ]
//!  y[i,j] = +1 if i==j, -1 otherwise
//! ```

use burn::{
    module::{Module, Param},
    tensor::{
        backend::{AutodiffBackend, Backend},
        ElementConversion, Int, Tensor,
    },
    train::{
        metric::{Adaptor, LossInput},
        TrainOutput, TrainStep, ValidStep,
    },
    data::dataloader::batcher::Batcher,
};

use crate::config::SensorLMConfig;
use crate::data::dataset::SensorTextItem;

use crate::loss::siglip_loss;
use crate::model::sensor_encoder::SensorEncoder;
use crate::model::text_encoder::TextEncoder;

// ===========================================================================
// Model
// ===========================================================================

/// The combined SensorLM two-tower model.
#[derive(Module, Debug)]
pub struct SensorLMModel<B: Backend> {
    /// ViT sensor encoder.
    pub sensor_encoder: SensorEncoder<B>,
    /// Text transformer encoder.
    pub text_encoder: TextEncoder<B>,
    /// Log-temperature scalar (temperature = exp(log_temperature) > 0).
    pub log_temperature: Param<Tensor<B, 1>>,
    /// SigLIP bias scalar.
    pub bias: Param<Tensor<B, 1>>,
}

impl<B: Backend> SensorLMModel<B> {
    /// Construct from a [`SensorLMConfig`].
    pub fn new(cfg: &SensorLMConfig, device: &B::Device) -> Self {
        let log_temp = Tensor::<B, 1>::from_floats(
            [cfg.temperature_init.ln()],
            device,
        );
        let bias = Tensor::<B, 1>::from_floats([cfg.bias_init], device);

        Self {
            sensor_encoder: SensorEncoder::new(&cfg.sensor_encoder, device),
            text_encoder:   TextEncoder::new(&cfg.text_encoder, device),
            log_temperature: Param::from_tensor(log_temp),
            bias:            Param::from_tensor(bias),
        }
    }

    /// Encode sensor data → `(B, D)` L2-normalised.
    pub fn encode_sensor(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
        self.sensor_encoder.forward(x)
    }

    /// Encode text → `(B, D)` L2-normalised.
    pub fn encode_text(
        &self,
        input_ids: Tensor<B, 2, Int>,
        attention_mask: Tensor<B, 2, Int>,
    ) -> Tensor<B, 2> {
        self.text_encoder.forward(input_ids, attention_mask)
    }

    /// Compute `(B, B)` similarity matrix.
    pub fn similarity_matrix(
        &self,
        z_sensor: Tensor<B, 2>,
        z_text: Tensor<B, 2>,
    ) -> Tensor<B, 2> {
        // into_scalar() returns B::FloatElem; .elem::<f32>() converts to f32.
        let temperature: f32 = self.log_temperature.val().exp().into_scalar().elem();
        let bias: f32        = self.bias.val().into_scalar().elem();
        z_sensor.matmul(z_text.transpose())
            .mul_scalar(temperature)
            .add_scalar(bias)
    }

    /// Full forward pass computing the SigLIP loss.
    pub fn forward(
        &self,
        sensor: Tensor<B, 3>,
        input_ids: Tensor<B, 2, Int>,
        attention_mask: Tensor<B, 2, Int>,
    ) -> SensorLMOutput<B> {
        let z_sensor = self.encode_sensor(sensor);
        let z_text   = self.encode_text(input_ids, attention_mask);
        let logits   = self.similarity_matrix(z_sensor, z_text);
        let loss     = siglip_loss(logits.clone());
        SensorLMOutput { loss, logits }
    }
}

// ===========================================================================
// Output type
// ===========================================================================

/// Output of a SensorLM forward pass.
#[derive(Debug)]
pub struct SensorLMOutput<B: Backend> {
    /// Scalar SigLIP loss `(1,)`.
    pub loss: Tensor<B, 1>,
    /// `(B, B)` similarity logits.
    pub logits: Tensor<B, 2>,
}

// Teach burn's LossMetric how to extract the loss from our output type.
impl<B: Backend> Adaptor<LossInput<B>> for SensorLMOutput<B> {
    fn adapt(&self) -> LossInput<B> {
        LossInput::new(self.loss.clone())
    }
}

// ===========================================================================
// Batch type
// ===========================================================================

/// A collated training batch.
#[derive(Debug, Clone)]
pub struct SensorLMBatch<B: Backend> {
    /// `(B, T, C)` sensor data.
    pub sensor: Tensor<B, 3>,
    /// `(B, L)` token IDs.
    pub input_ids: Tensor<B, 2, Int>,
    /// `(B, L)` attention mask.
    pub attention_mask: Tensor<B, 2, Int>,
}

// ===========================================================================
// Burn TrainStep / ValidStep
// ===========================================================================

impl<B: AutodiffBackend> TrainStep<SensorLMBatch<B>, SensorLMOutput<B>>
    for SensorLMModel<B>
{
    fn step(&self, batch: SensorLMBatch<B>) -> TrainOutput<SensorLMOutput<B>> {
        let output = self.forward(batch.sensor, batch.input_ids, batch.attention_mask);
        TrainOutput::new(self, output.loss.backward(), output)
    }
}

impl<B: Backend> ValidStep<SensorLMBatch<B>, SensorLMOutput<B>> for SensorLMModel<B> {
    fn step(&self, batch: SensorLMBatch<B>) -> SensorLMOutput<B> {
        self.forward(batch.sensor, batch.input_ids, batch.attention_mask)
    }
}

// ===========================================================================
// Batcher
// ===========================================================================

/// Converts a `Vec<SensorTextItem>` into a GPU-resident `SensorLMBatch`.
#[derive(Clone)]
pub struct SensorLMBatcher<B: Backend> {
    device:       B::Device,
    time_steps:   usize,
    num_channels: usize,
    max_seq_len:  usize,
}

impl<B: Backend> SensorLMBatcher<B> {
    /// Create a new batcher.
    pub fn new(
        device: B::Device,
        time_steps: usize,
        num_channels: usize,
        max_seq_len: usize,
    ) -> Self {
        Self { device, time_steps, num_channels, max_seq_len }
    }
}

impl<B: Backend> Batcher<SensorTextItem, SensorLMBatch<B>> for SensorLMBatcher<B> {
    fn batch(&self, items: Vec<SensorTextItem>) -> SensorLMBatch<B> {
        let b = items.len();
        let t = self.time_steps;
        let c = self.num_channels;
        let l = self.max_seq_len;

        let sensor_flat: Vec<f32> = items.iter()
            .flat_map(|it| it.sensor.iter().copied()).collect();
        let token_flat: Vec<i32> = items.iter()
            .flat_map(|it| it.token_ids.iter().copied()).collect();
        let mask_flat: Vec<i32> = items.iter()
            .flat_map(|it| it.attention_mask.iter().copied()).collect();

        let sensor = Tensor::<B, 1>::from_floats(sensor_flat.as_slice(), &self.device)
            .reshape([b, t, c]);
        let input_ids = Tensor::<B, 1, Int>::from_ints(token_flat.as_slice(), &self.device)
            .reshape([b, l]);
        let attention_mask = Tensor::<B, 1, Int>::from_ints(mask_flat.as_slice(), &self.device)
            .reshape([b, l]);

        SensorLMBatch { sensor, input_ids, attention_mask }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use burn::backend::NdArray;
    use crate::config::{SensorEncoderConfig, TextEncoderConfig, PoolType, SensorLMConfig};

    type B = NdArray;

    fn tiny_config() -> SensorLMConfig {
        SensorLMConfig {
            sensor_encoder: SensorEncoderConfig {
                time_steps: 40,
                num_channels: 4,
                patch_h: 10,
                patch_w: 2,
                d_model: 32,
                depth: 2,
                num_heads: 4,
                mlp_dim: 64,
                dropout: 0.0,
                pool_type: PoolType::Gap,
                head_zeroinit: false,
                attn_chunk_size: 0,
            },
            text_encoder: TextEncoderConfig {
                vocab_size: 100,
                max_seq_len: 16,
                d_model: 32,
                depth: 2,
                num_heads: 4,
                mlp_dim: 64,
                dropout: 0.0,
                out_dim: Some(32),
            },
            embed_dim: 32,
            temperature_init: 10.0,
            bias_init: -10.0,
        }
    }

    #[test]
    fn test_sensorlm_forward() {
        let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
        let cfg = tiny_config();
        let model = SensorLMModel::<B>::new(&cfg, &device);

        let sensor = Tensor::<B, 3>::zeros([2, 40, 4], &device);
        let ids    = Tensor::<B, 2, Int>::from_ints([[1, 2, 3, 0], [4, 5, 6, 7]], &device);
        let mask   = Tensor::<B, 2, Int>::from_ints([[1, 1, 1, 0], [1, 1, 1, 1]], &device);

        let out = model.forward(sensor, ids, mask);
        let [b1, b2] = out.logits.dims();
        assert_eq!(b1, 2);
        assert_eq!(b2, 2);
        let loss: f32 = out.loss.into_scalar();
        assert!(!loss.is_nan(), "Loss must not be NaN");
    }
}