#![cfg_attr(has_deprecated, deprecated(note="`rblas` integration has moved to crate `ndarray-rblas`, use it instead."))]
#![allow(deprecated)]
use std::os::raw::{c_int};
use rblas::{
Matrix,
Vector,
};
use super::{
ShapeError,
zipsl,
};
use error::{from_kind, ErrorKind};
use imp_prelude::*;
pub struct BlasArrayView<'a, A: 'a, D>(ArrayView<'a, A, D>);
impl<'a, A, D: Copy> Copy for BlasArrayView<'a, A, D> { }
impl<'a, A, D: Clone> Clone for BlasArrayView<'a, A, D> {
fn clone(&self) -> Self {
BlasArrayView(self.0.clone())
}
}
pub struct BlasArrayViewMut<'a, A: 'a, D>(ArrayViewMut<'a, A, D>);
impl<S, D> ArrayBase<S, D>
where S: Data,
D: Dimension
{
fn size_check(&self) -> Result<(), ShapeError> {
let max = c_int::max_value();
for (&dim, &stride) in zipsl(self.shape(), self.strides()) {
if dim > max as Ix || stride > max as Ixs {
return Err(from_kind(ErrorKind::RangeLimited));
}
}
Ok(())
}
fn contiguous_check(&self) -> Result<(), ShapeError> {
if self.is_inner_contiguous() {
Ok(())
} else {
Err(from_kind(ErrorKind::IncompatibleLayout))
}
}
}
impl<'a, A, D> Priv<ArrayView<'a, A, D>>
where D: Dimension
{
pub fn into_blas_view(self) -> Result<BlasArrayView<'a, A, D>, ShapeError> {
let self_ = self.0;
if self_.dim.ndim() > 1 {
try!(self_.contiguous_check());
}
try!(self_.size_check());
Ok(BlasArrayView(self_))
}
}
impl<'a, A, D> ArrayViewMut<'a, A, D>
where D: Dimension
{
fn into_blas_view_mut(self) -> Result<BlasArrayViewMut<'a, A, D>, ShapeError> {
if self.dim.ndim() > 1 {
try!(self.contiguous_check());
}
try!(self.size_check());
Ok(BlasArrayViewMut(self))
}
}
pub trait AsBlas<A, S, D> {
fn blas_checked(&mut self) -> Result<BlasArrayViewMut<A, D>, ShapeError>
where S: DataOwned + DataMut,
A: Clone;
fn blas(&mut self) -> BlasArrayViewMut<A, D>
where S: DataOwned<Elem=A> + DataMut,
A: Clone
{
self.blas_checked().unwrap()
}
fn blas_view_checked(&self) -> Result<BlasArrayView<A, D>, ShapeError>
where S: Data;
fn bv(&self) -> BlasArrayView<A, D>
where S: Data,
{
self.blas_view_checked().unwrap()
}
fn blas_view_mut_checked(&mut self) -> Result<BlasArrayViewMut<A, D>, ShapeError>
where S: DataMut;
fn bvm(&mut self) -> BlasArrayViewMut<A, D>
where S: DataMut,
{
self.blas_view_mut_checked().unwrap()
}
}
impl<A, S, D> AsBlas<A, S, D> for ArrayBase<S, D>
where S: Data<Elem=A>,
D: Dimension,
{
fn blas_checked(&mut self) -> Result<BlasArrayViewMut<A, D>, ShapeError>
where S: DataOwned + DataMut,
A: Clone,
{
try!(self.size_check());
match self.dim.ndim() {
0 | 1 => { }
2 => {
if !self.is_inner_contiguous() {
self.ensure_standard_layout();
}
}
_n => self.ensure_standard_layout(),
}
self.view_mut().into_blas_view_mut()
}
fn blas_view_checked(&self) -> Result<BlasArrayView<A, D>, ShapeError>
where S: Data
{
Priv(self.view()).into_blas_view()
}
fn blas_view_mut_checked(&mut self) -> Result<BlasArrayViewMut<A, D>, ShapeError>
where S: DataMut,
{
self.view_mut().into_blas_view_mut()
}
}
impl<'a, A> Vector<A> for BlasArrayView<'a, A, Ix> {
fn len(&self) -> c_int {
self.0.len() as c_int
}
fn as_ptr(&self) -> *const A {
self.0.ptr
}
fn as_mut_ptr(&mut self) -> *mut A {
panic!("ndarray: as_mut_ptr called on BlasArrayView (not mutable)");
}
fn inc(&self) -> c_int {
self.0.strides as c_int
}
}
impl<'a, A> Vector<A> for BlasArrayViewMut<'a, A, Ix> {
fn len(&self) -> c_int {
self.0.len() as c_int
}
fn as_ptr(&self) -> *const A {
self.0.ptr
}
fn as_mut_ptr(&mut self) -> *mut A {
self.0.ptr
}
fn inc(&self) -> c_int {
self.0.strides as c_int
}
}
impl<'a, A> Matrix<A> for BlasArrayView<'a, A, (Ix, Ix)> {
fn rows(&self) -> c_int {
self.0.dim().0 as c_int
}
fn cols(&self) -> c_int {
self.0.dim().1 as c_int
}
fn lead_dim(&self) -> c_int {
debug_assert!(self.cols() <= 1 || self.0.strides()[1] == 1);
self.0.strides()[0] as c_int
}
fn as_ptr(&self) -> *const A {
self.0.ptr as *const _
}
fn as_mut_ptr(&mut self) -> *mut A {
panic!("ndarray: as_mut_ptr called on BlasArrayView (not mutable)");
}
}
impl<'a, A> Matrix<A> for BlasArrayViewMut<'a, A, (Ix, Ix)> {
fn rows(&self) -> c_int {
self.0.dim().0 as c_int
}
fn cols(&self) -> c_int {
self.0.dim().1 as c_int
}
fn lead_dim(&self) -> c_int {
debug_assert!(self.cols() <= 1 || self.0.strides()[1] == 1);
self.0.strides()[0] as c_int
}
fn as_ptr(&self) -> *const A {
self.0.ptr as *const _
}
fn as_mut_ptr(&mut self) -> *mut A {
self.0.ptr
}
}