use std::borrow::Borrow;
use serde::Deserialize;
use syntaxdot_tch_ext::PathExt;
use tch::nn::Module;
use tch::Tensor;
use crate::activations::Activation;
use crate::error::TransformerError;
use crate::models::albert::{AlbertConfig, AlbertEmbeddingProjection};
use crate::models::bert::BertConfig;
use crate::models::layer_output::LayerOutput;
use crate::models::squeeze_bert::{SqueezeBertConfig, SqueezeBertLayer};
use crate::models::traits::WordEmbeddingsConfig;
use crate::models::Encoder;
use crate::util::LogitsMask;
#[derive(Debug, Deserialize)]
#[serde(default)]
pub struct SqueezeAlbertConfig {
pub attention_probs_dropout_prob: f64,
pub embedding_size: i64,
pub hidden_act: Activation,
pub hidden_dropout_prob: f64,
pub hidden_size: i64,
pub initializer_range: f64,
pub inner_group_num: i64,
pub intermediate_size: i64,
pub max_position_embeddings: i64,
pub num_attention_heads: i64,
pub num_hidden_groups: i64,
pub num_hidden_layers: i64,
pub type_vocab_size: i64,
pub vocab_size: i64,
pub q_groups: i64,
pub k_groups: i64,
pub v_groups: i64,
pub post_attention_groups: i64,
pub intermediate_groups: i64,
pub output_groups: i64,
}
impl Default for SqueezeAlbertConfig {
fn default() -> Self {
SqueezeAlbertConfig {
attention_probs_dropout_prob: 0.,
embedding_size: 128,
hidden_act: Activation::GeluNew,
hidden_dropout_prob: 0.,
hidden_size: 768,
initializer_range: 0.02,
inner_group_num: 1,
intermediate_size: 3072,
max_position_embeddings: 512,
num_attention_heads: 12,
num_hidden_groups: 1,
num_hidden_layers: 12,
type_vocab_size: 2,
vocab_size: 30000,
q_groups: 4,
k_groups: 4,
v_groups: 4,
post_attention_groups: 1,
intermediate_groups: 4,
output_groups: 4,
}
}
}
impl From<&SqueezeAlbertConfig> for AlbertConfig {
fn from(albert_config: &SqueezeAlbertConfig) -> Self {
AlbertConfig {
attention_probs_dropout_prob: albert_config.attention_probs_dropout_prob,
embedding_size: albert_config.embedding_size,
hidden_act: albert_config.hidden_act,
hidden_dropout_prob: albert_config.hidden_dropout_prob,
hidden_size: albert_config.hidden_size,
initializer_range: albert_config.initializer_range,
inner_group_num: albert_config.inner_group_num,
intermediate_size: albert_config.intermediate_size,
max_position_embeddings: albert_config.max_position_embeddings,
num_attention_heads: albert_config.num_attention_heads,
num_hidden_groups: albert_config.num_hidden_groups,
num_hidden_layers: albert_config.num_hidden_layers,
type_vocab_size: albert_config.type_vocab_size,
vocab_size: albert_config.vocab_size,
}
}
}
impl From<&SqueezeAlbertConfig> for BertConfig {
fn from(albert_config: &SqueezeAlbertConfig) -> Self {
BertConfig {
attention_probs_dropout_prob: albert_config.attention_probs_dropout_prob,
hidden_act: albert_config.hidden_act,
hidden_dropout_prob: albert_config.hidden_dropout_prob,
hidden_size: albert_config.hidden_size,
initializer_range: albert_config.initializer_range,
intermediate_size: albert_config.intermediate_size,
layer_norm_eps: 1e-12,
max_position_embeddings: albert_config.max_position_embeddings,
num_attention_heads: albert_config.num_attention_heads,
num_hidden_layers: albert_config.num_hidden_layers,
type_vocab_size: albert_config.type_vocab_size,
vocab_size: albert_config.vocab_size,
}
}
}
impl From<&SqueezeAlbertConfig> for SqueezeBertConfig {
fn from(config: &SqueezeAlbertConfig) -> Self {
SqueezeBertConfig {
attention_probs_dropout_prob: config.attention_probs_dropout_prob,
embedding_size: config.embedding_size,
hidden_act: config.hidden_act,
hidden_dropout_prob: config.hidden_dropout_prob,
hidden_size: config.hidden_size,
initializer_range: config.initializer_range,
intermediate_size: config.intermediate_size,
layer_norm_eps: config.layer_norm_eps(),
max_position_embeddings: config.max_position_embeddings,
num_attention_heads: config.num_attention_heads,
num_hidden_layers: config.num_hidden_layers,
type_vocab_size: config.type_vocab_size,
vocab_size: config.vocab_size,
q_groups: config.q_groups,
k_groups: config.k_groups,
v_groups: config.v_groups,
post_attention_groups: config.post_attention_groups,
intermediate_groups: config.intermediate_groups,
output_groups: config.output_groups,
}
}
}
impl WordEmbeddingsConfig for SqueezeAlbertConfig {
fn dims(&self) -> i64 {
self.embedding_size
}
fn dropout(&self) -> f64 {
self.hidden_dropout_prob
}
fn initializer_range(&self) -> f64 {
self.initializer_range
}
fn layer_norm_eps(&self) -> f64 {
1e-12
}
fn vocab_size(&self) -> i64 {
self.vocab_size
}
}
#[derive(Debug)]
pub struct SqueezeAlbertEncoder {
groups: Vec<SqueezeBertLayer>,
n_layers: i64,
projection: AlbertEmbeddingProjection,
}
impl SqueezeAlbertEncoder {
pub fn new<'a>(
vs: impl Borrow<PathExt<'a>>,
config: &SqueezeAlbertConfig,
) -> Result<Self, TransformerError> {
assert!(
config.num_hidden_groups > 0,
"Need at least 1 hidden group, got: {}",
config.num_hidden_groups
);
let vs = vs.borrow();
let mut groups = Vec::with_capacity(config.num_hidden_groups as usize);
for group_idx in 0..config.num_hidden_groups {
groups.push(SqueezeBertLayer::new(
vs.sub(format!("group_{}", group_idx)).sub("inner_group_0"),
&config.into(),
)?);
}
let albert_config: AlbertConfig = config.into();
let projection = AlbertEmbeddingProjection::new(vs, &albert_config)?;
Ok(SqueezeAlbertEncoder {
groups,
n_layers: config.num_hidden_layers,
projection,
})
}
}
impl Encoder for SqueezeAlbertEncoder {
fn encode(
&self,
input: &Tensor,
attention_mask: Option<&Tensor>,
train: bool,
) -> Result<Vec<LayerOutput>, TransformerError> {
let hidden_states = self.projection.forward(input);
let input = hidden_states.f_permute(&[0, 2, 1])?;
let mut all_layer_outputs = Vec::with_capacity(self.n_layers as usize + 1);
all_layer_outputs.push(LayerOutput::Embedding(hidden_states.shallow_clone()));
let attention_mask = attention_mask.map(LogitsMask::from_bool_mask).transpose()?;
let layers_per_group = self.n_layers as usize / self.groups.len();
let mut hidden_states = input;
for idx in 0..self.n_layers {
let layer_output = self.groups[idx as usize / layers_per_group].forward_t(
&hidden_states,
attention_mask.as_ref(),
train,
)?;
hidden_states = layer_output.output().shallow_clone();
all_layer_outputs.push(layer_output);
}
for layer_output in &mut all_layer_outputs {
*layer_output.output_mut() = layer_output.output().f_permute(&[0, 2, 1])?;
}
Ok(all_layer_outputs)
}
fn n_layers(&self) -> i64 {
self.n_layers + 1
}
}