use candle_core::{Result, Tensor};
use candle_nn::Dropout;
use candle_nn::{embedding, Embedding, Module, ModuleT, VarBuilder};
use candle_transformers::models::with_tracing::{layer_norm, LayerNorm};
pub(crate) struct BertEmbeddings {
word_embeddings: Embedding,
position_embeddings: Option<Embedding>,
token_type_embeddings: Embedding,
layer_norm: LayerNorm,
dropout: Dropout,
span: tracing::Span,
}
impl BertEmbeddings {
pub(crate) fn load(vb: VarBuilder, config: &super::Config) -> Result<Self> {
let word_embeddings = embedding(
config.vocab_size,
config.hidden_size,
vb.pp("word_embeddings"),
)?;
let position_embeddings = embedding(
config.max_position_embeddings,
config.hidden_size,
vb.pp("position_embeddings"),
)?;
let token_type_embeddings = embedding(
config.type_vocab_size,
config.hidden_size,
vb.pp("token_type_embeddings"),
)?;
let layer_norm = layer_norm(
config.hidden_size,
config.layer_norm_eps,
vb.pp("LayerNorm"),
)?;
Ok(Self {
word_embeddings,
position_embeddings: Some(position_embeddings),
token_type_embeddings,
layer_norm,
dropout: Dropout::new(config.hidden_dropout_prob),
span: tracing::span!(tracing::Level::TRACE, "embeddings"),
})
}
pub(crate) fn forward(
&self,
input_ids: &Tensor,
token_type_ids: &Tensor,
train: bool,
) -> Result<Tensor> {
let _enter = self.span.enter();
let (_bsize, seq_len) = input_ids.dims2()?;
let input_embeddings = self.word_embeddings.forward(input_ids)?;
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
if let Some(position_embeddings) = &self.position_embeddings {
let position_ids = Tensor::arange(0, seq_len as u32, input_ids.device())?;
embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
}
let embeddings = self.layer_norm.forward(&embeddings)?;
let embeddings = self.dropout.forward_t(&embeddings, train)?;
Ok(embeddings)
}
pub(crate) fn embedding_dim(&self) -> usize {
self.word_embeddings.hidden_size()
}
pub(crate) fn max_seq_len(&self) -> usize {
self.position_embeddings
.as_ref()
.map(|p| p.embeddings().dims()[0])
.unwrap_or(0)
}
}