blas_array2/blas2/
hpmv.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6/* #region BLAS func */
7
8pub trait HPMVNum: BLASFloat {
9    unsafe fn hpmv(
10        uplo: *const c_char,
11        n: *const blas_int,
12        alpha: *const Self,
13        ap: *const Self,
14        x: *const Self,
15        incx: *const blas_int,
16        beta: *const Self,
17        y: *mut Self,
18        incy: *const blas_int,
19    );
20}
21
22macro_rules! impl_func {
23    ($type: ty, $func: ident) => {
24        impl HPMVNum for $type {
25            unsafe fn hpmv(
26                uplo: *const c_char,
27                n: *const blas_int,
28                alpha: *const Self,
29                ap: *const Self,
30                x: *const Self,
31                incx: *const blas_int,
32                beta: *const Self,
33                y: *mut Self,
34                incy: *const blas_int,
35            ) {
36                ffi::$func(uplo, n, alpha, ap, x, incx, beta, y, incy);
37            }
38        }
39    };
40}
41
42impl_func!(f32, sspmv_);
43impl_func!(f64, dspmv_);
44impl_func!(c32, chpmv_);
45impl_func!(c64, zhpmv_);
46
47/* #endregion */
48
49/* #region BLAS driver */
50
51pub struct HPMV_Driver<'a, 'x, 'y, F>
52where
53    F: HPMVNum,
54{
55    uplo: c_char,
56    n: blas_int,
57    alpha: F,
58    ap: ArrayView1<'a, F>,
59    x: ArrayView1<'x, F>,
60    incx: blas_int,
61    beta: F,
62    y: ArrayOut1<'y, F>,
63    incy: blas_int,
64}
65
66impl<'a, 'x, 'y, F> BLASDriver<'y, F, Ix1> for HPMV_Driver<'a, 'x, 'y, F>
67where
68    F: HPMVNum,
69{
70    fn run_blas(self) -> Result<ArrayOut1<'y, F>, BLASError> {
71        let Self { uplo, n, alpha, ap, x, incx, beta, mut y, incy, .. } = self;
72        let ap_ptr = ap.as_ptr();
73        let x_ptr = x.as_ptr();
74        let y_ptr = y.get_data_mut_ptr();
75
76        // assuming dimension checks has been performed
77        // unconditionally return Ok if output does not contain anything
78        if n == 0 {
79            return Ok(y);
80        }
81
82        unsafe {
83            F::hpmv(&uplo, &n, &alpha, ap_ptr, x_ptr, &incx, &beta, y_ptr, &incy);
84        }
85        return Ok(y);
86    }
87}
88
89/* #endregion */
90
91/* #region BLAS builder */
92
93#[derive(Builder)]
94#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
95pub struct HPMV_<'a, 'x, 'y, F>
96where
97    F: HPMVNum,
98{
99    pub ap: ArrayView1<'a, F>,
100    pub x: ArrayView1<'x, F>,
101
102    #[builder(setter(into, strip_option), default = "None")]
103    pub y: Option<ArrayViewMut1<'y, F>>,
104    #[builder(setter(into), default = "F::one()")]
105    pub alpha: F,
106    #[builder(setter(into), default = "F::zero()")]
107    pub beta: F,
108    #[builder(setter(into), default = "BLASUpper")]
109    pub uplo: BLASUpLo,
110    #[builder(setter(into, strip_option), default = "None")]
111    pub layout: Option<BLASLayout>,
112}
113
114impl<'a, 'x, 'y, F> BLASBuilder_<'y, F, Ix1> for HPMV_<'a, 'x, 'y, F>
115where
116    F: HPMVNum,
117{
118    fn driver(self) -> Result<HPMV_Driver<'a, 'x, 'y, F>, BLASError> {
119        let Self { ap, x, y, alpha, beta, uplo, layout, .. } = self;
120
121        // only fortran-preferred (col-major) is accepted in inner wrapper
122        let incap = ap.stride_of(Axis(0));
123        assert!(incap <= 1);
124        assert_eq!(layout, Some(BLASColMajor));
125
126        // initialize intent(hide)
127        let np = ap.len_of(Axis(0));
128        let n = x.len_of(Axis(0));
129        let incx = x.stride_of(Axis(0));
130
131        // perform check
132        blas_assert_eq!(np, n * (n + 1) / 2, InvalidDim)?;
133
134        // prepare output
135        let y = match y {
136            Some(y) => {
137                blas_assert_eq!(y.len_of(Axis(0)), n, InvalidDim)?;
138                ArrayOut1::ViewMut(y)
139            },
140            None => ArrayOut1::Owned(Array1::zeros(n)),
141        };
142        let incy = y.view().stride_of(Axis(0));
143
144        // finalize
145        let driver = HPMV_Driver {
146            uplo: uplo.try_into()?,
147            n: n.try_into()?,
148            alpha,
149            ap,
150            x,
151            incx: incx.try_into()?,
152            beta,
153            y,
154            incy: incy.try_into()?,
155        };
156        return Ok(driver);
157    }
158}
159
160/* #endregion */
161
162/* #region BLAS wrapper */
163
164pub type HPMV<'a, 'x, 'y, F> = HPMV_Builder<'a, 'x, 'y, F>;
165pub type SSPMV<'a, 'x, 'y> = HPMV<'a, 'x, 'y, f32>;
166pub type DSPMV<'a, 'x, 'y> = HPMV<'a, 'x, 'y, f64>;
167pub type CHPMV<'a, 'x, 'y> = HPMV<'a, 'x, 'y, c32>;
168pub type ZHPMV<'a, 'x, 'y> = HPMV<'a, 'x, 'y, c64>;
169
170impl<'a, 'x, 'y, F> BLASBuilder<'y, F, Ix1> for HPMV_Builder<'a, 'x, 'y, F>
171where
172    F: HPMVNum,
173{
174    fn run(self) -> Result<ArrayOut1<'y, F>, BLASError> {
175        // initialize
176        let obj = self.build()?;
177
178        let layout = obj.layout.unwrap_or(BLASRowMajor);
179
180        if layout == BLASColMajor {
181            // F-contiguous
182            let ap_cow = obj.ap.to_seq_layout()?;
183            let obj = HPMV_ { ap: ap_cow.view(), layout: Some(BLASColMajor), ..obj };
184            return obj.driver()?.run_blas();
185        } else {
186            // C-contiguous
187            let ap_cow = obj.ap.to_seq_layout()?;
188            if F::is_complex() {
189                let x = obj.x.mapv(F::conj);
190                let y = obj.y.map(|mut y| {
191                    y.mapv_inplace(F::conj);
192                    y
193                });
194                let obj = HPMV_ {
195                    ap: ap_cow.view(),
196                    x: x.view(),
197                    y,
198                    uplo: obj.uplo.flip()?,
199                    alpha: F::conj(obj.alpha),
200                    beta: F::conj(obj.beta),
201                    layout: Some(BLASColMajor),
202                    ..obj
203                };
204                let mut y = obj.driver()?.run_blas()?;
205                y.view_mut().mapv_inplace(F::conj);
206                return Ok(y);
207            } else {
208                let obj =
209                    HPMV_ { ap: ap_cow.view(), uplo: obj.uplo.flip()?, layout: Some(BLASColMajor), ..obj };
210                return obj.driver()?.run_blas();
211            }
212        }
213    }
214}
215
216/* #endregion */