brainy 0.2.3

A library for neural networks.
Documentation
//! Use to create matrices.

use std::ops::{Add, Mul, Index, IndexMut};
use rand;

/// Performs random fill.
pub fn fill(vec: &mut Vec<f64>) 
{
    for i in 0..vec.len()
    {
        vec[i] = 2.0 * rand::random::<f64>() - 1.0;
    }
}

/// Add vectors element wise.
pub fn add(lhs: Vec<f64>, rhs: Vec<f64>) -> Vec<f64>
{
    assert!(lhs.len() == rhs.len(), "Can't add vectors of different lengths.");
    let mut result: Vec<f64> = vec![0_f64; lhs.len()];

    for i in 0..lhs.len()
    {
        result[i] = lhs[i] + rhs[i];
    }
    result
}   

/// Multiply vectors element wise.
pub fn mul(lhs: Vec<f64>, rhs: Vec<f64>) -> Vec<f64>
{
    assert!(lhs.len() == rhs.len(), "Can't multiply vectors of different lengths.");
    let mut result: Vec<f64> = vec![0_f64; lhs.len()];

    for i in 0..lhs.len()
    {
        result[i] = lhs[i] * rhs[i];
    }
    result
}  

/// Matrix, the elements are given by Vec<f64>
#[derive(Clone)]
pub struct Matrix
{
    pub rows: usize,
    pub cols: usize,
    pub elements: Vec<f64>
}

/// Instead of filling in zeros for off diagonal entries in the Matrix struct, use this.
#[derive(Clone)]
pub struct DiagonalMatrix
{
    pub elements: Vec<f64>
}

impl Matrix
{
    pub fn new(rows: usize, cols: usize) -> Self 
    {
        let elements: Vec<f64> = vec![0_f64; rows * cols];
        Self { rows, cols, elements }
    }  

    /// Wraps column vector into a matrix.
    pub fn wrap(c: Matrix, cols: usize) -> Matrix
    {
        assert!(c.cols == 1, "Input must be a column vector.");
        let rows = c.elements.len() / cols;

        let mut mat = Matrix::new(rows, cols);
        for i in 0..rows
        {
            for j in 0..cols
            {
                mat[(i, j)] = c[(i * cols + j, 0)];
            }
        }

        mat
    }

    // print matrix
    pub fn print(&self)
    {
        for r in 0..self.rows
        {
            for c in 0..self.cols
            {
                print!("{} ", self.elements[r * self.cols + c]);
            }
            print!("\n");
        }
        print!("\n");
    }
}

impl DiagonalMatrix
{ 
    pub fn new(n: usize) -> Self 
    {
        let elements: Vec<f64> = vec![0_f64; n];
        Self { elements }
    }
    
    pub fn print(&self)
    {
        for i  in 0..self.elements.len()
        {
            print!("{} ", self.elements[i]);
        }

        print!("\n");
    }
}


impl Index<(usize, usize)> for Matrix
{
    type Output = f64;

    fn index(&self, index: (usize, usize)) -> &Self::Output
    {
        assert!((index.0 < self.rows) && (index.1 < self.cols), "Index out of bounds.");
        &self.elements[index.0 * self.cols + index.1]
    }
}

impl IndexMut<(usize, usize)> for Matrix
{
    fn index_mut(&mut self, index: (usize, usize)) -> &mut f64
    {
        assert!((index.0 < self.rows) && (index.1 < self.cols), "Index out of bounds.");
        &mut self.elements[index.0 * self.cols + index.1]
    }
}

impl Add<Matrix> for Matrix
{
        type Output = Matrix;

        /// Adds matrices element wise.
        fn add(self, rhs: Matrix) -> Self::Output
        {
            assert!((self.rows == rhs.rows) && (self.cols == rhs.cols), "Can't add matrices of different shape.");
            let vec = add(self.elements.clone(), rhs.elements.clone());
            Matrix { rows: self.rows, cols: self.cols, elements: vec  }
        }
}

impl Add<DiagonalMatrix> for DiagonalMatrix
{
        type Output = DiagonalMatrix;

        /// Adds diagonal matrices element wise.
        fn add(self, rhs: DiagonalMatrix) -> Self::Output
        { 
            let vec = add(self.elements.clone(), rhs.elements.clone());
            DiagonalMatrix { elements: vec  }
        }
}

impl Mul<Matrix> for Matrix
{
        type Output = Matrix;

        /// Performs standard multiplication of two matrices.
        fn mul(self, rhs: Matrix) -> Self::Output
        {
            assert!(self.cols == rhs.rows, "Can't multiply matrices of conflicting shape.");
            let mut result = Matrix::new(self.rows, rhs.cols);

            for i in 0..self.rows
            {
                for j in 0..rhs.cols
                {
                    let mut sum: f64 = 0_f64;
                    for k in 0..self.cols
                    {
                        sum += self[(i, k)] * rhs[(k, j)];
                    }
                    result[(i, j)] = sum;
                }
            }

            result
        }
}

impl Mul<DiagonalMatrix> for DiagonalMatrix
{
        type Output = DiagonalMatrix;

        /// Performs standard multiplication of two diagonal matrices.
        fn mul(self, rhs: DiagonalMatrix) -> Self::Output
        {
            let vec = mul(self.elements, rhs.elements);
            DiagonalMatrix { elements: vec  }
        }
}

impl Mul<DiagonalMatrix> for Matrix
{
        type Output = Matrix;

        /// Performs standard multiplication of a matrix on the left hand side with a diagonal matrix on the right hand side.
        fn mul(self, rhs: DiagonalMatrix) -> Self::Output
        {
            assert!(self.cols == rhs.elements.len(), "Can't multiply matrices of conflicting shape.");
            let mut result = Matrix::new(self.rows, self.cols);

            for i in 0..self.rows
            {
                for j in 0..self.cols
                {
                    result[(i, j)] = self[(i, j)] * rhs.elements[j];
                }
            }

            result
        }
}

impl Mul<Matrix> for DiagonalMatrix
{
        type Output = Matrix;

        /// Performs standard multiplication of a diagonal matrix on the left hand side with a matrix on the right hand side.
        fn mul(self, rhs: Matrix) -> Self::Output
        {
            assert!(self.elements.len() == rhs.rows, "Can't multiply matrices of conflicting shape.");
            let mut result = Matrix::new(rhs.rows, rhs.cols);

            for i in 0..rhs.rows
            {
                for j in 0..rhs.cols
                {
                    result[(i, j)] = self.elements[i] * rhs[(i, j)];
                }
            }

            result
        }
}