muonts 0.1.0

Timeseries models in rust
Documentation
use burn::config::Config;
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::tensor::{backend::Backend, Tensor};

use crate::utils::split;

#[derive(Module, Debug)]
pub struct FeatureProjector<B: Backend> {
    feature_dims: Vec<usize>,
    projectors: Vec<Linear<B>>,
}

impl<B: Backend> FeatureProjector<B> {
    pub fn forward<const D: usize>(&self, features: Tensor<B, D>) -> Vec<Tensor<B, D>> {
        let feature_slices = if self.projectors.len() > 1 {
            split(features, self.feature_dims.clone(), -1)
        } else {
            vec![features]
        };

        self.projectors
            .iter()
            .zip(feature_slices)
            .map(|(proj, feat_slice)| proj.forward(feat_slice))
            .collect()
    }
}

#[derive(Config, Debug)]
pub struct FeatureProjectorConfig {
    feature_dims: Vec<usize>,
    embedding_dims: Vec<usize>,
}

impl FeatureProjectorConfig {
    pub fn init<B: Backend>(&self) -> FeatureProjector<B> {
        assert!(self.feature_dims.len() > 0);
        assert!(self.feature_dims.len() == self.embedding_dims.len());

        let projectors: Vec<Linear<B>> = self
            .feature_dims
            .iter()
            .zip(self.embedding_dims.iter())
            .map(|(c, d)| LinearConfig::new(*c, *d).init())
            .collect();

        FeatureProjector {
            feature_dims: self.feature_dims.clone(),
            projectors,
        }
    }
}