use burn::{
module::{Module, Param},
nn::{
Dropout, DropoutConfig, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig,
Linear, LinearConfig,
},
tensor::{
backend::Backend,
Distribution, Int, Tensor,
},
};
use crate::config::TextEncoderConfig;
use crate::model::sensor_encoder::{EncoderBlock, l2_normalize};
#[derive(Module, Debug)]
pub struct TextEncoder<B: Backend> {
tok_embed: Embedding<B>,
pos_embed: Param<Tensor<B, 3>>,
blocks: Vec<EncoderBlock<B>>,
norm: LayerNorm<B>,
proj: Option<Linear<B>>,
dropout: Dropout,
d_model: usize,
}
impl<B: Backend> TextEncoder<B> {
pub fn new(cfg: &TextEncoderConfig, device: &B::Device) -> Self {
let tok_embed = EmbeddingConfig::new(cfg.vocab_size, cfg.d_model).init(device);
let pos = Tensor::<B, 3>::random(
[1, cfg.max_seq_len, cfg.d_model],
Distribution::Normal(0.0, (1.0 / cfg.d_model as f64).sqrt()),
device,
);
let blocks: Vec<EncoderBlock<B>> = (0..cfg.depth)
.map(|_| EncoderBlock::new(cfg.d_model, cfg.num_heads, cfg.mlp_dim, cfg.dropout, 0, device))
.collect();
let norm = LayerNormConfig::new(cfg.d_model).init(device);
let proj = cfg.out_dim.map(|out| LinearConfig::new(cfg.d_model, out).init(device));
Self {
tok_embed,
pos_embed: Param::from_tensor(pos),
blocks,
norm,
proj,
dropout: DropoutConfig::new(cfg.dropout).init(),
d_model: cfg.d_model,
}
}
pub fn forward(
&self,
input_ids: Tensor<B, 2, Int>,
attention_mask: Tensor<B, 2, Int>,
) -> Tensor<B, 2> {
let [batch, seq] = input_ids.dims();
let tok = self.tok_embed.forward(input_ids);
let pos = self.pos_embed.val()
.slice([0..1, 0..seq, 0..self.d_model])
.expand([batch, seq, self.d_model]);
let mut x = tok + pos;
x = self.dropout.forward(x);
for block in &self.blocks {
x = block.forward(x);
}
x = self.norm.forward(x);
let mask: Tensor<B, 3> = attention_mask
.float()
.unsqueeze_dim::<3>(2)
.expand([batch, seq, self.d_model]);
let sum = (x * mask.clone()).sum_dim(1);
let counts = mask.sum_dim(1).clamp_min(1.0f32);
let pooled: Tensor<B, 2> = (sum / counts).squeeze(1);
let projected = match &self.proj {
Some(p) => p.forward(pooled),
None => pooled,
};
l2_normalize(projected)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use burn::tensor::Tensor;
type B = NdArray;
fn tiny_cfg() -> TextEncoderConfig {
TextEncoderConfig {
vocab_size: 100,
max_seq_len: 32,
d_model: 32,
depth: 2,
num_heads: 4,
mlp_dim: 64,
dropout: 0.0,
out_dim: Some(32),
}
}
#[test]
fn test_text_encoder_forward() {
let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
let cfg = tiny_cfg();
let encoder = TextEncoder::<B>::new(&cfg, &device);
let ids = Tensor::<B, 2, Int>::from_ints([[1, 2, 3, 0, 0], [4, 5, 6, 7, 0]], &device);
let mask = Tensor::<B, 2, Int>::from_ints([[1, 1, 1, 0, 0], [1, 1, 1, 1, 0]], &device);
let out = encoder.forward(ids, mask);
let [b, d] = out.dims();
assert_eq!(b, 2);
assert_eq!(d, 32);
}
#[test]
fn test_output_unit_norm() {
let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
let cfg = tiny_cfg();
let encoder = TextEncoder::<B>::new(&cfg, &device);
let ids = Tensor::<B, 2, Int>::from_ints([[1, 2, 3]], &device);
let mask = Tensor::<B, 2, Int>::from_ints([[1, 1, 1]], &device);
let out = encoder.forward(ids, mask);
let norm: Vec<f32> = out.powf_scalar(2.0).sum_dim(1).sqrt()
.into_data().to_vec::<f32>().unwrap();
for n in norm {
assert!((n - 1.0).abs() < 1e-5, "Expected unit norm, got {n}");
}
}
}