blas_array2/blas2/
hbmv.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6/* #region BLAS func */
7
8pub trait HBMVNum: BLASFloat {
9    unsafe fn hbmv(
10        uplo: *const c_char,
11        n: *const blas_int,
12        k: *const blas_int,
13        alpha: *const Self,
14        a: *const Self,
15        lda: *const blas_int,
16        x: *const Self,
17        incx: *const blas_int,
18        beta: *const Self,
19        y: *mut Self,
20        incy: *const blas_int,
21    );
22}
23
24macro_rules! impl_func {
25    ($type: ty, $func: ident) => {
26        impl HBMVNum for $type {
27            unsafe fn hbmv(
28                uplo: *const c_char,
29                n: *const blas_int,
30                k: *const blas_int,
31                alpha: *const Self,
32                a: *const Self,
33                lda: *const blas_int,
34                x: *const Self,
35                incx: *const blas_int,
36                beta: *const Self,
37                y: *mut Self,
38                incy: *const blas_int,
39            ) {
40                ffi::$func(uplo, n, k, alpha, a, lda, x, incx, beta, y, incy);
41            }
42        }
43    };
44}
45
46impl_func!(f32, ssbmv_);
47impl_func!(f64, dsbmv_);
48impl_func!(c32, chbmv_);
49impl_func!(c64, zhbmv_);
50
51/* #endregion */
52
53/* #region BLAS driver */
54
55pub struct HBMV_Driver<'a, 'x, 'y, F>
56where
57    F: HBMVNum,
58{
59    uplo: c_char,
60    n: blas_int,
61    k: blas_int,
62    alpha: F,
63    a: ArrayView2<'a, F>,
64    lda: blas_int,
65    x: ArrayView1<'x, F>,
66    incx: blas_int,
67    beta: F,
68    y: ArrayOut1<'y, F>,
69    incy: blas_int,
70}
71
72impl<'a, 'x, 'y, F> BLASDriver<'y, F, Ix1> for HBMV_Driver<'a, 'x, 'y, F>
73where
74    F: HBMVNum,
75{
76    fn run_blas(self) -> Result<ArrayOut1<'y, F>, BLASError> {
77        let Self { uplo, n, k, alpha, a, lda, x, incx, beta, mut y, incy, .. } = self;
78        let a_ptr = a.as_ptr();
79        let x_ptr = x.as_ptr();
80        let y_ptr = y.get_data_mut_ptr();
81
82        // assuming dimension checks has been performed
83        // unconditionally return Ok if output does not contain anything
84        if n == 0 {
85            return Ok(y);
86        }
87
88        unsafe {
89            F::hbmv(&uplo, &n, &k, &alpha, a_ptr, &lda, x_ptr, &incx, &beta, y_ptr, &incy);
90        }
91        return Ok(y);
92    }
93}
94
95/* #endregion */
96
97/* #region BLAS builder */
98
99#[derive(Builder)]
100#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
101pub struct HBMV_<'a, 'x, 'y, F>
102where
103    F: HBMVNum,
104{
105    pub a: ArrayView2<'a, F>,
106    pub x: ArrayView1<'x, F>,
107
108    #[builder(setter(into, strip_option), default = "None")]
109    pub y: Option<ArrayViewMut1<'y, F>>,
110    #[builder(setter(into), default = "F::one()")]
111    pub alpha: F,
112    #[builder(setter(into), default = "F::zero()")]
113    pub beta: F,
114    #[builder(setter(into), default = "BLASUpper")]
115    pub uplo: BLASUpLo,
116    #[builder(setter(into, strip_option), default = "None")]
117    pub layout: Option<BLASLayout>,
118}
119
120impl<'a, 'x, 'y, F> BLASBuilder_<'y, F, Ix1> for HBMV_<'a, 'x, 'y, F>
121where
122    F: HBMVNum,
123{
124    fn driver(self) -> Result<HBMV_Driver<'a, 'x, 'y, F>, BLASError> {
125        let Self { a, x, y, alpha, beta, uplo, layout, .. } = self;
126
127        // only fortran-preferred (col-major) is accepted in inner wrapper
128        let layout_a = get_layout_array2(&a);
129        assert!(layout_a.is_fpref());
130        assert!(layout == Some(BLASLayout::ColMajor));
131
132        // initialize intent(hide)
133        let (k_, n) = a.dim();
134        blas_assert!(k_ > 0, InvalidDim, "Rows of input `a` must larger than zero.")?;
135        let k = k_ - 1;
136        let lda = a.stride_of(Axis(1));
137        let incx = x.stride_of(Axis(0));
138
139        // perform check
140        blas_assert_eq!(x.len_of(Axis(0)), n, InvalidDim)?;
141
142        // prepare output
143        let y = match y {
144            Some(y) => {
145                blas_assert_eq!(y.len_of(Axis(0)), n, InvalidDim)?;
146                ArrayOut1::ViewMut(y)
147            },
148            None => ArrayOut1::Owned(Array1::zeros(n)),
149        };
150        let incy = y.view().stride_of(Axis(0));
151
152        // finalize
153        let driver = HBMV_Driver {
154            uplo: uplo.try_into()?,
155            n: n.try_into()?,
156            k: k.try_into()?,
157            alpha,
158            a,
159            lda: lda.try_into()?,
160            x,
161            incx: incx.try_into()?,
162            beta,
163            y,
164            incy: incy.try_into()?,
165        };
166        return Ok(driver);
167    }
168}
169
170/* #endregion */
171
172/* #region BLAS wrapper */
173
174pub type HBMV<'a, 'x, 'y, F> = HBMV_Builder<'a, 'x, 'y, F>;
175pub type SSBMV<'a, 'x, 'y> = HBMV<'a, 'x, 'y, f32>;
176pub type DSBMV<'a, 'x, 'y> = HBMV<'a, 'x, 'y, f64>;
177pub type CHBMV<'a, 'x, 'y> = HBMV<'a, 'x, 'y, c32>;
178pub type ZHBMV<'a, 'x, 'y> = HBMV<'a, 'x, 'y, c64>;
179
180impl<'a, 'x, 'y, F> BLASBuilder<'y, F, Ix1> for HBMV_Builder<'a, 'x, 'y, F>
181where
182    F: HBMVNum,
183{
184    fn run(self) -> Result<ArrayOut1<'y, F>, BLASError> {
185        // initialize
186        let obj = self.build()?;
187
188        let layout_a = get_layout_array2(&obj.a);
189        let layout = get_layout_row_preferred(&[obj.layout, Some(layout_a)], &[]);
190
191        if layout == BLASColMajor {
192            // F-contiguous
193            let a_cow = obj.a.to_col_layout()?;
194            let obj = HBMV_ { a: a_cow.view(), layout: Some(BLASColMajor), ..obj };
195            return obj.driver()?.run_blas();
196        } else {
197            // C-contiguous
198            let a_cow = obj.a.to_row_layout()?;
199            if F::is_complex() {
200                let x = obj.x.mapv(F::conj);
201                let y = obj.y.map(|mut y| {
202                    y.mapv_inplace(F::conj);
203                    y
204                });
205                let obj = HBMV_ {
206                    a: a_cow.t(),
207                    x: x.view(),
208                    y,
209                    uplo: obj.uplo.flip()?,
210                    alpha: F::conj(obj.alpha),
211                    beta: F::conj(obj.beta),
212                    layout: Some(BLASColMajor),
213                    ..obj
214                };
215                let mut y = obj.driver()?.run_blas()?;
216                y.view_mut().mapv_inplace(F::conj);
217                return Ok(y);
218            } else {
219                let obj = HBMV_ { a: a_cow.t(), uplo: obj.uplo.flip()?, layout: Some(BLASColMajor), ..obj };
220                return obj.driver()?.run_blas();
221            }
222        }
223    }
224}
225
226/* #endregion */