fsrs 5.2.0

FSRS for Rust, including Optimizer and Scheduler
Documentation
use burn::{
    backend::autodiff::Autodiff,
    tensor::{Element, TensorData, Tolerance},
};
pub type NdArrayAutodiff = Autodiff<burn::backend::NdArray>;
pub type Model = crate::model::Model<NdArrayAutodiff>;
pub type Tensor<const D: usize, K = burn::tensor::Float> =
    burn::tensor::Tensor<NdArrayAutodiff, D, K>;

#[track_caller]
fn assert_approx_eq<const N: usize, T>(a: [T; N], b: [T; N])
where
    T: Copy + std::fmt::Debug + PartialEq + Element,
    f64: From<T>,
{
    TensorData::from(a).assert_approx_eq::<f32>(&TensorData::from(b), Tolerance::absolute(1e-4));
}

pub trait TestHelper<const N: usize, T> {
    fn assert_approx_eq(&self, b: [T; N])
    where
        T: Copy + std::fmt::Debug + PartialEq + Element,
        f64: From<T>;
}

impl<T, const N: usize> TestHelper<N, T> for [T; N] {
    #[track_caller]
    fn assert_approx_eq(&self, b: [T; N])
    where
        T: Copy + std::fmt::Debug + PartialEq + Element,
        f64: From<T>,
    {
        assert_approx_eq(*self, b);
    }
}

impl<T, const N: usize> TestHelper<N, T> for Vec<T> {
    #[track_caller]
    fn assert_approx_eq(&self, b: [T; N])
    where
        T: Copy + std::fmt::Debug + PartialEq + Element,
        f64: From<T>,
    {
        let a = self.to_owned().try_into().unwrap();
        assert_approx_eq(a, b);
    }
}
impl<T, const N: usize> TestHelper<N, T> for [T] {
    #[track_caller]
    fn assert_approx_eq(&self, b: [T; N])
    where
        T: Copy + std::fmt::Debug + PartialEq + Element,
        f64: From<T>,
    {
        self.to_vec().assert_approx_eq(b);
    }
}