use serde::{Deserialize, Serialize};
use syntaxdot_tch_ext::tensor::SumDim;
use syntaxdot_transformers::models::LayerOutput;
use syntaxdot_transformers::TransformerError;
use tch::{Kind, Tensor};
use crate::error::SyntaxDotError;
use crate::tensor::{TokenSpans, TokenSpansWithRoot};
#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum PiecePooler {
Discard,
Mean,
}
impl PiecePooler {
pub fn pool(
&self,
token_spans: &TokenSpans,
layer_outputs: &[LayerOutput],
) -> Result<Vec<LayerOutput>, SyntaxDotError> {
let mut new_layer_outputs = Vec::with_capacity(layer_outputs.len());
for layer_output in layer_outputs {
let new_layer_output = layer_output
.map_output(|output| self.pool_layer(&token_spans.with_root()?, output))
.map_err(SyntaxDotError::BertError)?;
new_layer_outputs.push(new_layer_output);
}
Ok(new_layer_outputs)
}
fn pool_layer(
&self,
token_spans: &TokenSpansWithRoot,
layer: &Tensor,
) -> Result<Tensor, TransformerError> {
let token_embeddings = token_spans.embeddings_per_token(layer)?;
let pooled_layer = match self {
PiecePooler::Discard => Self::pool_discard(&token_embeddings)?,
PiecePooler::Mean => Self::pool_mean(&token_embeddings)?,
};
Ok(pooled_layer)
}
fn pool_discard(token_embeddings: &TokenEmbeddings) -> Result<Tensor, TransformerError> {
Ok(token_embeddings
.embeddings
.f_slice(2, 0, 1, 1)?
.f_squeeze_dim(2)?)
}
fn pool_mean(token_embeddings: &TokenEmbeddings) -> Result<Tensor, TransformerError> {
let pieces_per_token = token_embeddings
.mask
.f_sum_dim(2, false, Kind::Float)?
.f_clamp_min(1)?;
Ok(token_embeddings
.embeddings
.f_sum_dim(2, false, Kind::Float)?
.f_div(&pieces_per_token.f_unsqueeze(-1)?)?)
}
}
#[derive(Debug)]
struct TokenEmbeddings {
embeddings: Tensor,
mask: Tensor,
}
trait EmbeddingsPerToken {
fn embeddings_per_token(
&self,
embeddings: &Tensor,
) -> Result<TokenEmbeddings, TransformerError>;
}
impl EmbeddingsPerToken for TokenSpansWithRoot {
fn embeddings_per_token(
&self,
embeddings: &Tensor,
) -> Result<TokenEmbeddings, TransformerError> {
let (batch_size, _pieces_len, embed_size) = embeddings.size3()?;
let (_batch_size, tokens_len) = self.offsets().size2()?;
let max_token_len = i64::from(self.lens().max());
let piece_range = Tensor::f_arange(max_token_len, (Kind::Int64, self.lens().device()))?
.f_view([1, 1, max_token_len])?;
let mask = piece_range.less_tensor(&self.lens().f_unsqueeze(-1)?);
let piece_indices = (piece_range + self.offsets().unsqueeze(-1)).f_mul(&mask)?;
let piece_embeddings = embeddings
.f_gather(
1,
&piece_indices
.f_view([batch_size, -1, 1])?
.f_expand(&[-1, -1, embed_size], true)?,
false,
)?
.f_view([batch_size, tokens_len, max_token_len, embed_size])?
.f_mul(&mask.f_unsqueeze(-1)?)?;
Ok(TokenEmbeddings {
embeddings: piece_embeddings,
mask,
})
}
}
#[cfg(test)]
mod tests {
use tch::{Device, Kind, Tensor};
use crate::model::pooling::{EmbeddingsPerToken, PiecePooler};
use crate::tensor::{TokenSpans, TokenSpansWithRoot};
#[test]
fn discard_pooler_works_correctly() {
let spans = TokenSpans::new(
Tensor::of_slice2(&[[1, 3, 4, -1, -1], [1, 3, 4, 6, 7]]),
Tensor::of_slice2(&[[2, 1, 1, -1, -1], [2, 1, 2, 1, 1]]),
);
let hidden = Tensor::arange_start_step(36, 0, -1, (Kind::Int64, Device::Cpu))
.view([2, 9, 2])
.to_kind(Kind::Float);
let pooler = PiecePooler::Discard;
let token_embeddings = pooler
.pool_layer(&spans.with_root().unwrap(), &hidden)
.unwrap();
assert_eq!(
token_embeddings,
Tensor::of_slice2(&[
&[36, 35, 34, 33, 30, 29, 28, 27, 0, 0, 0, 0],
&[18, 17, 16, 15, 12, 11, 10, 9, 6, 5, 4, 3]
])
.view([2, 6, 2])
);
}
#[test]
fn embeddings_are_returned_per_token() {
let spans = TokenSpansWithRoot::new(
Tensor::of_slice2(&[[1, 3, 4, -1, -1], [1, 3, 4, 6, 7]]),
Tensor::of_slice2(&[[2, 1, 1, -1, -1], [2, 1, 2, 1, 1]]),
);
let hidden =
Tensor::arange_start_step(32, 0, -1, (Kind::Int64, Device::Cpu)).view([2, 8, 2]);
let token_embeddings = spans.embeddings_per_token(&hidden).unwrap();
assert_eq!(
token_embeddings.embeddings,
Tensor::of_slice(&[
30, 29, 28, 27, 26, 25, 0, 0, 24, 23, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 13, 12, 11,
10, 9, 0, 0, 8, 7, 6, 5, 4, 3, 0, 0, 2, 1, 0, 0
])
.view([2, 5, 2, 2])
);
assert_eq!(
token_embeddings.mask,
Tensor::of_slice(&[
true, true, true, false, true, false, false, false, false, false, true, true, true,
false, true, true, true, false, true, false
])
.view([2, 5, 2])
);
}
#[test]
fn mean_pooler_works_correctly() {
let spans = TokenSpans::new(
Tensor::of_slice2(&[[1, 3, 4, -1, -1], [1, 3, 4, 6, 7]]),
Tensor::of_slice2(&[[2, 1, 1, -1, -1], [2, 1, 2, 1, 1]]),
);
let hidden = Tensor::arange_start_step(36, 0, -1, (Kind::Int64, Device::Cpu))
.view([2, 9, 2])
.to_kind(Kind::Float);
let pooler = PiecePooler::Mean;
let token_embeddings = pooler
.pool_layer(&spans.with_root().unwrap(), &hidden)
.unwrap();
assert_eq!(
token_embeddings,
Tensor::of_slice2(&[
&[36, 35, 33, 32, 30, 29, 28, 27, 0, 0, 0, 0,],
&[18, 17, 15, 14, 12, 11, 9, 8, 6, 5, 4, 3]
])
.view([2, 6, 2])
);
}
}