rblas 0.0.13

BLAS bindings and wrappers
Documentation
// Copyright 2015 Michael Yang. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.

use std::ops::{
    Add,
    Mul,
};
use num::complex::{Complex32, Complex64};
use attribute::Transpose;
use default::Default;
use math::Trans;
use matrix::ops::*;
use matrix::Matrix;
use math::Mat;
use vector::ops::*;

impl<'a, T> Add for &'a Matrix<T>
    where T: Axpy + Copy + Default
{
    type Output = Mat<T>;

    fn add(self, b: &Matrix<T>) -> Mat<T> {
        if self.cols() != b.cols() || self.rows() != b.rows() {
            panic!("Dimension mismatch")
        }

        let scale = Default::one();
        let mut result = Mat::from(self);
        Axpy::axpy_mat(&scale, b, &mut result);
        result
    }
}

impl<'a, T> Mul<T> for &'a Matrix<T>
    where T: Sized + Copy + Scal
{
    type Output = Mat<T>;

    fn mul(self, alpha: T) -> Mat<T> {
        let mut result = Mat::from(self);
        Scal::scal_mat(&alpha, &mut result);
        result
    }
}

macro_rules! left_scale(($($t: ident), +) => (
    $(
        impl<'a> Mul<&'a Matrix<$t>> for $t
        {
            type Output = Mat<$t>;

            fn mul(self, x: &Matrix<$t>) -> Mat<$t> {
                let mut result = Mat::from(x);
                Scal::scal_mat(&self, &mut result);
                result
            }
        }
    )+
));

left_scale!(f32, f64, Complex32, Complex64);

impl<'a, T> Mul<&'a Matrix<T>> for &'a Matrix<T>
    where T: Default + Gemm,
{
    type Output = Mat<T>;

    fn mul(self, b: &Matrix<T>) -> Mat<T> {
        if self.cols() != b.rows() {
            panic!("Dimension mismatch");
        }

        let n = self.rows() as usize;
        let m = b.cols() as usize;
        let mut result = Mat::new(n, m);
        let t = Transpose::NoTrans;

        Gemm::gemm(&Default::one(), t, self, t, b, &Default::zero(), &mut result);
        result
    }
}

impl<'a, T> Mul<&'a Matrix<T>> for Trans<&'a Matrix<T>>
    where T: Default + Gemm,
{
    type Output = Mat<T>;

    fn mul(self, b: &Matrix<T>) -> Mat<T> {
        let (a, at) = match self {
            Trans::T(a) => (a, Transpose::Trans),
            Trans::H(a) => (a, Transpose::ConjTrans),
        };

        if a.rows() != b.rows() {
            panic!("Dimension mismatch");
        }

        let n = a.cols() as usize;
        let m = b.cols() as usize;
        let mut result = Mat::new(n, m);
        let bt = Transpose::NoTrans;

        Gemm::gemm(&Default::one(), at, a, bt, b, &Default::zero(), &mut result);
        result
    }
}

impl<'a, T> Mul<Trans<&'a Matrix<T>>> for &'a Matrix<T>
    where T: Default + Gemm,
{
    type Output = Mat<T>;

    fn mul(self, rhs: Trans<&Matrix<T>>) -> Mat<T> {
        let (b, bt) = match rhs {
            Trans::T(a) => (a, Transpose::Trans),
            Trans::H(a) => (a, Transpose::ConjTrans),
        };

        if self.cols() != b.cols() {
            panic!("Dimension mismatch");
        }

        let n = self.rows() as usize;
        let m = b.rows() as usize;
        let mut result = Mat::new(n, m);
        let at = Transpose::NoTrans;

        Gemm::gemm(&Default::one(), at, self, bt, b, &Default::zero(), &mut result);
        result
    }
}

impl<'a, T> Mul<Trans<&'a Matrix<T>>> for Trans<&'a Matrix<T>>
    where T: Default + Gemm,
{
    type Output = Mat<T>;

    fn mul(self, rhs: Trans<&Matrix<T>>) -> Mat<T> {
        let (a, at) = match self {
            Trans::T(a) => (a, Transpose::Trans),
            Trans::H(a) => (a, Transpose::ConjTrans),
        };

        let (b, bt) = match rhs {
            Trans::T(a) => (a, Transpose::Trans),
            Trans::H(a) => (a, Transpose::ConjTrans),
        };

        if self.rows() != b.cols() {
            panic!("Dimension mismatch");
        }

        let n = self.cols() as usize;
        let m = b.rows() as usize;
        let mut result = Mat::new(n, m);

        Gemm::gemm(&Default::one(), at, a, bt, b, &Default::zero(), &mut result);
        result
    }
}

#[cfg(test)]
mod tests {
    use Matrix;
    use math::Mat;
    use math::Marker::T;

    #[test]
    fn add() {
        let a = mat![1.0, 2.0; 3.0, 4.0];
        let b = mat![-1.0, 3.0; 1.0, 1.0];

        let c = {
            let ar = &a as &Matrix<_>;
            let br = &b as &Matrix<_>;
            ar + br
        };

        assert_eq!(c, mat![0.0, 5.0; 4.0, 5.0]);
    }

    #[test]
    fn scale() {
        let x = mat![1f32, 2f32; 3f32, 4f32];
        let xr = &x as &Matrix<_>;

        let y = xr * 3.0;
        let z = 3.0 * xr;
        assert_eq!(y, mat![3f32, 6f32; 9f32, 12f32]);
        assert_eq!(z, y);
    }

    #[test]
    fn mul() {
        let a = mat![1.0, 2.0; 3.0, 4.0];
        let b = mat![-1.0, 3.0; 1.0, 1.0];

        let c = {
            let ar = &a as &Matrix<_>;
            let br = &b as &Matrix<_>;
            ar * br
        };

        assert_eq!(c, mat![1.0, 5.0; 1.0, 13.0]);
    }

    #[test]
    fn left_mul_trans() {
        let a = mat![1.0, 3.0; 2.0, 4.0];
        let b = mat![-1.0, 3.0; 1.0, 1.0];

        let c = {
            let ar = &a as &Matrix<_>;
            let br = &b as &Matrix<_>;
            (ar ^ T) * br
        };

        assert_eq!(c, mat![1.0, 5.0; 1.0, 13.0]);
    }

    #[test]
    fn right_mul_trans() {
        let a = mat![1.0, 2.0; 3.0, 4.0];
        let b = mat![-1.0, 1.0; 3.0, 1.0];

        let c = {
            let ar = &a as &Matrix<_>;
            let br = &b as &Matrix<_>;
            ar * (br ^ T)
        };

        assert_eq!(c, mat![1.0, 5.0; 1.0, 13.0]);
    }

    #[test]
    fn mul_trans() {
        let a = mat![1.0, 3.0; 2.0, 4.0];
        let b = mat![-1.0, 1.0; 3.0, 1.0];

        let c = {
            let ar = &a as &Matrix<_>;
            let br = &b as &Matrix<_>;
            (ar ^ T) * (br ^ T)
        };

        assert_eq!(c, mat![1.0, 5.0; 1.0, 13.0]);
    }
}