rsrl 0.8.1

A fast, extensible reinforcement learning framework in Rust
Documentation
use ndarray::{ArrayBase, DataMut, NdIndex, IntoDimension};

#[derive(Copy, Clone, Debug)]
#[cfg_attr(
    feature = "serde",
    derive(Serialize, Deserialize),
    serde(crate = "serde_crate")
)]
pub struct Tile<D: ndarray::Dimension, I: NdIndex<D>> {
    dim: D,
    active: Option<(I, f64)>,
}

impl<D: ndarray::Dimension, I: NdIndex<D>> Tile<D, I> {
    pub fn new<T: IntoDimension<Dim = D>>(dim: T, active: Option<(I, f64)>) -> Self {
        Tile {
            dim: dim.into_dimension(),
            active,
        }
    }
}

impl<D: ndarray::Dimension, I: NdIndex<D> + Clone> crate::params::Buffer for Tile<D, I> {
    type Dim = D;

    fn raw_dim(&self) -> D { self.dim.clone() }

    fn addto<E: DataMut<Elem = f64>>(&self, arr: &mut ArrayBase<E, Self::Dim>) {
        if let Some((idx, activation)) = &self.active {
            arr[idx.clone()] += activation;
        }
    }

    fn scaled_addto<E: DataMut<Elem = f64>>(&self, alpha: f64, arr: &mut ArrayBase<E, Self::Dim>) {
        if let Some((idx, activation)) = &self.active {
            arr[idx.clone()] += alpha * activation;
        }
    }
}

impl<D, I> crate::params::BufferMut for Tile<D, I>
where
    D: ndarray::Dimension,
    I: NdIndex<D> + PartialEq + Clone,
{
    fn zeros<T: IntoDimension<Dim = D>>(dim: T) -> Self { Tile::new(dim, None) }

    fn map(&self, f: impl Fn(f64) -> f64) -> Self {
        self.clone().map_into(f)
    }

    fn map_into(self, f: impl Fn(f64) -> f64) -> Self {
        Tile {
            dim: self.dim,
            active: self.active.map(|(idx, a)| (idx, f(a))),
        }
    }

    fn map_inplace(&mut self, f: impl Fn(f64) -> f64) {
        if let Some((_, x)) = &mut self.active {
            *x = f(*x);
        }
    }

    fn merge(&self, other: &Self, f: impl Fn(f64, f64) -> f64) -> Self {
        self.clone().merge_into(other, f)
    }

    fn merge_into(mut self, other: &Self, f: impl Fn(f64, f64) -> f64) -> Self {
        self.merge_inplace(other, f);
        self
    }

    fn merge_inplace(&mut self, other: &Self, f: impl Fn(f64, f64) -> f64) {
        if self.dim != other.dim {
            panic!("Incompatible buffers shapes.")
        }

        match (&mut self.active, &other.active) {
            (Some((i, x)), Some((j, y))) if i == j => *x = f(*x, *y),
            _ => panic!("Incompatible buffer indices."),
        }
    }
}