meuron 0.4.0

Meuron is a modular neural network library written in rust for training simple neural networks.
Documentation
use crate::backend::Backend;
use ndarray::{Array, ArrayBase, Axis, Dimension, Ix2, OwnedRepr, RemoveAxis};
use ndarray_rand::RandomExt;
use ndarray_rand::rand_distr::{Normal, Uniform};

#[derive(Clone)]
pub struct CPUBackend;

impl Backend for CPUBackend {
    type Tensor<D: Dimension> = ArrayBase<OwnedRepr<f32>, D>;

    fn zeros<D: Dimension>(shape: D) -> Self::Tensor<D> {
        Array::zeros(shape)
    }

    fn random_uniform<D: Dimension>(shape: D, low: f32, high: f32) -> Self::Tensor<D> {
        Array::random(shape, Uniform::new(low, high).unwrap())
    }

    fn random_normal<D: Dimension>(shape: D, mean: f32, std: f32) -> Self::Tensor<D> {
        Array::random(shape, Normal::new(mean, std).unwrap())
    }

    fn from_array<D: Dimension>(array: ndarray::Array<f32, D>) -> Self::Tensor<D> {
        array
    }

    fn to_array<D: Dimension>(tensor: &Self::Tensor<D>) -> ndarray::Array<f32, D> {
        tensor.clone()
    }

    fn unary<D: Dimension>(tensor: &Self::Tensor<D>, op: u32) -> Self::Tensor<D> {
        tensor.mapv(|x| match op {
            0 => x.tanh(),
            1 => 1.0 / (1.0 + (-x).exp()),
            2 => x.max(0.0),
            3 => {
                let t = x.tanh();
                1.0 - t * t
            }
            4 => {
                let s = 1.0 / (1.0 + (-x).exp());
                s * (1.0 - s)
            }
            5 => {
                if x > 0.0 {
                    1.0
                } else {
                    0.0
                }
            }
            6 => x.exp(),
            7 => x.ln(),
            8 => x.abs(),
            9 => -x,
            10 => x.sqrt(),
            _ => panic!("unknown unary op {op}"),
        })
    }

    fn add<D: Dimension>(a: &Self::Tensor<D>, b: &Self::Tensor<D>) -> Self::Tensor<D> {
        a + b
    }
    fn sub<D: Dimension>(a: &Self::Tensor<D>, b: &Self::Tensor<D>) -> Self::Tensor<D> {
        a - b
    }
    fn mul<D: Dimension>(a: &Self::Tensor<D>, b: &Self::Tensor<D>) -> Self::Tensor<D> {
        a * b
    }
    fn div<D: Dimension>(a: &Self::Tensor<D>, b: &Self::Tensor<D>) -> Self::Tensor<D> {
        a / b
    }

    fn scale<D: Dimension>(tensor: &Self::Tensor<D>, scalar: f32) -> Self::Tensor<D> {
        tensor * scalar
    }

    fn scalar_sub<D: Dimension>(scalar: f32, tensor: &Self::Tensor<D>) -> Self::Tensor<D> {
        tensor.mapv(|v| scalar - v)
    }
    fn scalar_max<D: Dimension>(tensor: &Array<f32, D>, s: f32) -> Array<f32, D> {
        tensor.mapv(|x| x.max(s))
    }
    fn scalar_min<D: Dimension>(tensor: &Array<f32, D>, s: f32) -> Array<f32, D> {
        tensor.mapv(|x| x.min(s))
    }

    fn mean<D: Dimension>(tensor: &Self::Tensor<D>) -> Option<f32> {
        tensor.mean()
    }

    fn sum_axis<D: Dimension + RemoveAxis>(
        tensor: &Self::Tensor<D>,
        axis: usize,
    ) -> Self::Tensor<D::Smaller> {
        tensor.sum_axis(Axis(axis))
    }

    fn matmul<D1: Dimension, D2: Dimension>(
        a: &Self::Tensor<D1>,
        b: &Self::Tensor<D2>,
    ) -> Self::Tensor<D1> {
        let a_dyn = a.view().into_dyn();
        let b_dyn = b.view().into_dyn();
        let out = match (a_dyn.ndim(), b_dyn.ndim()) {
            (2, 2) => {
                let a2 = a_dyn.view().into_dimensionality::<Ix2>().unwrap();
                let b2 = b_dyn.view().into_dimensionality::<Ix2>().unwrap();
                a2.dot(&b2).into_dyn()
            }
            (3, 2) => {
                let (batch, m) = (a_dyn.shape()[0], a_dyn.shape()[1]);
                let n = b_dyn.shape()[1];
                let b2 = b_dyn.view().into_dimensionality::<Ix2>().unwrap();
                let mut result = Array::zeros((batch, m, n)).into_dyn();
                for i in 0..batch {
                    let ai = a_dyn
                        .index_axis(Axis(0), i)
                        .into_dimensionality::<Ix2>()
                        .unwrap();
                    result.index_axis_mut(Axis(0), i).assign(&ai.dot(&b2));
                }
                result
            }
            (3, 3) => {
                let (batch, m) = (a_dyn.shape()[0], a_dyn.shape()[1]);
                let n = b_dyn.shape()[2];
                let mut result = Array::zeros((batch, m, n)).into_dyn();
                for i in 0..batch {
                    let ai = a_dyn
                        .index_axis(Axis(0), i)
                        .into_dimensionality::<Ix2>()
                        .unwrap();
                    let bi = b_dyn
                        .index_axis(Axis(0), i)
                        .into_dimensionality::<Ix2>()
                        .unwrap();
                    result.index_axis_mut(Axis(0), i).assign(&ai.dot(&bi));
                }
                result
            }
            (4, 2) => {
                let (b1, b2, m, _k) = (
                    a_dyn.shape()[0],
                    a_dyn.shape()[1],
                    a_dyn.shape()[2],
                    a_dyn.shape()[3],
                );
                let n = b_dyn.shape()[1];
                let bw = b_dyn.view().into_dimensionality::<Ix2>().unwrap();
                let mut out = Array::zeros((b1, b2, m, n)).into_dyn();
                for i in 0..b1 {
                    let ai = a_dyn.index_axis(Axis(0), i);
                    let mut oi = out.index_axis_mut(Axis(0), i);
                    for j in 0..b2 {
                        let aij = ai
                            .index_axis(Axis(0), j)
                            .into_dimensionality::<Ix2>()
                            .unwrap();
                        oi.index_axis_mut(Axis(0), j).assign(&aij.dot(&bw));
                    }
                }
                out
            }
            (4, 4) => {
                let (b1, b2, m) = (a_dyn.shape()[0], a_dyn.shape()[1], a_dyn.shape()[2]);
                let n = b_dyn.shape()[3];
                let mut out = Array::zeros((b1, b2, m, n)).into_dyn();
                for i in 0..b1 {
                    let ai = a_dyn.index_axis(Axis(0), i);
                    let bi = b_dyn.index_axis(Axis(0), i);
                    let mut oi = out.index_axis_mut(Axis(0), i);
                    for j in 0..b2 {
                        let aij = ai
                            .index_axis(Axis(0), j)
                            .into_dimensionality::<Ix2>()
                            .unwrap();
                        let bij = bi
                            .index_axis(Axis(0), j)
                            .into_dimensionality::<Ix2>()
                            .unwrap();
                        oi.index_axis_mut(Axis(0), j).assign(&aij.dot(&bij));
                    }
                }
                out
            }
            _ => panic!(
                "matmul: unsupported shapes {:?} x {:?}",
                a_dyn.shape(),
                b_dyn.shape()
            ),
        };
        out.into_dimensionality::<D1>()
            .expect("matmul output rank must match left operand")
            .to_owned()
    }

    fn transpose<D: Dimension>(
        tensor: &Self::Tensor<D>,
        axis1: usize,
        axis2: usize,
    ) -> Self::Tensor<D> {
        let mut t = tensor.clone().into_dyn();
        t.swap_axes(axis1, axis2);
        t.as_standard_layout()
            .into_dimensionality::<D>()
            .expect("transpose must preserve rank")
            .to_owned()
    }

    fn broadcast_add<D1: Dimension, D2: Dimension>(
        a: &Self::Tensor<D1>,
        b: &Self::Tensor<D2>,
    ) -> Self::Tensor<D1> {
        let a_dyn = a.view().into_dyn().to_owned();
        let b_dyn = b.view().into_dyn().to_owned();
        (a_dyn + b_dyn)
            .into_dimensionality::<D1>()
            .expect("broadcast_add output rank must match left operand")
    }

    fn softmax<D: Dimension>(tensor: &Self::Tensor<D>) -> Self::Tensor<D> {
        let shape = tensor.shape().to_vec();
        let last_dim = shape[shape.len() - 1];
        let batch = shape[..shape.len() - 1].iter().product::<usize>().max(1);
        let x_c = tensor.as_standard_layout();
        let raw = x_c.as_slice().unwrap();
        let mut out = vec![0.0f32; raw.len()];
        for b in 0..batch {
            let s = b * last_dim;
            let row = &raw[s..s + last_dim];
            let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
            let sum: f32 = row.iter().map(|&v| (v - max).exp()).sum();
            for (i, &v) in row.iter().enumerate() {
                out[s + i] = (v - max).exp() / sum;
            }
        }
        Array::from_shape_vec(tensor.raw_dim(), out).unwrap()
    }

    fn softmax_vjp<D: Dimension>(
        z: &Self::Tensor<D>,
        grad_output: &Self::Tensor<D>,
    ) -> Self::Tensor<D> {
        let shape = z.shape().to_vec();
        let last_dim = shape[shape.len() - 1];
        let batch = shape[..shape.len() - 1].iter().product::<usize>().max(1);
        let s = Self::softmax(z);
        let s_c = s.as_standard_layout();
        let s_raw = s_c.as_slice().unwrap();
        let g_c = grad_output.as_standard_layout();
        let g_raw = g_c.as_slice().unwrap();
        let mut out = vec![0.0f32; s_raw.len()];
        for b in 0..batch {
            let start = b * last_dim;
            let sr = &s_raw[start..start + last_dim];
            let gr = &g_raw[start..start + last_dim];
            let dot: f32 = sr.iter().zip(gr).map(|(&si, &gi)| si * gi).sum();
            for i in 0..last_dim {
                out[start + i] = sr[i] * (gr[i] - dot);
            }
        }
        Array::from_shape_vec(z.raw_dim(), out).unwrap()
    }

    fn assign<D: Dimension>(dst: &mut Self::Tensor<D>, src: Self::Tensor<D>) {
        *dst = src;
    }

    fn shape<D: Dimension>(tensor: &Self::Tensor<D>) -> Vec<usize> {
        tensor.shape().to_vec()
    }

    fn len_of<D: Dimension>(tensor: &Self::Tensor<D>, axis: usize) -> usize {
        tensor.len_of(Axis(axis))
    }

    fn select<D: Dimension + RemoveAxis>(
        tensor: &Self::Tensor<D>,
        axis: usize,
        indices: &[usize],
    ) -> Self::Tensor<D> {
        tensor.select(Axis(axis), indices)
    }
}