neuronika 0.2.0

Tensors and dynamic neural networks.
use super::{Data, Forward, Gradient, Input, InputBackward, Overwrite, Tensor};
use std::cell::{Cell, RefCell};

mod forward {

    use super::{Cell, Data, Forward, Input, RefCell, Tensor};

    #[test]
    fn creation() {
        let input = Input {
            data: RefCell::new(Tensor::zeros((3, 3))),
            computed: Cell::new(false),
        };
        assert!(!input.was_computed());
        assert_eq!(*input.data(), Tensor::from_elem((3, 3), 0.));
        assert_eq!(*input.data_mut(), Tensor::from_elem((3, 3), 0.));
    }

    #[test]
    fn computation_was_computed_transition() {
        let input = Input {
            data: RefCell::new(Tensor::zeros((3, 3))),
            computed: Cell::new(false),
        };

        input.forward();
        assert!(input.was_computed());

        input.forward();
        assert!(input.was_computed());

        input.reset_computation();
        assert!(!input.was_computed());

        input.reset_computation();
        assert!(!input.was_computed());
    }

    #[test]
    fn debug() {
        let node = Input {
            data: RefCell::new(Tensor::zeros(1)),
            computed: Cell::new(false),
        };
        let output =
            "Input { data: [0.0], shape=[1], strides=[1], layout=CFcf (0xf), const ndim=1, computed: false }";

        assert_eq!(output, format!("{:?}", node));
    }

    #[test]
    fn display() {
        let node = Input {
            data: RefCell::new(Tensor::zeros(1)),
            computed: Cell::new(false),
        };

        assert_eq!(format!("{}", node.data()), format!("{}", node));
    }
}

mod backward {
    use super::{Cell, Gradient, InputBackward, Overwrite, RefCell, Tensor};

    #[test]
    fn creation() {
        let input = InputBackward {
            gradient: RefCell::new(Some(Tensor::zeros((3, 3)))),
            overwrite: Cell::new(true),
        };
        assert!(input.can_overwrite());
        assert_eq!(*input.gradient(), Tensor::from_elem((3, 3), 0.));
        assert_eq!(*input.gradient_mut(), Tensor::from_elem((3, 3), 0.));
    }

    #[test]
    fn computation_state_transition() {
        let input = InputBackward {
            gradient: RefCell::new(Some(Tensor::zeros((3, 3)))),
            overwrite: Cell::new(true),
        };

        assert!(input.can_overwrite());
        input.set_overwrite(false);
        assert!(!input.can_overwrite());
        input.set_overwrite(true);
        assert!(input.can_overwrite());
    }

    #[test]
    fn zero_grad() {
        let input = InputBackward {
            gradient: RefCell::new(Some(Tensor::ones((3, 3)))),
            overwrite: Cell::new(true),
        };

        input.zero_grad();
        assert_eq!(*input.gradient(), Tensor::zeros((3, 3)));
    }

    #[test]
    fn debug() {
        let node = InputBackward {
            gradient: RefCell::new(Some(Tensor::zeros(1))),
            overwrite: Cell::new(false),
        };
        let output =
            "InputBackward { gradient: Some([0.0], shape=[1], strides=[1], layout=CFcf (0xf), const ndim=1), overwrite: false }";

        assert_eq!(output, format!("{:?}", node));
    }

    #[test]
    fn display() {
        let node = InputBackward {
            gradient: RefCell::new(Some(Tensor::zeros(1))),
            overwrite: Cell::new(false),
        };

        assert_eq!(format!("{}", node.gradient()), format!("{}", node));
    }
}