extern crate rblas;
use std::os::raw::{c_int, c_uint};
use self::rblas::{
Matrix,
Vector,
};
use super::{
ArrayBase,
ArrayViewMut,
Ix,
ShapeError,
Data,
DataMut,
DataOwned,
Dimension,
zipsl,
};
pub struct BlasArrayViewMut<'a, A: 'a, D>(ArrayViewMut<'a, A, D>);
impl<S, D> ArrayBase<S, D>
where S: Data,
D: Dimension
{
fn size_check(&self) -> Result<(), ShapeError> {
let max = c_int::max_value();
for (&dim, &stride) in zipsl(self.shape(), self.strides()) {
if dim > max as c_uint || stride > max {
return Err(ShapeError::DimensionTooLarge(
self.shape().to_vec().into_boxed_slice()));
}
}
Ok(())
}
fn contiguous_check(&self) -> Result<(), ShapeError> {
if self.dim.ndim() <= 1 || self.strides().last().cloned() == Some(1) {
Ok(())
} else {
Err(ShapeError::IncompatibleLayout)
}
}
}
impl<'a, A, D> ArrayViewMut<'a, A, D>
where D: Dimension,
{
fn into_matrix_mut(self) -> Result<BlasArrayViewMut<'a, A, D>, ShapeError>
{
if self.dim.ndim() > 1 {
try!(self.contiguous_check());
}
try!(self.size_check());
Ok(BlasArrayViewMut(self))
}
}
pub trait AsBlas<A, S, D> {
fn blas_checked(&mut self) -> Result<BlasArrayViewMut<A, D>, ShapeError>
where S: DataOwned + DataMut,
A: Clone;
fn blas(&mut self) -> BlasArrayViewMut<A, D>
where S: DataOwned<Elem=A> + DataMut,
A: Clone
{
self.blas_checked().unwrap()
}
}
impl<A, S, D> AsBlas<A, S, D> for ArrayBase<S, D>
where S: Data<Elem=A>,
D: Dimension,
{
fn blas_checked(&mut self) -> Result<BlasArrayViewMut<A, D>, ShapeError>
where S: DataOwned + DataMut,
A: Clone,
{
try!(self.size_check());
match self.dim.ndim() {
0 | 1 => { }
2 => {
if self.strides()[1] != 1 {
self.ensure_standard_layout();
}
}
_n => self.ensure_standard_layout(),
}
self.view_mut().into_matrix_mut()
}
}
impl<'a, A> Vector<A> for BlasArrayViewMut<'a, A, Ix> {
fn len(&self) -> c_int {
self.0.len() as c_int
}
fn as_ptr(&self) -> *const A {
self.0.ptr
}
fn as_mut_ptr(&mut self) -> *mut A {
self.0.ptr
}
fn inc(&self) -> c_int {
self.0.strides as c_int
}
}
impl<'a, A> Matrix<A> for BlasArrayViewMut<'a, A, (Ix, Ix)> {
fn rows(&self) -> c_int {
self.0.dim().0 as c_int
}
fn cols(&self) -> c_int {
self.0.dim().1 as c_int
}
fn lead_dim(&self) -> c_int {
debug_assert_eq!(self.0.strides()[1], 1);
self.0.strides()[0] as c_int
}
fn as_ptr(&self) -> *const A {
self.0.ptr as *const _
}
fn as_mut_ptr(&mut self) -> *mut A {
self.0.ptr
}
}