yarnn 0.1.0

Yet Another rust Neural Network framework
Documentation
use crate::backend::Backend;
use crate::optimizer::{Optimizer, OptimizerContext};
use crate::tensor::{TensorShape, Tensor};

pub struct Params<N, B: Backend<N>, O: Optimizer<N, B>> {
    pub params: B::Tensor,
    pub grads: B::Tensor,
    pub ctx: O::Context,
}

impl<N, B: Backend<N>, O: Optimizer<N, B>> Params<N, B, O> {
    pub fn new<S: Into<TensorShape>>(shape: S) -> Self {
        let shape = shape.into();

        Self {
            params: B::Tensor::new(shape.clone()),
            grads: B::Tensor::new(shape.clone()),
            ctx: O::Context::new(shape),
        }
    }

    pub fn init_random(&mut self, backend: &B, count: u32) {
        let to = backend.scalar_f32((1.0 / (count as f32)).sqrt());

        backend.fill_random(&mut self.params, backend.scalar_f32(0.0), to);
    }

    pub fn init_zero(&mut self, backend: &B) {
        backend.fill_scalar(&mut self.params, backend.scalar_f32(0.0));
    }
}