use crate::albert::AlbertConfig;
use crate::common::dropout::Dropout;
use tch::kind::Kind::Float;
use tch::{nn, Tensor};
#[derive(Debug)]
pub struct AlbertSelfAttention {
num_attention_heads: i64,
attention_head_size: i64,
hidden_size: i64,
dropout: Dropout,
output_attentions: bool,
query: nn::Linear,
key: nn::Linear,
value: nn::Linear,
dense: nn::Linear,
layer_norm: nn::LayerNorm,
}
impl AlbertSelfAttention {
pub fn new(p: nn::Path, config: &AlbertConfig) -> AlbertSelfAttention {
assert_eq!(
config.hidden_size % config.num_attention_heads,
0,
"Hidden size not a multiple of the number of attention heads"
);
let query = nn::linear(
&p / "query",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let key = nn::linear(
&p / "key",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let value = nn::linear(
&p / "value",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dense = nn::linear(
&p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dropout = Dropout::new(config.attention_probs_dropout_prob);
let attention_head_size = config.hidden_size / config.num_attention_heads;
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false,
};
let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value,
None => 1e-12,
};
let layer_norm_config = nn::LayerNormConfig {
eps: layer_norm_eps,
..Default::default()
};
let layer_norm = nn::layer_norm(
&p / "LayerNorm",
vec![config.hidden_size],
layer_norm_config,
);
AlbertSelfAttention {
num_attention_heads: config.num_attention_heads,
attention_head_size,
hidden_size: config.hidden_size,
dropout,
output_attentions,
query,
key,
value,
dense,
layer_norm,
}
}
fn split_heads(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.view((bs, -1, self.num_attention_heads, dim_per_head))
.transpose(1, 2)
}
pub fn forward_t(
&self,
input_ids: &Tensor,
mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let bs = *input_ids.size().first().unwrap();
let key_layer = self.split_heads(input_ids.apply(&self.key), bs, self.attention_head_size);
let value_layer =
self.split_heads(input_ids.apply(&self.value), bs, self.attention_head_size);
let query_layer =
self.split_heads(input_ids.apply(&self.query), bs, self.attention_head_size);
let query_layer: Tensor = query_layer / (self.attention_head_size as f64).sqrt();
let scores = if let Some(mask) = mask {
query_layer.matmul(&key_layer.transpose(-1, -2)) + mask
} else {
query_layer.matmul(&key_layer.transpose(-1, -2))
};
let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
let context = weights.matmul(&value_layer).transpose(1, 2).contiguous();
let w = self.dense.ws.transpose(0, 1).view((
self.num_attention_heads,
self.attention_head_size,
self.hidden_size,
));
let context: Tensor = Tensor::einsum("bfnd,ndh->bfh", &[context, w]) + &self.dense.bs;
let context = (input_ids + context.apply_t(&self.dropout, train)).apply(&self.layer_norm);
if !self.output_attentions {
(context, None)
} else {
(context, Some(weights))
}
}
}