extern crate rblas;
extern crate ndarray;
use std::os::raw::{c_int};
use rblas::{
Matrix,
Vector,
};
use ndarray::{
ShapeError,
ErrorKind,
ArrayView,
ArrayViewMut,
Data,
DataOwned,
DataMut,
Dimension,
ArrayBase,
Ix, Ixs,
};
pub struct BlasArrayView<'a, A: 'a, D>(ArrayView<'a, A, D>);
impl<'a, A, D: Copy> Copy for BlasArrayView<'a, A, D> { }
impl<'a, A, D: Clone> Clone for BlasArrayView<'a, A, D> {
fn clone(&self) -> Self {
BlasArrayView(self.0.clone())
}
}
pub struct BlasArrayViewMut<'a, A: 'a, D>(ArrayViewMut<'a, A, D>);
struct Priv<T>(T);
fn is_inner_contiguous<S, D>(a: &ArrayBase<S, D>) -> bool
where S: Data,
D: Dimension,
{
let ndim = a.ndim();
if ndim == 0 {
return true;
}
a.shape()[ndim - 1] <= 1 || a.strides()[ndim - 1] == 1
}
#[allow(deprecated)] fn ensure_standard_layout<A, S, D>(a: &mut ArrayBase<S, D>)
where S: DataOwned<Elem=A>,
D: Dimension,
A: Clone
{
if !a.is_standard_layout() {
let d = a.dim();
let v: Vec<A> = a.iter().cloned().collect();
*a = ArrayBase::from_vec_dim(d, v).unwrap();
}
}
impl<S, D> Priv<ArrayBase<S, D>>
where S: Data,
D: Dimension
{
fn size_check(&self) -> Result<(), ShapeError> {
let max = c_int::max_value();
let self_ = &self.0;
for (&dim, &stride) in self_.shape().iter().zip(self_.strides()) {
if dim > max as Ix || stride > max as Ixs {
return Err(ShapeError::from_kind(ErrorKind::RangeLimited));
}
}
Ok(())
}
fn contiguous_check(&self) -> Result<(), ShapeError> {
if is_inner_contiguous(&self.0) {
Ok(())
} else {
Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout))
}
}
}
impl<'a, A, D> Priv<ArrayView<'a, A, D>>
where D: Dimension
{
pub fn into_blas_view(self) -> Result<BlasArrayView<'a, A, D>, ShapeError> {
if self.0.ndim() > 1 {
try!(self.contiguous_check());
}
try!(self.size_check());
Ok(BlasArrayView(self.0))
}
}
impl<'a, A, D> Priv<ArrayViewMut<'a, A, D>>
where D: Dimension
{
fn into_blas_view_mut(self) -> Result<BlasArrayViewMut<'a, A, D>, ShapeError> {
if self.0.ndim() > 1 {
try!(self.contiguous_check());
}
try!(self.size_check());
Ok(BlasArrayViewMut(self.0))
}
}
pub trait AsBlas<A, S, D> {
fn blas_checked(&mut self) -> Result<BlasArrayViewMut<A, D>, ShapeError>
where S: DataOwned + DataMut,
A: Clone;
fn blas(&mut self) -> BlasArrayViewMut<A, D>
where S: DataOwned<Elem=A> + DataMut,
A: Clone
{
self.blas_checked().unwrap()
}
fn blas_view_checked(&self) -> Result<BlasArrayView<A, D>, ShapeError>
where S: Data;
fn bv(&self) -> BlasArrayView<A, D>
where S: Data,
{
self.blas_view_checked().unwrap()
}
fn blas_view_mut_checked(&mut self) -> Result<BlasArrayViewMut<A, D>, ShapeError>
where S: DataMut;
fn bvm(&mut self) -> BlasArrayViewMut<A, D>
where S: DataMut,
{
self.blas_view_mut_checked().unwrap()
}
}
impl<A, S, D> AsBlas<A, S, D> for ArrayBase<S, D>
where S: Data<Elem=A>,
D: Dimension,
{
fn blas_checked(&mut self) -> Result<BlasArrayViewMut<A, D>, ShapeError>
where S: DataOwned + DataMut,
A: Clone,
{
try!(Priv(self.view()).size_check());
match self.ndim() {
0 | 1 => { }
2 => {
if !is_inner_contiguous(self) {
ensure_standard_layout(self);
}
}
_n => ensure_standard_layout(self),
}
Priv(self.view_mut()).into_blas_view_mut()
}
fn blas_view_checked(&self) -> Result<BlasArrayView<A, D>, ShapeError>
where S: Data
{
Priv(self.view()).into_blas_view()
}
fn blas_view_mut_checked(&mut self) -> Result<BlasArrayViewMut<A, D>, ShapeError>
where S: DataMut,
{
Priv(self.view_mut()).into_blas_view_mut()
}
}
impl<'a, A> Vector<A> for BlasArrayView<'a, A, Ix> {
fn len(&self) -> c_int {
self.0.len() as c_int
}
fn as_ptr(&self) -> *const A {
self.0.as_ptr()
}
fn as_mut_ptr(&mut self) -> *mut A {
panic!("ndarray: as_mut_ptr called on BlasArrayView (not mutable)");
}
fn inc(&self) -> c_int {
self.0.strides()[0] as c_int
}
}
impl<'a, A> Vector<A> for BlasArrayViewMut<'a, A, Ix> {
fn len(&self) -> c_int {
self.0.len() as c_int
}
fn as_ptr(&self) -> *const A {
self.0.as_ptr()
}
fn as_mut_ptr(&mut self) -> *mut A {
self.0.as_mut_ptr()
}
fn inc(&self) -> c_int {
self.0.strides()[0] as c_int
}
}
impl<'a, A> Matrix<A> for BlasArrayView<'a, A, (Ix, Ix)> {
fn rows(&self) -> c_int {
self.0.dim().0 as c_int
}
fn cols(&self) -> c_int {
self.0.dim().1 as c_int
}
fn lead_dim(&self) -> c_int {
debug_assert!(self.cols() <= 1 || self.0.strides()[1] == 1);
self.0.strides()[0] as c_int
}
fn as_ptr(&self) -> *const A {
self.0.as_ptr()
}
fn as_mut_ptr(&mut self) -> *mut A {
panic!("ndarray: as_mut_ptr called on BlasArrayView (not mutable)");
}
}
impl<'a, A> Matrix<A> for BlasArrayViewMut<'a, A, (Ix, Ix)> {
fn rows(&self) -> c_int {
self.0.dim().0 as c_int
}
fn cols(&self) -> c_int {
self.0.dim().1 as c_int
}
fn lead_dim(&self) -> c_int {
debug_assert!(self.cols() <= 1 || self.0.strides()[1] == 1);
self.0.strides()[0] as c_int
}
fn as_ptr(&self) -> *const A {
self.0.as_ptr()
}
fn as_mut_ptr(&mut self) -> *mut A {
self.0.as_mut_ptr()
}
}