muonts 0.1.0

Timeseries models in rust
Documentation
use burn::tensor::{backend::Backend, Tensor};

pub fn weighted_average<B: Backend>(x: Tensor<B, 2>, weights: Tensor<B, 2>) -> Tensor<B, 1> {
    let zeros = x.zeros_like();
    let mask = weights.clone().equal_elem(0.0).bool_not();
    let weighted = x * weights.clone();

    let weighted_tensor = zeros.mask_where(mask, weighted);
    let sum_weights = weights.sum().clamp_min(1.0);

    weighted_tensor.sum() / sum_weights
}

pub fn quantile_loss<B: Backend>(
    y_true: Tensor<B, 2>,
    y_pred: Tensor<B, 3>,
    quantiles: Tensor<B, 1>,
) -> Tensor<B, 2> {
    let quantiles: Tensor<B, 3> = quantiles.unsqueeze();
    let y_true: Tensor<B, 3> = y_true.unsqueeze_dim(2);
    let residual = y_true.clone() - y_pred.clone();
    let loss = y_pred.lower_equal(y_true).float() - quantiles;
    (residual * loss).abs().sum_dim(2).squeeze(2)
}

pub fn split<B: Backend, const D: usize>(
    x: Tensor<B, D>,
    splits: Vec<usize>,
    dim: i32,
) -> Vec<Tensor<B, D>> {
    let dim: usize = if dim < 0 { D - 1 } else { dim as usize };

    let dim_size = x.dims()[dim];

    if splits.len() == 0 {
        return vec![x];
    }

    if splits.len() > 1 {
        assert!(splits.iter().copied().fold(0, |x, y| x + y) == dim_size);
    }

    let splits: Vec<usize> = if splits.len() == 1 {
        let l = splits[0];
        let reps = dim_size / l;
        let reminder = dim_size % l;
        (0..reps).map(|_| l).chain(vec![reminder]).collect()
    } else {
        splits
    };

    let mut current_idx: usize = 0;
    splits
        .into_iter()
        .map(|s| {
            let mut ranges = x.dims().map(|x| 0..x);
            ranges[dim] = current_idx..current_idx + s;
            current_idx += s;
            x.clone().slice(ranges)
        })
        .collect()
}