use candle_core::{Device, Module, Result, Tensor};
use candle_nn::VarBuilder;
use super::config::PlbertConfig;
fn scalar_like(tensor: &Tensor, value: f32) -> Result<Tensor> {
Tensor::new(value, tensor.device())?.to_dtype(tensor.dtype())
}
fn scale_tensor(tensor: &Tensor, value: f32) -> Result<Tensor> {
tensor.broadcast_mul(&scalar_like(tensor, value)?)
}
pub struct Albert {
word_embeddings: candle_nn::Embedding,
position_embeddings: candle_nn::Embedding,
token_type_embeddings: candle_nn::Embedding,
embed_layer_norm: candle_nn::LayerNorm,
embedding_projection: Option<candle_nn::Linear>,
shared_layer: AlbertLayer,
num_hidden_layers: usize,
_num_hidden_groups: usize,
hidden_size: usize,
}
struct AlbertLayer {
attention: AlbertAttention,
ffn: candle_nn::Linear,
ffn_output: candle_nn::Linear,
attention_norm: candle_nn::LayerNorm,
output_norm: candle_nn::LayerNorm,
}
struct AlbertAttention {
query: candle_nn::Linear,
key: candle_nn::Linear,
value: candle_nn::Linear,
dense: candle_nn::Linear,
num_heads: usize,
head_dim: usize,
}
impl Albert {
pub fn load(config: &PlbertConfig, vb: VarBuilder, _device: &Device) -> Result<Self> {
let emb_vb = vb.pp("embeddings");
let word_embeddings = candle_nn::embedding(
config.vocab_size,
config.embedding_size,
emb_vb.pp("word_embeddings"),
)?;
let position_embeddings = candle_nn::embedding(
config.max_position_embeddings,
config.embedding_size,
emb_vb.pp("position_embeddings"),
)?;
let token_type_embeddings = candle_nn::embedding(
config.type_vocab_size,
config.embedding_size,
emb_vb.pp("token_type_embeddings"),
)?;
let embed_layer_norm = candle_nn::layer_norm(
config.embedding_size,
candle_nn::LayerNormConfig::default(),
emb_vb.pp("LayerNorm"),
)?;
let enc_vb = vb.pp("encoder");
let embedding_projection = if config.embedding_size != config.hidden_size {
Some(candle_nn::linear(
config.embedding_size,
config.hidden_size,
enc_vb.pp("embedding_hidden_mapping_in"),
)?)
} else {
None
};
let layer_vb = enc_vb
.pp("albert_layer_groups")
.pp("0")
.pp("albert_layers")
.pp("0");
let shared_layer = AlbertLayer::load(config, layer_vb)?;
Ok(Self {
word_embeddings,
position_embeddings,
token_type_embeddings,
embed_layer_norm,
embedding_projection,
shared_layer,
num_hidden_layers: config.num_hidden_layers,
_num_hidden_groups: config.num_hidden_groups,
hidden_size: config.hidden_size,
})
}
pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let (_batch, seq_len) = input_ids.dims2()?;
let device = input_ids.device();
let position_ids: Vec<u32> = (0..seq_len as u32).collect();
let position_ids = Tensor::new(position_ids.as_slice(), device)?
.unsqueeze(0)?
.broadcast_as(input_ids.shape())?;
let token_type_ids = Tensor::zeros(input_ids.shape(), candle_core::DType::U32, device)?;
let word_emb = self.word_embeddings.forward(input_ids)?;
let pos_emb = self.position_embeddings.forward(&position_ids)?;
let type_emb = self.token_type_embeddings.forward(&token_type_ids)?;
let embeddings = word_emb.add(&pos_emb)?.add(&type_emb)?;
let embeddings = self.embed_layer_norm.forward(&embeddings)?;
let mut hidden = match &self.embedding_projection {
Some(proj) => proj.forward(&embeddings)?,
None => embeddings,
};
let attn_mask = attention_mask
.to_dtype(hidden.dtype())?
.unsqueeze(1)?
.unsqueeze(2)?;
let inv_mask = attn_mask.neg()?.add(&Tensor::ones_like(&attn_mask)?)?;
let attn_bias = inv_mask.broadcast_mul(&scalar_like(&inv_mask, -10000.0)?)?;
for _ in 0..self.num_hidden_layers {
hidden = self.shared_layer.forward(&hidden, &attn_bias)?;
}
Ok(hidden)
}
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
}
impl AlbertLayer {
fn load(config: &PlbertConfig, vb: VarBuilder) -> Result<Self> {
let attn_vb = vb.pp("attention");
let attention = AlbertAttention::load(config, attn_vb)?;
let ffn = candle_nn::linear(config.hidden_size, config.intermediate_size, vb.pp("ffn"))?;
let ffn_output = candle_nn::linear(
config.intermediate_size,
config.hidden_size,
vb.pp("ffn_output"),
)?;
let attention_norm = candle_nn::layer_norm(
config.hidden_size,
candle_nn::LayerNormConfig::default(),
vb.pp("attention").pp("LayerNorm"),
)?;
let output_norm = candle_nn::layer_norm(
config.hidden_size,
candle_nn::LayerNormConfig::default(),
vb.pp("full_layer_layer_norm"),
)?;
Ok(Self {
attention,
ffn,
ffn_output,
attention_norm,
output_norm,
})
}
fn forward(&self, hidden: &Tensor, attn_bias: &Tensor) -> Result<Tensor> {
let attn_out = self.attention.forward(hidden, attn_bias)?;
let hidden = self.attention_norm.forward(&hidden.add(&attn_out)?)?;
let ffn_out = self.ffn.forward(&hidden)?;
let ffn_out = ffn_out.gelu_erf()?;
let ffn_out = self.ffn_output.forward(&ffn_out)?;
let hidden = self.output_norm.forward(&hidden.add(&ffn_out)?)?;
Ok(hidden)
}
}
impl AlbertAttention {
fn load(config: &PlbertConfig, vb: VarBuilder) -> Result<Self> {
let head_dim = config.hidden_size / config.num_attention_heads;
let query = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("query"))?;
let key = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("key"))?;
let value = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("value"))?;
let dense = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
Ok(Self {
query,
key,
value,
dense,
num_heads: config.num_attention_heads,
head_dim,
})
}
fn forward(&self, hidden: &Tensor, attn_bias: &Tensor) -> Result<Tensor> {
let (batch, seq_len, _) = hidden.dims3()?;
let q = self.query.forward(hidden)?;
let k = self.key.forward(hidden)?;
let v = self.value.forward(hidden)?;
let q = q
.reshape((batch, seq_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((batch, seq_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let v = v
.reshape((batch, seq_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let scale = 1.0f32 / (self.head_dim as f32).sqrt();
let k_t = k.transpose(2, 3)?.contiguous()?;
let attn_weights = q.matmul(&k_t)?;
let attn_weights = scale_tensor(&attn_weights, scale)?;
let attn_weights = attn_weights.broadcast_add(attn_bias)?;
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_out = attn_weights.matmul(&v)?;
let attn_out =
attn_out
.transpose(1, 2)?
.reshape((batch, seq_len, self.num_heads * self.head_dim))?;
self.dense.forward(&attn_out)
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_albert_config_defaults() {
use super::super::config::PlbertConfig;
let config: PlbertConfig = serde_json::from_str("{}").unwrap();
assert_eq!(config.hidden_size, 768);
assert_eq!(config.num_attention_heads, 12);
assert_eq!(config.embedding_size, 128);
}
}