#![cfg(feature = "rblas")]
#![allow(deprecated)]
extern crate rblas;
extern crate num;
#[macro_use] extern crate ndarray;
use rblas::Gemm;
use rblas::attribute::Transpose;
use num::Float;
use ndarray::{
OwnedArray,
ArrayView,
ArrayViewMut,
arr2,
Ix,
ShapeError
};
use ndarray::blas::AsBlas;
#[test]
fn strided_matrix() {
let (n, m) = (45, 33);
let mut a = OwnedArray::linspace(0., ((n * m) - 1) as f32, n as usize * m as usize ).into_shape((n, m)).unwrap();
let mut b = OwnedArray::eye(m);
let mut res = OwnedArray::zeros(a.dim());
Gemm::gemm(&1., Transpose::NoTrans, &a.blas(), Transpose::NoTrans, &b.blas(),
&0., &mut res.blas());
assert_eq!(res, a);
let mut aprim = a.to_shared();
aprim.islice(s![0..12, 0..11]);
println!("{:?}", aprim.shape());
println!("{:?}", aprim.strides());
let mut b = OwnedArray::eye(aprim.shape()[1]);
let mut res = OwnedArray::zeros(aprim.dim());
Gemm::gemm(&1., Transpose::NoTrans, &aprim.blas(), Transpose::NoTrans, &b.blas(),
&0., &mut res.blas());
assert_eq!(res, aprim);
let (np, mp) = aprim.dim();
let mut res = OwnedArray::zeros((mp, np));
let mut b = OwnedArray::eye(np);
Gemm::gemm(&1., Transpose::Trans, &aprim.blas(), Transpose::NoTrans, &b.blas(),
&0., &mut res.blas());
let mut at = aprim.clone();
at.swap_axes(0, 1);
assert_eq!(at, res);
let mut abis = a.to_shared();
abis.islice(s![0..12, ..;2]);
println!("{:?}", abis.shape());
println!("{:?}", abis.strides());
let mut b = OwnedArray::eye(abis.shape()[1]);
let mut res = OwnedArray::zeros(abis.dim());
Gemm::gemm(&1., Transpose::NoTrans, &abis.blas(), Transpose::NoTrans, &b.blas(),
&0., &mut res.blas());
assert_eq!(res, abis);
}
#[test]
fn strided_view() {
let (n, m) = (45, 33);
let mut a = OwnedArray::linspace(0., ((n * m) - 1) as f32, n as usize * m as usize ).into_shape((n, m)).unwrap();
let mut b = OwnedArray::eye(m);
let mut res = OwnedArray::zeros(a.dim());
Gemm::gemm(&1.,
Transpose::NoTrans, &a.blas_view_mut_checked().unwrap(),
Transpose::NoTrans, &b.blas_view_mut_checked().unwrap(),
&0., &mut res.blas_view_mut_checked().unwrap());
assert_eq!(res, a);
let aprim = a.slice(s![0..12, 0..11]);
let mut b = OwnedArray::eye(aprim.shape()[1]);
let mut res = OwnedArray::zeros(aprim.dim());
Gemm::gemm(&1.,
Transpose::NoTrans, &aprim.bv(),
Transpose::NoTrans, &b.blas(),
&0., &mut res.blas());
assert_eq!(res, aprim);
let mut a3 = arr2(&[[1., 2., 3.]]);
a3.swap_axes(0, 1);
let mut b = OwnedArray::eye(a3.shape()[1]);
let mut res = arr2(&[[0., 0., 0.]]);
res.swap_axes(0, 1);
Gemm::gemm(&1.,
Transpose::NoTrans, &a3.bvm(),
Transpose::NoTrans, &b.blas(),
&0., &mut res.bvm());
assert_eq!(res, a3);
}
#[test]
fn as_blas() {
let mut a = OwnedArray::<f32, _>::zeros((4, 4));
assert!(a.blas_view_mut_checked().is_ok());
a.swap_axes(0, 1);
assert!(a.blas_view_mut_checked().is_err());
a.swap_axes(0, 1);
{
let mut b = a.slice_mut(s![..;2, ..]);
assert!(b.blas_view_mut_checked().is_ok());
b.bvm(); }
{
let mut b = a.slice_mut(s![.., ..;2]);
assert!(b.blas_view_mut_checked().is_err());
}
{
let mut b = a.slice_mut(s![.., ..;4]);
assert!(b.blas_view_mut_checked().is_ok());
b.bvm();
}
}
type Ix2 = (Ix, Ix);
fn dot<F>(a: ArrayView<F, Ix2>, b: ArrayView<F, Ix2>,
c: &mut ArrayViewMut<F, Ix2>) -> Result<(), ShapeError>
where F: Gemm + Float,
{
let at = Transpose::NoTrans;
let bt = Transpose::NoTrans;
let ba = try!(a.blas_view_checked());
let bb = try!(b.blas_view_checked());
let mut bc = try!(c.blas_view_mut_checked());
F::gemm(&F::one(),
at, &ba,
bt, &bb,
&F::zero(), &mut bc);
Ok(())
}
#[test]
fn test_dot() {
let mut a = arr2(&[[1., 2.],
[0., 3.]]);
let b = a.clone();
let mut res = a.clone();
res.assign_scalar(&0.);
dot(a.view(), b.view(), &mut res.view_mut()).unwrap();
println!("{:?}", res);
a.swap_axes(0, 1);
res.assign_scalar(&0.);
let result = dot(a.view(), b.view(), &mut res.view_mut());
assert!(result.is_err());
}