use crate::bert::attention::{BertAttention, BertIntermediate, BertOutput};
use crate::bert::bert_model::BertConfig;
use std::borrow::{Borrow, BorrowMut};
use tch::{nn, Tensor};
pub struct BertLayer {
attention: BertAttention,
is_decoder: bool,
cross_attention: Option<BertAttention>,
intermediate: BertIntermediate,
output: BertOutput,
}
impl BertLayer {
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertLayer
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let attention = BertAttention::new(p / "attention", config);
let (is_decoder, cross_attention) = match config.is_decoder {
Some(value) => {
if value {
(
value,
Some(BertAttention::new(p / "cross_attention", config)),
)
} else {
(value, None)
}
}
None => (false, None),
};
let intermediate = BertIntermediate::new(p / "intermediate", config);
let output = BertOutput::new(p / "output", config);
BertLayer {
attention,
is_decoder,
cross_attention,
intermediate,
output,
}
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
encoder_mask: Option<&Tensor>,
train: bool,
) -> BertLayerOutput {
let (attention_output, attention_weights) =
self.attention
.forward_t(hidden_states, mask, None, None, train);
let (attention_output, attention_scores, cross_attention_scores) =
if self.is_decoder & encoder_hidden_states.is_some() {
let (attention_output, cross_attention_weights) =
self.cross_attention.as_ref().unwrap().forward_t(
&attention_output,
mask,
encoder_hidden_states,
encoder_mask,
train,
);
(attention_output, attention_weights, cross_attention_weights)
} else {
(attention_output, attention_weights, None)
};
let output = self.intermediate.forward(&attention_output);
let output = self.output.forward_t(&output, &attention_output, train);
BertLayerOutput {
hidden_state: output,
attention_weights: attention_scores,
cross_attention_weights: cross_attention_scores,
}
}
}
pub struct BertEncoder {
output_attentions: bool,
output_hidden_states: bool,
layers: Vec<BertLayer>,
}
impl BertEncoder {
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertEncoder
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow() / "layer";
let output_attentions = config.output_attentions.unwrap_or(false);
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
let mut layers: Vec<BertLayer> = vec![];
for layer_index in 0..config.num_hidden_layers {
layers.push(BertLayer::new(&p / layer_index, config));
}
BertEncoder {
output_attentions,
output_hidden_states,
layers,
}
}
pub fn forward_t(
&self,
input: &Tensor,
mask: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
encoder_mask: Option<&Tensor>,
train: bool,
) -> BertEncoderOutput {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut hidden_state = None::<Tensor>;
let mut attention_weights: Option<Tensor>;
for layer in &self.layers {
let layer_output = if let Some(hidden_state) = &hidden_state {
layer.forward_t(
hidden_state,
mask,
encoder_hidden_states,
encoder_mask,
train,
)
} else {
layer.forward_t(input, mask, encoder_hidden_states, encoder_mask, train)
};
hidden_state = Some(layer_output.hidden_state);
attention_weights = layer_output.attention_weights;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(std::mem::take(&mut attention_weights.unwrap()));
};
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().unwrap().copy());
};
}
BertEncoderOutput {
hidden_state: hidden_state.unwrap(),
all_hidden_states,
all_attentions,
}
}
}
pub struct BertPooler {
lin: nn::Linear,
}
impl BertPooler {
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertPooler
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let lin = nn::linear(
p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
BertPooler { lin }
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
hidden_states.select(1, 0).apply(&self.lin).tanh()
}
}
pub struct BertLayerOutput {
pub hidden_state: Tensor,
pub attention_weights: Option<Tensor>,
pub cross_attention_weights: Option<Tensor>,
}
pub struct BertEncoderOutput {
pub hidden_state: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}