rust-bert 0.7.2

Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)
Documentation
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//     http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use tch::{Tensor, nn};
use crate::distilbert::distilbert::{DistilBertConfig, Activation};
use crate::distilbert::attention::MultiHeadSelfAttention;
use tch::nn::LayerNorm;
use std::borrow::BorrowMut;
use crate::common::dropout::Dropout;
use crate::common::activations::{_gelu, _relu};

pub struct FeedForwardNetwork {
    lin1: nn::Linear,
    lin2: nn::Linear,
    dropout: Dropout,
    activation: Box<dyn Fn(&Tensor) -> Tensor>,
}

impl FeedForwardNetwork {
    pub fn new(p: nn::Path, config: &DistilBertConfig) -> FeedForwardNetwork {
        let lin1 = nn::linear(&p / "lin1", config.dim, config.hidden_dim, Default::default());
        let lin2 = nn::linear(&p / "lin2", config.hidden_dim, config.dim, Default::default());
        let dropout = Dropout::new(config.dropout);
        let activation = Box::new(match &config.activation {
            Activation::gelu => _gelu,
            Activation::relu => _relu
        });
        FeedForwardNetwork { lin1, lin2, dropout, activation }
    }

    pub fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
        (self.activation)(&input.apply(&self.lin1)).apply(&self.lin2).apply_t(&self.dropout, train)
    }
}

pub struct TransformerBlock {
    attention: MultiHeadSelfAttention,
    sa_layer_norm: LayerNorm,
    ffn: FeedForwardNetwork,
    output_layer_norm: LayerNorm,
}

impl TransformerBlock {
    pub fn new(p: &nn::Path, config: &DistilBertConfig) -> TransformerBlock {
        let attention = MultiHeadSelfAttention::new(p / "attention", &config);
        let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
        let sa_layer_norm = nn::layer_norm(p / "sa_layer_norm", vec![config.dim], layer_norm_config);
        let ffn = FeedForwardNetwork::new(p / "ffn", &config);
        let output_layer_norm = nn::layer_norm(p / "output_layer_norm", vec![config.dim], layer_norm_config);

        TransformerBlock {
            attention,
            sa_layer_norm,
            ffn,
            output_layer_norm,
        }
    }

    pub fn forward_t(&self, input: &Tensor, mask: &Option<Tensor>, train: bool) -> (Tensor, Option<Tensor>) {
        let (output, sa_weights) = self.attention.forward_t(&input, &input, &input, mask, train);
        let output = (input + &output).apply(&self.sa_layer_norm);
        let output = (&output + self.ffn.forward_t(&output, train)).apply(&self.output_layer_norm);
        (output, sa_weights)
    }
}

pub struct Transformer {
    output_attentions: bool,
    output_hidden_states: bool,
    layers: Vec<TransformerBlock>,
}

impl Transformer {
    pub fn new(p: &nn::Path, config: &DistilBertConfig) -> Transformer {
        let p = &(p / "layer");
        let output_attentions = match config.output_attentions {
            Some(value) => value,
            None => false
        };
        let output_hidden_states = match config.output_hidden_states {
            Some(value) => value,
            None => false
        };

        let mut layers: Vec<TransformerBlock> = vec!();
        for layer_index in 0..config.n_layers {
            layers.push(TransformerBlock::new(&(p / layer_index), config));
        };

        Transformer { output_attentions, output_hidden_states, layers }
    }

    pub fn forward_t(&self, input: &Tensor, mask: Option<Tensor>, train: bool)
                     -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
        let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
        let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };

        let mut hidden_state = input.copy();
        let mut attention_weights: Option<Tensor>;
        let mut layers = self.layers.iter();
        loop {
            match layers.next() {
                Some(layer) => {
                    if let Some(hidden_states) = all_hidden_states.borrow_mut() {
                        hidden_states.push(hidden_state.as_ref().copy());
                    };

                    let temp = layer.forward_t(&hidden_state, &mask, train);
                    hidden_state = temp.0;
                    attention_weights = temp.1;
                    if let Some(attentions) = all_attentions.borrow_mut() {
                        attentions.push(attention_weights.as_ref().unwrap().copy());
                    };
                }
                None => break
            };
        };

        (hidden_state, all_hidden_states, all_attentions)
    }
}