muonts 0.1.0

Timeseries models in rust
Documentation
use burn::config::Config;
use burn::module::Module;
use burn::tensor::activation;
use burn::tensor::{backend::Backend, Tensor};
use super::grn::{GatedResidualNetwork, GatedResidualNetworkConfig};

#[derive(Module, Debug)]
pub struct VariableSelectionNetwork<B: Backend> {
    weight_network: GatedResidualNetwork<B>,
    variable_networks: Vec<GatedResidualNetwork<B>>,
}

impl<B: Backend> VariableSelectionNetwork<B> {
    pub fn forward<const D: usize>(
        &self,
        variables: Vec<Tensor<B, D>>,
        statics: Option<Tensor<B, D>>,
    ) -> (Tensor<B, D>, Tensor<B, D>) {
        let statics: Option<Tensor<B, D>> = match statics {
            None => None,
            Some(tensor) => {
                let dim_size = variables[0].dims()[1];
                Some(tensor.repeat(1, dim_size))
            }
        };

        let flattened = Tensor::cat(variables.clone(), D - 1);
        let weight = self.weight_network.forward(flattened, statics);

        let var_encodings: Vec<Tensor<B, D>> = self
            .variable_networks
            .iter()
            .zip(variables)
            .map(|(net, var)| {
                let enc = net.forward(var, None);
                enc
            })
            .collect();

        match D {
            2 => {
                let weight : Tensor<B, 3> = weight.unsqueeze_dim(1);
                let weight = activation::softmax(weight, 2);

                let var_encodings : Tensor<B, 3> = Tensor::stack(var_encodings, 2);
                let var_encodings = (var_encodings * weight.clone()).sum_dim(2);
                let var_encodings : Tensor<B, D> = var_encodings.squeeze(2);

                (var_encodings, weight.squeeze(1))
            },
            3 => {
                let weight : Tensor<B, 4> = weight.unsqueeze_dim(2);
                let weight = activation::softmax(weight, 3);

                let var_encodings : Tensor<B, 4> = Tensor::stack(var_encodings, 3);
                let var_encodings = (var_encodings * weight.clone()).sum_dim(3);
                let var_encodings : Tensor<B, D> = var_encodings.squeeze(3);

                (var_encodings, weight.squeeze(2))
            },
            _ => {
                panic!("Unsupported dimension")
            }
        }
    }
}

#[derive(Config, Debug)]
pub struct VariableSelectionNetworkConfig {
    d_hidden: usize,
    num_vars: usize,

    #[config(default = false)]
    add_static: bool,

    #[config(default = 0.0)]
    dropout: f64,
}

impl VariableSelectionNetworkConfig {
    pub fn init<B: Backend>(&self) -> VariableSelectionNetwork<B> {
        let d_static = if self.add_static { self.d_hidden } else { 0 };

        let weight_network: GatedResidualNetwork<B> =
            GatedResidualNetworkConfig::new(self.d_hidden)
                .with_d_input(Some(self.d_hidden * self.num_vars))
                .with_d_output(Some(self.num_vars))
                .with_d_static(d_static)
                .with_dropout(self.dropout)
                .init();

        let variable_networks: Vec<GatedResidualNetwork<B>> = (0..self.num_vars)
            .map(|_| {
                GatedResidualNetworkConfig::new(self.d_hidden)
                    .with_dropout(self.dropout)
                    .init()
            })
            .collect();

        VariableSelectionNetwork {
            weight_network,
            variable_networks,
        }
    }
}