use burn::{
module::{Module, Param},
tensor::{
backend::{AutodiffBackend, Backend},
ElementConversion, Int, Tensor,
},
train::{
metric::{Adaptor, LossInput},
TrainOutput, TrainStep, ValidStep,
},
data::dataloader::batcher::Batcher,
};
use crate::config::SensorLMConfig;
use crate::data::dataset::SensorTextItem;
use crate::loss::siglip_loss;
use crate::model::sensor_encoder::SensorEncoder;
use crate::model::text_encoder::TextEncoder;
#[derive(Module, Debug)]
pub struct SensorLMModel<B: Backend> {
pub sensor_encoder: SensorEncoder<B>,
pub text_encoder: TextEncoder<B>,
pub log_temperature: Param<Tensor<B, 1>>,
pub bias: Param<Tensor<B, 1>>,
}
impl<B: Backend> SensorLMModel<B> {
pub fn new(cfg: &SensorLMConfig, device: &B::Device) -> Self {
let log_temp = Tensor::<B, 1>::from_floats(
[cfg.temperature_init.ln()],
device,
);
let bias = Tensor::<B, 1>::from_floats([cfg.bias_init], device);
Self {
sensor_encoder: SensorEncoder::new(&cfg.sensor_encoder, device),
text_encoder: TextEncoder::new(&cfg.text_encoder, device),
log_temperature: Param::from_tensor(log_temp),
bias: Param::from_tensor(bias),
}
}
pub fn encode_sensor(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
self.sensor_encoder.forward(x)
}
pub fn encode_text(
&self,
input_ids: Tensor<B, 2, Int>,
attention_mask: Tensor<B, 2, Int>,
) -> Tensor<B, 2> {
self.text_encoder.forward(input_ids, attention_mask)
}
pub fn similarity_matrix(
&self,
z_sensor: Tensor<B, 2>,
z_text: Tensor<B, 2>,
) -> Tensor<B, 2> {
let temperature: f32 = self.log_temperature.val().exp().into_scalar().elem();
let bias: f32 = self.bias.val().into_scalar().elem();
z_sensor.matmul(z_text.transpose())
.mul_scalar(temperature)
.add_scalar(bias)
}
pub fn forward(
&self,
sensor: Tensor<B, 3>,
input_ids: Tensor<B, 2, Int>,
attention_mask: Tensor<B, 2, Int>,
) -> SensorLMOutput<B> {
let z_sensor = self.encode_sensor(sensor);
let z_text = self.encode_text(input_ids, attention_mask);
let logits = self.similarity_matrix(z_sensor, z_text);
let loss = siglip_loss(logits.clone());
SensorLMOutput { loss, logits }
}
}
#[derive(Debug)]
pub struct SensorLMOutput<B: Backend> {
pub loss: Tensor<B, 1>,
pub logits: Tensor<B, 2>,
}
impl<B: Backend> Adaptor<LossInput<B>> for SensorLMOutput<B> {
fn adapt(&self) -> LossInput<B> {
LossInput::new(self.loss.clone())
}
}
#[derive(Debug, Clone)]
pub struct SensorLMBatch<B: Backend> {
pub sensor: Tensor<B, 3>,
pub input_ids: Tensor<B, 2, Int>,
pub attention_mask: Tensor<B, 2, Int>,
}
impl<B: AutodiffBackend> TrainStep<SensorLMBatch<B>, SensorLMOutput<B>>
for SensorLMModel<B>
{
fn step(&self, batch: SensorLMBatch<B>) -> TrainOutput<SensorLMOutput<B>> {
let output = self.forward(batch.sensor, batch.input_ids, batch.attention_mask);
TrainOutput::new(self, output.loss.backward(), output)
}
}
impl<B: Backend> ValidStep<SensorLMBatch<B>, SensorLMOutput<B>> for SensorLMModel<B> {
fn step(&self, batch: SensorLMBatch<B>) -> SensorLMOutput<B> {
self.forward(batch.sensor, batch.input_ids, batch.attention_mask)
}
}
#[derive(Clone)]
pub struct SensorLMBatcher<B: Backend> {
device: B::Device,
time_steps: usize,
num_channels: usize,
max_seq_len: usize,
}
impl<B: Backend> SensorLMBatcher<B> {
pub fn new(
device: B::Device,
time_steps: usize,
num_channels: usize,
max_seq_len: usize,
) -> Self {
Self { device, time_steps, num_channels, max_seq_len }
}
}
impl<B: Backend> Batcher<SensorTextItem, SensorLMBatch<B>> for SensorLMBatcher<B> {
fn batch(&self, items: Vec<SensorTextItem>) -> SensorLMBatch<B> {
let b = items.len();
let t = self.time_steps;
let c = self.num_channels;
let l = self.max_seq_len;
let sensor_flat: Vec<f32> = items.iter()
.flat_map(|it| it.sensor.iter().copied()).collect();
let token_flat: Vec<i32> = items.iter()
.flat_map(|it| it.token_ids.iter().copied()).collect();
let mask_flat: Vec<i32> = items.iter()
.flat_map(|it| it.attention_mask.iter().copied()).collect();
let sensor = Tensor::<B, 1>::from_floats(sensor_flat.as_slice(), &self.device)
.reshape([b, t, c]);
let input_ids = Tensor::<B, 1, Int>::from_ints(token_flat.as_slice(), &self.device)
.reshape([b, l]);
let attention_mask = Tensor::<B, 1, Int>::from_ints(mask_flat.as_slice(), &self.device)
.reshape([b, l]);
SensorLMBatch { sensor, input_ids, attention_mask }
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use crate::config::{SensorEncoderConfig, TextEncoderConfig, PoolType, SensorLMConfig};
type B = NdArray;
fn tiny_config() -> SensorLMConfig {
SensorLMConfig {
sensor_encoder: SensorEncoderConfig {
time_steps: 40,
num_channels: 4,
patch_h: 10,
patch_w: 2,
d_model: 32,
depth: 2,
num_heads: 4,
mlp_dim: 64,
dropout: 0.0,
pool_type: PoolType::Gap,
head_zeroinit: false,
attn_chunk_size: 0,
},
text_encoder: TextEncoderConfig {
vocab_size: 100,
max_seq_len: 16,
d_model: 32,
depth: 2,
num_heads: 4,
mlp_dim: 64,
dropout: 0.0,
out_dim: Some(32),
},
embed_dim: 32,
temperature_init: 10.0,
bias_init: -10.0,
}
}
#[test]
fn test_sensorlm_forward() {
let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
let cfg = tiny_config();
let model = SensorLMModel::<B>::new(&cfg, &device);
let sensor = Tensor::<B, 3>::zeros([2, 40, 4], &device);
let ids = Tensor::<B, 2, Int>::from_ints([[1, 2, 3, 0], [4, 5, 6, 7]], &device);
let mask = Tensor::<B, 2, Int>::from_ints([[1, 1, 1, 0], [1, 1, 1, 1]], &device);
let out = model.forward(sensor, ids, mask);
let [b1, b2] = out.logits.dims();
assert_eq!(b1, 2);
assert_eq!(b2, 2);
let loss: f32 = out.loss.into_scalar();
assert!(!loss.is_nan(), "Loss must not be NaN");
}
}