use std::borrow::Borrow;
use tch::nn::{Init, Linear, Module, ModuleT, Path};
use tch::{Kind, Reduction, Tensor};
use crate::cow::CowTensor;
use crate::layers::{Dropout, LayerNorm};
use crate::traits::LayerOutput;
#[derive(Debug)]
struct NonLinearWithLayerNorm {
layer_norm: LayerNorm,
linear: Linear,
dropout: Dropout,
}
impl NonLinearWithLayerNorm {
fn new<'a>(
vs: impl Borrow<Path<'a>>,
in_size: i64,
out_size: i64,
dropout: f64,
layer_norm_eps: f64,
) -> Self {
let vs = vs.borrow();
NonLinearWithLayerNorm {
dropout: Dropout::new(dropout),
layer_norm: LayerNorm::new(vs / "layer_norm", vec![out_size], layer_norm_eps, true),
linear: Linear {
ws: vs.var("weight", &[out_size, in_size], Init::KaimingUniform),
bs: vs.var("bias", &[out_size], Init::Const(0.)),
},
}
}
}
impl ModuleT for NonLinearWithLayerNorm {
fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
let mut hidden = self.linear.forward(input).relu();
hidden = self.layer_norm.forward(&hidden);
self.dropout.forward_t(&hidden, train)
}
}
#[derive(Debug)]
pub struct ScalarWeight {
layer_dropout_prob: f64,
layer_weights: Tensor,
scale: Tensor,
}
impl ScalarWeight {
pub fn new<'a>(vs: impl Borrow<Path<'a>>, n_layers: i64, layer_dropout_prob: f64) -> Self {
assert!(
n_layers > 0,
"Number of layers ({}) should be larger than 0",
n_layers
);
assert!(
layer_dropout_prob >= 0.0 && layer_dropout_prob < 1.0,
"Layer dropout should be in [0,1), was: {}",
layer_dropout_prob
);
let vs = vs.borrow();
ScalarWeight {
layer_dropout_prob,
layer_weights: vs.var("layer_weights", &[n_layers], Init::Const(0.)),
scale: vs.var("scale", &[], Init::Const(1.)),
}
}
pub fn forward(&self, layers: &[impl LayerOutput], train: bool) -> Tensor {
assert_eq!(
self.layer_weights.size()[0],
layers.len() as i64,
"Expected {} layers, got {}",
self.layer_weights.size()[0],
layers.len()
);
let layers = layers
.iter()
.map(LayerOutput::layer_output)
.collect::<Vec<_>>();
let layers = Tensor::stack(&layers, 2);
let layer_weights = if train {
let dropout_mask = Tensor::empty_like(&self.layer_weights)
.fill_(1.0 - self.layer_dropout_prob)
.bernoulli();
let softmask_mask = (1.0 - dropout_mask.to_kind(Kind::Float)) * -10_000.;
CowTensor::Owned(&self.layer_weights + softmask_mask)
} else {
CowTensor::Borrowed(&self.layer_weights)
};
let layer_weights = layer_weights
.softmax(-1, Kind::Float)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(-1);
let weighted_layers = layers * layer_weights;
&self.scale * weighted_layers.sum1(&[-2], false, Kind::Float)
}
}
#[derive(Debug)]
pub struct ScalarWeightClassifier {
dropout: Dropout,
scalar_weight: ScalarWeight,
linear: Linear,
non_linear: NonLinearWithLayerNorm,
}
impl ScalarWeightClassifier {
pub fn new<'a>(vs: impl Borrow<Path<'a>>, config: &ScalarWeightClassifierConfig) -> Self {
assert!(
config.n_labels > 0,
"The number of labels should be larger than 0",
);
assert!(
config.input_size > 0,
"The input size should be larger than 0",
);
assert!(
config.hidden_size > 0,
"The hidden size should be larger than 0",
);
let vs = vs.borrow();
let ws = vs.var(
"weight",
&[config.n_labels, config.hidden_size],
Init::KaimingUniform,
);
let bs = vs.var("bias", &[config.n_labels], Init::Const(0.));
let non_linear = NonLinearWithLayerNorm::new(
vs / "nonlinear",
config.input_size,
config.hidden_size,
config.dropout_prob,
config.layer_norm_eps,
);
ScalarWeightClassifier {
dropout: Dropout::new(config.dropout_prob),
linear: Linear { ws, bs },
non_linear,
scalar_weight: ScalarWeight::new(
vs / "scalar_weight",
config.n_layers,
config.layer_dropout_prob,
),
}
}
pub fn forward(&self, layers: &[impl LayerOutput], train: bool) -> Tensor {
let logits = self.logits(layers, train);
logits.softmax(-1, Kind::Float)
}
pub fn logits(&self, layers: &[impl LayerOutput], train: bool) -> Tensor {
let mut features = self.scalar_weight.forward(layers, train);
features = self.dropout.forward_t(&features, train);
features = self.non_linear.forward_t(&features, train);
self.linear.forward(&features)
}
pub fn losses(
&self,
layers: &[impl LayerOutput],
targets: &Tensor,
label_smoothing: Option<f64>,
train: bool,
) -> (Tensor, Tensor) {
let targets_shape = targets.size();
let batch_size = targets_shape[0];
let seq_len = targets_shape[1];
let n_labels = self.linear.ws.size()[0];
let logits = self
.logits(layers, train)
.view([batch_size * seq_len, n_labels]);
let targets = targets.view([batch_size * seq_len]);
let predicted = logits.argmax(-1, false);
let losses =
cross_entropy_loss(&logits, &targets, self.linear.bs.size()[0], label_smoothing)
.view([batch_size, seq_len]);
(losses, predicted.eq1(&targets).view([batch_size, seq_len]))
}
}
pub struct ScalarWeightClassifierConfig {
pub hidden_size: i64,
pub input_size: i64,
pub n_layers: i64,
pub n_labels: i64,
pub layer_dropout_prob: f64,
pub dropout_prob: f64,
pub layer_norm_eps: f64,
}
fn cross_entropy_loss(
logits: &Tensor,
targets: &Tensor,
n_classes: i64,
label_smoothing: Option<f64>,
) -> Tensor {
let probs = logits.log_softmax(-1, Kind::Float);
match label_smoothing {
Some(label_smoothing) => {
let n_classes = n_classes;
let smoothed_targets = tch::no_grad(|| {
Tensor::full_like(&probs, label_smoothing / (n_classes - 1) as f64).scatter1(
1,
&targets.unsqueeze(1),
1. - label_smoothing,
)
});
(-smoothed_targets * probs).sum1(&[-1], false, Kind::Float)
}
None => probs.g_nll_loss::<&Tensor>(&targets, None, Reduction::None, -100),
}
}
#[cfg(test)]
mod tests {
use std::collections::BTreeSet;
use std::convert::TryInto;
use std::iter::FromIterator;
use approx::assert_abs_diff_eq;
use ndarray::{array, ArrayD};
use tch::nn::VarStore;
use tch::{Device, Kind, Tensor};
use super::{cross_entropy_loss, ScalarWeightClassifier, ScalarWeightClassifierConfig};
use crate::models::bert::BertLayerOutput;
fn varstore_variables(vs: &VarStore) -> BTreeSet<String> {
vs.variables()
.into_iter()
.map(|(k, _)| k)
.collect::<BTreeSet<_>>()
}
#[test]
fn cross_entropy_loss_without_label_smoothing() {
let logits = Tensor::of_slice(&[-1., -1., 1., -1., -1.]).view([1, 5]);
let targets = Tensor::of_slice(&[2i64]).view([1]);
let loss: ArrayD<f32> = (&cross_entropy_loss(&logits, &targets, 5, None))
.try_into()
.unwrap();
assert_abs_diff_eq!(loss, array![0.432653].into_dyn(), epsilon = 1e-6);
}
#[test]
fn cross_entropy_with_label_smoothing() {
let logits = Tensor::of_slice(&[-1., -1., 1., -1., -1.]).view([1, 5]);
let targets = Tensor::of_slice(&[2i64]).view([1]);
let loss: ArrayD<f32> = (&cross_entropy_loss(&logits, &targets, 5, Some(0.1)))
.try_into()
.unwrap();
assert_abs_diff_eq!(loss, array![0.632653].into_dyn(), epsilon = 1e-6);
}
#[test]
fn scalar_weight_classifier_shapes_forward_works() {
let vs = VarStore::new(Device::Cpu);
let classifier = ScalarWeightClassifier::new(
vs.root(),
&ScalarWeightClassifierConfig {
hidden_size: 10,
input_size: 8,
n_labels: 5,
n_layers: 2,
dropout_prob: 0.1,
layer_dropout_prob: 0.1,
layer_norm_eps: 0.01,
},
);
let layer1 = BertLayerOutput {
attention: Some(Tensor::zeros(&[1, 3, 2], (Kind::Float, Device::Cpu))),
output: Tensor::zeros(&[1, 3, 8], (Kind::Float, Device::Cpu)),
};
let layer2 = BertLayerOutput {
attention: Some(Tensor::zeros(&[1, 3, 2], (Kind::Float, Device::Cpu))),
output: Tensor::zeros(&[1, 3, 8], (Kind::Float, Device::Cpu)),
};
let results = classifier.forward(&[layer1, layer2], false);
assert_eq!(results.size(), &[1, 3, 5]);
}
#[test]
fn scalar_weight_classifier_names() {
let vs = VarStore::new(Device::Cpu);
let _classifier = ScalarWeightClassifier::new(
vs.root(),
&ScalarWeightClassifierConfig {
hidden_size: 10,
input_size: 8,
n_labels: 5,
n_layers: 2,
dropout_prob: 0.1,
layer_dropout_prob: 0.1,
layer_norm_eps: 0.01,
},
);
assert_eq!(
varstore_variables(&vs),
BTreeSet::from_iter(vec![
"bias".to_string(),
"weight".to_string(),
"nonlinear.bias".to_string(),
"nonlinear.weight".to_string(),
"nonlinear.layer_norm.bias".to_string(),
"nonlinear.layer_norm.weight".to_string(),
"scalar_weight.layer_weights".to_string(),
"scalar_weight.scale".to_string()
])
)
}
}