use crate::ffi::{self, blas_int, c_char};
use crate::util::*;
use derive_builder::Builder;
use ndarray::prelude::*;
pub trait SPMVFunc<F, S>
where
F: BLASFloat,
S: BLASSymmetric,
{
unsafe fn spmv(
uplo: *const c_char,
n: *const blas_int,
alpha: *const F,
ap: *const F,
x: *const F,
incx: *const blas_int,
beta: *const F,
y: *mut F,
incy: *const blas_int,
);
}
macro_rules! impl_func {
($type: ty, $symm: ty, $func: ident) => {
impl SPMVFunc<$type, $symm> for BLASFunc
where
$type: BLASFloat,
{
unsafe fn spmv(
uplo: *const c_char,
n: *const blas_int,
alpha: *const $type,
ap: *const $type,
x: *const $type,
incx: *const blas_int,
beta: *const $type,
y: *mut $type,
incy: *const blas_int,
) {
ffi::$func(uplo, n, alpha, ap, x, incx, beta, y, incy);
}
}
};
}
impl_func!(f32, BLASSymm<f32>, sspmv_);
impl_func!(f64, BLASSymm<f64>, dspmv_);
impl_func!(c32, BLASHermi<c32>, chpmv_);
impl_func!(c64, BLASHermi<c64>, zhpmv_);
pub struct SPMV_Driver<'a, 'x, 'y, F, S>
where
F: BLASFloat,
S: BLASSymmetric,
{
uplo: c_char,
n: blas_int,
alpha: F,
ap: ArrayView1<'a, F>,
x: ArrayView1<'x, F>,
incx: blas_int,
beta: F,
y: ArrayOut1<'y, F>,
incy: blas_int,
_phantom: core::marker::PhantomData<S>,
}
impl<'a, 'x, 'y, F, S> BLASDriver<'y, F, Ix1> for SPMV_Driver<'a, 'x, 'y, F, S>
where
F: BLASFloat,
S: BLASSymmetric,
BLASFunc: SPMVFunc<F, S>,
{
fn run_blas(self) -> Result<ArrayOut1<'y, F>, BLASError> {
let Self { uplo, n, alpha, ap, x, incx, beta, mut y, incy, .. } = self;
let ap_ptr = ap.as_ptr();
let x_ptr = x.as_ptr();
let y_ptr = y.get_data_mut_ptr();
if n == 0 {
return Ok(y);
}
unsafe {
BLASFunc::spmv(&uplo, &n, &alpha, ap_ptr, x_ptr, &incx, &beta, y_ptr, &incy);
}
return Ok(y);
}
}
#[derive(Builder)]
#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
pub struct SPMV_<'a, 'x, 'y, F, S>
where
F: BLASFloat,
{
pub ap: ArrayView1<'a, F>,
pub x: ArrayView1<'x, F>,
#[builder(setter(into, strip_option), default = "None")]
pub y: Option<ArrayViewMut1<'y, F>>,
#[builder(setter(into), default = "F::one()")]
pub alpha: F,
#[builder(setter(into), default = "F::zero()")]
pub beta: F,
#[builder(setter(into), default = "BLASUpper")]
pub uplo: BLASUpLo,
#[builder(setter(into, strip_option), default = "None")]
pub layout: Option<BLASLayout>,
#[builder(private, default = "core::marker::PhantomData {}")]
_phantom: core::marker::PhantomData<S>,
}
impl<'a, 'x, 'y, F, S> BLASBuilder_<'y, F, Ix1> for SPMV_<'a, 'x, 'y, F, S>
where
F: BLASFloat,
S: BLASSymmetric,
BLASFunc: SPMVFunc<F, S>,
{
fn driver(self) -> Result<SPMV_Driver<'a, 'x, 'y, F, S>, BLASError> {
let Self { ap, x, y, alpha, beta, uplo, layout, .. } = self;
let incap = ap.stride_of(Axis(0));
assert!(incap <= 1);
assert_eq!(layout, Some(BLASColMajor));
let np = ap.len_of(Axis(0));
let n = x.len_of(Axis(0));
let incx = x.stride_of(Axis(0));
blas_assert_eq!(np, n * (n + 1) / 2, InvalidDim)?;
let y = match y {
Some(y) => {
blas_assert_eq!(y.len_of(Axis(0)), n, InvalidDim)?;
ArrayOut1::ViewMut(y)
},
None => ArrayOut1::Owned(Array1::zeros(n)),
};
let incy = y.view().stride_of(Axis(0));
let driver = SPMV_Driver {
uplo: uplo.into(),
n: n.try_into()?,
alpha,
ap,
x,
incx: incx.try_into()?,
beta,
y,
incy: incy.try_into()?,
_phantom: core::marker::PhantomData {},
};
return Ok(driver);
}
}
pub type SPMV<'a, 'x, 'y, F> = SPMV_Builder<'a, 'x, 'y, F, BLASSymm<F>>;
pub type SSPMV<'a, 'x, 'y> = SPMV<'a, 'x, 'y, f32>;
pub type DSPMV<'a, 'x, 'y> = SPMV<'a, 'x, 'y, f64>;
pub type HPMV<'a, 'x, 'y, F> = SPMV_Builder<'a, 'x, 'y, F, BLASHermi<F>>;
pub type CHPMV<'a, 'x, 'y> = HPMV<'a, 'x, 'y, c32>;
pub type ZHPMV<'a, 'x, 'y> = HPMV<'a, 'x, 'y, c64>;
impl<'a, 'x, 'y, F, S> BLASBuilder<'y, F, Ix1> for SPMV_Builder<'a, 'x, 'y, F, S>
where
F: BLASFloat,
S: BLASSymmetric,
BLASFunc: SPMVFunc<F, S>,
{
fn run(self) -> Result<ArrayOut1<'y, F>, BLASError> {
let obj = self.build()?;
let layout = obj.layout.unwrap_or(BLASRowMajor);
if layout == BLASColMajor {
let ap_cow = obj.ap.to_seq_layout()?;
let obj = SPMV_ { ap: ap_cow.view(), layout: Some(BLASColMajor), ..obj };
return obj.driver()?.run_blas();
} else {
let ap_cow = obj.ap.to_seq_layout()?;
if S::is_hermitian() {
let x = obj.x.mapv(F::conj);
let y = obj.y.map(|mut y| {
y.mapv_inplace(F::conj);
y
});
let obj = SPMV_ {
ap: ap_cow.view(),
x: x.view(),
y,
uplo: obj.uplo.flip(),
alpha: F::conj(obj.alpha),
beta: F::conj(obj.beta),
layout: Some(BLASColMajor),
..obj
};
let mut y = obj.driver()?.run_blas()?;
y.view_mut().mapv_inplace(F::conj);
return Ok(y);
} else {
let obj =
SPMV_ { ap: ap_cow.view(), uplo: obj.uplo.flip(), layout: Some(BLASColMajor), ..obj };
return obj.driver()?.run_blas();
}
}
}
}