muonts 0.1.0

Timeseries models in rust
Documentation
use burn::config::Config;
use burn::module::Module;
use burn::nn::{LayerNorm, LayerNormConfig, Linear, LinearConfig, Lstm, LstmConfig};
use burn::tensor::{backend::Backend, Tensor};
use super::glu::{GatedLinearUnit, GatedLinearUnitConfig};

#[derive(Module, Debug)]
pub struct TemporalFusionEncoder<B: Backend> {
    encoder_lstm: Lstm<B>,
    decoder_lstm: Lstm<B>,
    gate_linear: Linear<B>,
    gate_glu: GatedLinearUnit,
    skip_proj: Option<Linear<B>>,
    lnorm: LayerNorm<B>,
}

impl<B: Backend> TemporalFusionEncoder<B> {
    pub fn forward(
        &self,
        ctx_input: Tensor<B, 3>,
        tgt_input: Option<Tensor<B, 3>>,
        states: Option<(Tensor<B, 2>, Tensor<B, 2>)>,
    ) -> Tensor<B, 3> {
        let (cell_state, hidden_state) = self.encoder_lstm.forward(ctx_input.clone(), states);
        let ctx_encodings = hidden_state.clone();

        let last_hidden_state: Tensor<B, 2> = {
            let [batch, d_seq, d_hidden] = hidden_state.dims();
            hidden_state
                .slice([0..batch, d_seq - 1..d_seq, 0..d_hidden])
                .squeeze(1)
        };

        let last_cell_state: Tensor<B, 2> = {
            let [batch, d_seq, d_hidden] = cell_state.dims();
            cell_state
                .slice([0..batch, d_seq - 1..d_seq, 0..d_hidden])
                .squeeze(1)
        };

        let (encodings, skip) = match tgt_input {
            Some(input) => {
                let states = (last_cell_state, last_hidden_state);
                let (_, tgt_encodings) = self.decoder_lstm.forward(input.clone(), Some(states));

                let encodings = Tensor::cat(vec![ctx_encodings, tgt_encodings], 1);
                let skip = Tensor::cat(vec![ctx_input, input], 1);

                (encodings, skip)
            }
            None => (ctx_encodings, ctx_input),
        };

        let encodings = self.gate_linear.forward(encodings);
        let encodings = self.gate_glu.forward(encodings);

        let skip = match &self.skip_proj {
            Some(proj) => proj.forward(skip),
            None => skip,
        };

        let encodings = self.lnorm.forward(encodings + skip);

        encodings
    }
}

#[derive(Config, Debug)]
pub struct TemporalFusionEncoderConfig {
    d_input: usize,
    d_hidden: usize,
}

impl TemporalFusionEncoderConfig {
    pub fn init<B: Backend>(&self) -> TemporalFusionEncoder<B> {
        let skip_proj = if self.d_input != self.d_hidden {
            Some(LinearConfig::new(self.d_input, self.d_hidden).init())
        } else {
            None
        };

        let encoder_lstm = LstmConfig::new(self.d_input, self.d_hidden, true).init();
        let decoder_lstm = LstmConfig::new(self.d_input, self.d_hidden, true).init();
        let gate_linear = LinearConfig::new(self.d_hidden, self.d_hidden * 2).init();
        let gate_glu = GatedLinearUnitConfig::new().init();
        let lnorm = LayerNormConfig::new(self.d_hidden).init();

        TemporalFusionEncoder {
            encoder_lstm,
            decoder_lstm,
            gate_linear,
            gate_glu,
            skip_proj,
            lnorm,
        }
    }
}