muonts 0.1.0

Timeseries models in rust
Documentation
use burn::config::Config;
use burn::module::Module;
use burn::nn::attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig};
use burn::nn::{LayerNorm, LayerNormConfig, Linear, LinearConfig};
use burn::tensor::{backend::Backend, Tensor};

use super::glu::{GatedLinearUnit, GatedLinearUnitConfig};
use super::grn::{GatedResidualNetwork, GatedResidualNetworkConfig};

#[derive(Module, Debug)]
pub struct TemportalFusionDecoder<B: Backend> {
    context_length: usize,
    prediction_length: usize,
    enrich: GatedResidualNetwork<B>,
    attention: MultiHeadAttention<B>,
    att_net_linear: Linear<B>,
    att_net_glu: GatedLinearUnit,
    att_lnorm: LayerNorm<B>,
    ff_net_grn: GatedResidualNetwork<B>,
    ff_net_linear: Linear<B>,
    ff_net_glu: GatedLinearUnit,
    ff_lnorm: LayerNorm<B>,
}

impl<B: Backend> TemportalFusionDecoder<B> {
    pub fn forward(
        &self,
        x: Tensor<B, 3>,
        statics: Tensor<B, 3>,
        mask: Tensor<B, 2>,
    ) -> Tensor<B, 3> {
        let expanded_static = statics.repeat(1, self.context_length + self.prediction_length);
        let skip = {
            let [batch_size, seq_size, val_size] = x.dims();
            x.clone()
                .slice([0..batch_size, self.context_length..seq_size, 0..val_size])
        };
        let x = self.enrich.forward(x, Some(expanded_static));

        let mask_pad = {
            let mask_pad = mask.ones_like();
            let [x, _] = mask_pad.dims();
            let mask_pad = mask_pad.slice([0..x, 0..1]);
            mask_pad.repeat(1, self.prediction_length)
        };

        let key_padding_mask = (Tensor::cat(vec![mask, mask_pad], 1).neg() + 1.0)
            .equal_elem(0)
            .bool_not();

        let query = {
            let [a, b, c] = x.dims();
            x.clone().slice([0..a, self.context_length..b, 0..c])
        };

        let att_input = MhaInput::new(query, x.clone(), x.clone()).mask_pad(key_padding_mask);
        let mha_output = self.attention.forward(att_input);
        let attn_output = mha_output.context;

        let att = self.att_net_linear.forward(attn_output);
        let att = self.att_net_glu.forward(att);

        let x = {
            let [a, b, c] = x.dims();
            x.slice([0..a, self.context_length..b, 0..c])
        };
        let x = self.att_lnorm.forward(x + att);
        let x = self.ff_net_grn.forward(x, None);
        let x = self.ff_net_linear.forward(x);
        let x = self.ff_net_glu.forward(x);
        let x = self.ff_lnorm.forward(x + skip);

        x
    }
}

#[derive(Config, Debug)]
pub struct TemportalFusionDecoderConfig {
    context_length: usize,
    prediction_length: usize,
    d_hidden: usize,
    d_var: usize,
    num_heads: usize,

    #[config(default = 0.0)]
    dropout: f64,
}

impl TemportalFusionDecoderConfig {
    pub fn init<B: Backend>(&self) -> TemportalFusionDecoder<B> {
        let enrich = GatedResidualNetworkConfig::new(self.d_hidden)
            .with_d_static(self.d_var)
            .with_dropout(self.dropout)
            .init();

        let attention = MultiHeadAttentionConfig::new(self.d_hidden, self.num_heads)
            .with_dropout(self.dropout)
            .init();

        let att_net_linear = LinearConfig::new(self.d_hidden, self.d_hidden * 2).init();
        let att_net_glu = GatedLinearUnitConfig::new().init();
        let att_lnorm = LayerNormConfig::new(self.d_hidden).init();

        let ff_net_grn = GatedResidualNetworkConfig::new(self.d_hidden)
            .with_dropout(self.dropout)
            .init();
        let ff_net_linear = LinearConfig::new(self.d_hidden, self.d_hidden * 2).init();
        let ff_net_glu = GatedLinearUnitConfig::new().init();
        let ff_lnorm = LayerNormConfig::new(self.d_hidden).init();

        TemportalFusionDecoder {
            context_length: self.context_length,
            prediction_length: self.prediction_length,
            enrich,
            attention,
            att_net_linear,
            att_net_glu,
            att_lnorm,
            ff_net_grn,
            ff_net_linear,
            ff_net_glu,
            ff_lnorm,
        }
    }
}