use crate::autograd::Tensor;
use crate::models::bert::config::BertConfig;
use crate::nn::{LayerNorm, Module};
pub struct BertEmbeddings {
pub(crate) word_embeddings: Tensor,
pub(crate) position_embeddings: Tensor,
pub(crate) token_type_embeddings: Tensor,
pub(crate) layer_norm: LayerNorm,
hidden_dim: usize,
max_position_embeddings: usize,
}
impl BertEmbeddings {
#[must_use]
pub fn new(config: &BertConfig) -> Self {
let h = config.hidden_dim;
let we = vec![0.0; config.vocab_size * h];
let pe = vec![0.0; config.max_position_embeddings * h];
let te = vec![0.0; config.type_vocab_size * h];
Self {
word_embeddings: Tensor::from_vec(we, &[config.vocab_size, h]),
position_embeddings: Tensor::from_vec(pe, &[config.max_position_embeddings, h]),
token_type_embeddings: Tensor::from_vec(te, &[config.type_vocab_size, h]),
layer_norm: LayerNorm::with_eps(&[h], config.layer_norm_eps),
hidden_dim: h,
max_position_embeddings: config.max_position_embeddings,
}
}
#[must_use]
pub fn forward(&self, input_ids: &[u32], token_type_ids: &[u32]) -> Tensor {
assert_eq!(
input_ids.len(),
token_type_ids.len(),
"input_ids and token_type_ids must have the same length"
);
assert!(
input_ids.len() <= self.max_position_embeddings,
"sequence length {} exceeds max_position_embeddings {}",
input_ids.len(),
self.max_position_embeddings
);
let seq_len = input_ids.len();
let h = self.hidden_dim;
let mut summed = vec![0.0f32; seq_len * h];
let we_data = self.word_embeddings.data();
let pe_data = self.position_embeddings.data();
let te_data = self.token_type_embeddings.data();
for (i, (&wid, &tid)) in input_ids.iter().zip(token_type_ids).enumerate() {
let dst = &mut summed[i * h..(i + 1) * h];
let w_row = &we_data[wid as usize * h..(wid as usize + 1) * h];
let p_row = &pe_data[i * h..(i + 1) * h];
let t_row = &te_data[tid as usize * h..(tid as usize + 1) * h];
for j in 0..h {
dst[j] = w_row[j] + p_row[j] + t_row[j];
}
}
let summed_tensor = Tensor::from_vec(summed, &[1, seq_len, h]);
self.layer_norm.forward(&summed_tensor)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn embeddings_shape_correct() {
let config = BertConfig::minilm_l6();
let emb = BertEmbeddings::new(&config);
let input_ids = vec![101u32, 2024, 102];
let token_type_ids = vec![0u32, 0, 0];
let out = emb.forward(&input_ids, &token_type_ids);
assert_eq!(out.shape(), &[1, 3, 384]);
}
#[test]
#[should_panic(expected = "must have the same length")]
fn embeddings_mismatched_ids_panics() {
let config = BertConfig::minilm_l6();
let emb = BertEmbeddings::new(&config);
emb.forward(&[101u32, 2024], &[0u32]);
}
#[test]
fn embeddings_handles_paired_input() {
let config = BertConfig::minilm_l6();
let emb = BertEmbeddings::new(&config);
let input_ids = vec![101u32, 2024, 102, 3456, 102];
let token_type_ids = vec![0u32, 0, 0, 1, 1];
let out = emb.forward(&input_ids, &token_type_ids);
assert_eq!(out.shape(), &[1, 5, 384]);
}
}