use crate::bert::bert_model::BertConfig;
use crate::common::activations::TensorFunction;
use crate::common::dropout::Dropout;
use std::borrow::Borrow;
use tch::{nn, Tensor};
#[derive(Debug)]
pub struct BertSelfAttention {
num_attention_heads: i64,
attention_head_size: i64,
dropout: Dropout,
output_attentions: bool,
query: nn::Linear,
key: nn::Linear,
value: nn::Linear,
}
impl BertSelfAttention {
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertSelfAttention
where
P: Borrow<nn::Path<'p>>,
{
assert_eq!(
config.hidden_size % config.num_attention_heads,
0,
"Hidden size not a multiple of the number of attention heads"
);
let p = p.borrow();
let query = nn::linear(
p / "query",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let key = nn::linear(
p / "key",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let value = nn::linear(
p / "value",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dropout = Dropout::new(config.attention_probs_dropout_prob);
let attention_head_size = config.hidden_size / config.num_attention_heads;
let output_attentions = config.output_attentions.unwrap_or(false);
BertSelfAttention {
num_attention_heads: config.num_attention_heads,
attention_head_size,
dropout,
output_attentions,
query,
key,
value,
}
}
fn split_heads(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.view((bs, -1, self.num_attention_heads, dim_per_head))
.transpose(1, 2)
}
fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.transpose(1, 2)
.contiguous()
.view((bs, -1, self.num_attention_heads * dim_per_head))
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
encoder_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (key_layer, value_layer, mask) = match encoder_hidden_states {
Some(encoder_hidden_state_values) => (
encoder_hidden_state_values.apply(&self.key),
encoder_hidden_state_values.apply(&self.value),
encoder_mask,
),
None => (
hidden_states.apply(&self.key),
hidden_states.apply(&self.value),
mask,
),
};
let bs = hidden_states.size()[0];
let query_layer = self.split_heads(
hidden_states.apply(&self.query),
bs,
self.attention_head_size,
);
let key_layer = self.split_heads(key_layer, bs, self.attention_head_size);
let value_layer = self.split_heads(value_layer, bs, self.attention_head_size);
let query_layer: Tensor = query_layer / (self.attention_head_size as f64).sqrt();
let scores = if let Some(mask) = mask {
query_layer.matmul(&key_layer.transpose(-1, -2)) + mask
} else {
query_layer.matmul(&key_layer.transpose(-1, -2))
};
let weights = scores
.softmax(-1, scores.kind())
.apply_t(&self.dropout, train);
let context = self.flatten(weights.matmul(&value_layer), bs, self.attention_head_size);
if !self.output_attentions {
(context, None)
} else {
(context, Some(weights))
}
}
}
#[derive(Debug)]
pub struct BertSelfOutput {
linear: nn::Linear,
layer_norm: nn::LayerNorm,
dropout: Dropout,
}
impl BertSelfOutput {
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertSelfOutput
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let linear = nn::linear(
p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout = Dropout::new(config.hidden_dropout_prob);
BertSelfOutput {
linear,
layer_norm,
dropout,
}
}
pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
let hidden_states: Tensor = input_tensor
+ hidden_states
.apply(&self.linear)
.apply_t(&self.dropout, train);
hidden_states.apply(&self.layer_norm)
}
}
#[derive(Debug)]
pub struct BertAttention {
_self: BertSelfAttention,
output: BertSelfOutput,
}
impl BertAttention {
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertAttention
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let _self = BertSelfAttention::new(p / "self", config);
let output = BertSelfOutput::new(p / "output", config);
BertAttention { _self, output }
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
encoder_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (self_output, attention_weights) = self._self.forward_t(
hidden_states,
mask,
encoder_hidden_states,
encoder_mask,
train,
);
let self_output = self.output.forward_t(&self_output, hidden_states, train);
(self_output, attention_weights)
}
}
pub struct BertIntermediate {
lin: nn::Linear,
activation: TensorFunction,
}
impl BertIntermediate {
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertIntermediate
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let lin = nn::linear(
p / "dense",
config.hidden_size,
config.intermediate_size,
Default::default(),
);
let activation = config.hidden_act.get_function();
BertIntermediate { lin, activation }
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
(self.activation.get_fn())(&hidden_states.apply(&self.lin))
}
}
pub struct BertOutput {
lin: nn::Linear,
layer_norm: nn::LayerNorm,
dropout: Dropout,
}
impl BertOutput {
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertOutput
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let lin = nn::linear(
p / "dense",
config.intermediate_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout = Dropout::new(config.hidden_dropout_prob);
BertOutput {
lin,
layer_norm,
dropout,
}
}
pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
let hidden_states: Tensor =
input_tensor + hidden_states.apply(&self.lin).apply_t(&self.dropout, train);
hidden_states.apply(&self.layer_norm)
}
}