use std::borrow::Borrow;
use syntaxdot_tch_ext::tensor::SumDim;
use syntaxdot_tch_ext::PathExt;
use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
use tch::nn::{Init, Linear, Module};
use tch::{Kind, Reduction, Tensor};
use crate::cow::CowTensor;
use crate::layers::{Dropout, LayerNorm};
use crate::loss::CrossEntropyLoss;
use crate::models::LayerOutput;
use crate::module::{FallibleModule, FallibleModuleT};
use crate::TransformerError;
#[derive(Debug)]
struct NonLinearWithLayerNorm {
layer_norm: LayerNorm,
linear: Linear,
dropout: Dropout,
}
impl NonLinearWithLayerNorm {
fn new<'a>(
vs: impl Borrow<PathExt<'a>>,
in_size: i64,
out_size: i64,
dropout: f64,
layer_norm_eps: f64,
) -> Result<NonLinearWithLayerNorm, TransformerError> {
let vs = vs.borrow();
Ok(NonLinearWithLayerNorm {
dropout: Dropout::new(dropout),
layer_norm: LayerNorm::new(vs / "layer_norm", vec![out_size], layer_norm_eps, true),
linear: Linear {
ws: vs.var("weight", &[out_size, in_size], DEFAULT_KAIMING_UNIFORM)?,
bs: Some(vs.var("bias", &[out_size], Init::Const(0.))?),
},
})
}
}
impl FallibleModuleT for NonLinearWithLayerNorm {
type Error = TransformerError;
fn forward_t(&self, input: &Tensor, train: bool) -> Result<Tensor, Self::Error> {
let mut hidden = self.linear.forward(input).relu();
hidden = self.layer_norm.forward(&hidden)?;
self.dropout.forward_t(&hidden, train)
}
}
#[derive(Debug)]
pub struct ScalarWeight {
layer_dropout_prob: f64,
layer_weights: Tensor,
scale: Tensor,
}
impl ScalarWeight {
pub fn new<'a>(
vs: impl Borrow<PathExt<'a>>,
n_layers: i64,
layer_dropout_prob: f64,
) -> Result<Self, TransformerError> {
assert!(
n_layers > 0,
"Number of layers ({}) should be larger than 0",
n_layers
);
assert!(
(0.0..1.0).contains(&layer_dropout_prob),
"Layer dropout should be in [0,1), was: {}",
layer_dropout_prob
);
let vs = vs.borrow();
Ok(ScalarWeight {
layer_dropout_prob,
layer_weights: vs.var("layer_weights", &[n_layers], Init::Const(0.))?,
scale: vs.var("scale", &[], Init::Const(1.))?,
})
}
pub fn forward(&self, layers: &[LayerOutput], train: bool) -> Result<Tensor, TransformerError> {
assert_eq!(
self.layer_weights.size()[0],
layers.len() as i64,
"Expected {} layers, got {}",
self.layer_weights.size()[0],
layers.len()
);
let layers = layers.iter().map(LayerOutput::output).collect::<Vec<_>>();
let layers = Tensor::f_stack(&layers, 2)?;
let layer_weights = if train {
let dropout_mask = Tensor::f_empty_like(&self.layer_weights)?
.f_fill_(1.0 - self.layer_dropout_prob)?
.f_bernoulli()?;
let softmax_mask = (Tensor::from(1.0).f_sub(&dropout_mask.to_kind(Kind::Float))?)
.f_mul_scalar(-10_000.)?;
CowTensor::Owned(self.layer_weights.f_add(&softmax_mask)?)
} else {
CowTensor::Borrowed(&self.layer_weights)
};
let layer_weights = layer_weights
.f_softmax(-1, Kind::Float)?
.f_unsqueeze(0)?
.f_unsqueeze(0)?
.f_unsqueeze(-1)?;
let weighted_layers = layers.f_mul(&layer_weights)?;
Ok(weighted_layers
.f_sum_dim(-2, false, Kind::Float)?
.f_mul(&self.scale)?)
}
}
#[derive(Debug)]
pub struct ScalarWeightClassifier {
dropout: Dropout,
scalar_weight: ScalarWeight,
linear: Linear,
non_linear: NonLinearWithLayerNorm,
}
impl ScalarWeightClassifier {
pub fn new<'a>(
vs: impl Borrow<PathExt<'a>>,
config: &ScalarWeightClassifierConfig,
) -> Result<ScalarWeightClassifier, TransformerError> {
assert!(
config.n_labels > 0,
"The number of labels should be larger than 0",
);
assert!(
config.input_size > 0,
"The input size should be larger than 0",
);
assert!(
config.hidden_size > 0,
"The hidden size should be larger than 0",
);
let vs = vs.borrow();
let ws = vs.var(
"weight",
&[config.n_labels, config.hidden_size],
DEFAULT_KAIMING_UNIFORM,
)?;
let bs = vs.var("bias", &[config.n_labels], Init::Const(0.))?;
let non_linear = NonLinearWithLayerNorm::new(
vs / "nonlinear",
config.input_size,
config.hidden_size,
config.dropout_prob,
config.layer_norm_eps,
)?;
Ok(ScalarWeightClassifier {
dropout: Dropout::new(config.dropout_prob),
linear: Linear { ws, bs: Some(bs) },
non_linear,
scalar_weight: ScalarWeight::new(
vs / "scalar_weight",
config.n_layers,
config.layer_dropout_prob,
)?,
})
}
pub fn forward(&self, layers: &[LayerOutput], train: bool) -> Result<Tensor, TransformerError> {
let logits = self.logits(layers, train)?;
Ok(logits.f_softmax(-1, Kind::Float)?)
}
pub fn logits(&self, layers: &[LayerOutput], train: bool) -> Result<Tensor, TransformerError> {
let mut features = self.scalar_weight.forward(layers, train)?;
features = self.dropout.forward_t(&features, train)?;
features = self.non_linear.forward_t(&features, train)?;
Ok(self.linear.forward(&features))
}
pub fn losses(
&self,
layers: &[LayerOutput],
targets: &Tensor,
label_smoothing: Option<f64>,
train: bool,
) -> Result<(Tensor, Tensor), TransformerError> {
assert_eq!(
targets.dim(),
2,
"Targets shoul have dimensionality 2, had {}",
targets.dim()
);
let (batch_size, seq_len) = targets.size2()?;
let n_labels = self.linear.ws.size()[0];
let logits = self
.logits(layers, train)?
.f_view([batch_size * seq_len, n_labels])?;
let targets = targets.f_view([batch_size * seq_len])?;
let predicted = logits.f_argmax(-1, false)?;
let losses = CrossEntropyLoss::new(-1, label_smoothing, Reduction::None)
.forward(&logits, &targets, None)?
.f_view([batch_size, seq_len])?;
Ok((
losses,
predicted
.f_eq_tensor(&targets)?
.f_view([batch_size, seq_len])?,
))
}
}
pub struct ScalarWeightClassifierConfig {
pub hidden_size: i64,
pub input_size: i64,
pub n_layers: i64,
pub n_labels: i64,
pub layer_dropout_prob: f64,
pub dropout_prob: f64,
pub layer_norm_eps: f64,
}
#[cfg(test)]
mod tests {
use std::collections::BTreeSet;
use std::iter::FromIterator;
use syntaxdot_tch_ext::RootExt;
use tch::nn::VarStore;
use tch::{Device, Kind, Tensor};
use super::{ScalarWeightClassifier, ScalarWeightClassifierConfig};
use crate::models::{HiddenLayer, LayerOutput};
fn varstore_variables(vs: &VarStore) -> BTreeSet<String> {
vs.variables().into_keys().collect::<BTreeSet<_>>()
}
#[test]
fn scalar_weight_classifier_shapes_forward_works() {
let vs = VarStore::new(Device::Cpu);
let classifier = ScalarWeightClassifier::new(
vs.root_ext(|_| 0),
&ScalarWeightClassifierConfig {
hidden_size: 10,
input_size: 8,
n_labels: 5,
n_layers: 2,
dropout_prob: 0.1,
layer_dropout_prob: 0.1,
layer_norm_eps: 0.01,
},
)
.unwrap();
let layer1 = LayerOutput::EncoderWithAttention(HiddenLayer {
attention: Tensor::zeros(&[1, 3, 2], (Kind::Float, Device::Cpu)),
output: Tensor::zeros(&[1, 3, 8], (Kind::Float, Device::Cpu)),
});
let layer2 = LayerOutput::EncoderWithAttention(HiddenLayer {
attention: Tensor::zeros(&[1, 3, 2], (Kind::Float, Device::Cpu)),
output: Tensor::zeros(&[1, 3, 8], (Kind::Float, Device::Cpu)),
});
let results = classifier.forward(&[layer1, layer2], false).unwrap();
assert_eq!(results.size(), &[1, 3, 5]);
}
#[test]
fn scalar_weight_classifier_names() {
let vs = VarStore::new(Device::Cpu);
let _classifier = ScalarWeightClassifier::new(
vs.root_ext(|_| 0),
&ScalarWeightClassifierConfig {
hidden_size: 10,
input_size: 8,
n_labels: 5,
n_layers: 2,
dropout_prob: 0.1,
layer_dropout_prob: 0.1,
layer_norm_eps: 0.01,
},
);
assert_eq!(
varstore_variables(&vs),
BTreeSet::from_iter(vec![
"bias".to_string(),
"weight".to_string(),
"nonlinear.bias".to_string(),
"nonlinear.weight".to_string(),
"nonlinear.layer_norm.bias".to_string(),
"nonlinear.layer_norm.weight".to_string(),
"scalar_weight.layer_weights".to_string(),
"scalar_weight.scale".to_string()
])
)
}
}