blas_array2/blas2/
hemv.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6/* #region BLAS func */
7
8pub trait HEMVNum: BLASFloat {
9    unsafe fn hemv(
10        uplo: *const c_char,
11        n: *const blas_int,
12        alpha: *const Self,
13        a: *const Self,
14        lda: *const blas_int,
15        x: *const Self,
16        incx: *const blas_int,
17        beta: *const Self,
18        y: *mut Self,
19        incy: *const blas_int,
20    );
21}
22
23macro_rules! impl_func {
24    ($type: ty, $func: ident) => {
25        impl HEMVNum for $type {
26            unsafe fn hemv(
27                uplo: *const c_char,
28                n: *const blas_int,
29                alpha: *const $type,
30                a: *const $type,
31                lda: *const blas_int,
32                x: *const $type,
33                incx: *const blas_int,
34                beta: *const $type,
35                y: *mut $type,
36                incy: *const blas_int,
37            ) {
38                ffi::$func(uplo, n, alpha, a, lda, x, incx, beta, y, incy);
39            }
40        }
41    };
42}
43
44impl_func!(f32, ssymv_);
45impl_func!(f64, dsymv_);
46impl_func!(c32, chemv_);
47impl_func!(c64, zhemv_);
48
49/* #endregion */
50
51/* #region BLAS driver */
52
53pub struct HEMV_Driver<'a, 'x, 'y, F>
54where
55    F: HEMVNum,
56{
57    uplo: c_char,
58    n: blas_int,
59    alpha: F,
60    a: ArrayView2<'a, F>,
61    lda: blas_int,
62    x: ArrayView1<'x, F>,
63    incx: blas_int,
64    beta: F,
65    y: ArrayOut1<'y, F>,
66    incy: blas_int,
67}
68
69impl<'a, 'x, 'y, F> BLASDriver<'y, F, Ix1> for HEMV_Driver<'a, 'x, 'y, F>
70where
71    F: HEMVNum,
72{
73    fn run_blas(self) -> Result<ArrayOut1<'y, F>, BLASError> {
74        let Self { uplo, n, alpha, a, lda, x, incx, beta, mut y, incy, .. } = self;
75        let a_ptr = a.as_ptr();
76        let x_ptr = x.as_ptr();
77        let y_ptr = y.get_data_mut_ptr();
78
79        // assuming dimension checks has been performed
80        // unconditionally return Ok if output does not contain anything
81        if n == 0 {
82            return Ok(y);
83        }
84
85        unsafe {
86            F::hemv(&uplo, &n, &alpha, a_ptr, &lda, x_ptr, &incx, &beta, y_ptr, &incy);
87        }
88        return Ok(y);
89    }
90}
91
92/* #endregion */
93
94/* #region BLAS builder */
95
96#[derive(Builder)]
97#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
98pub struct HEMV_<'a, 'x, 'y, F>
99where
100    F: BLASFloat,
101{
102    pub a: ArrayView2<'a, F>,
103    pub x: ArrayView1<'x, F>,
104
105    #[builder(setter(into, strip_option), default = "None")]
106    pub y: Option<ArrayViewMut1<'y, F>>,
107    #[builder(setter(into), default = "F::one()")]
108    pub alpha: F,
109    #[builder(setter(into), default = "F::zero()")]
110    pub beta: F,
111    #[builder(setter(into), default = "BLASUpper")]
112    pub uplo: BLASUpLo,
113}
114
115impl<'a, 'x, 'y, F> BLASBuilder_<'y, F, Ix1> for HEMV_<'a, 'x, 'y, F>
116where
117    F: HEMVNum,
118{
119    fn driver(self) -> Result<HEMV_Driver<'a, 'x, 'y, F>, BLASError> {
120        let Self { a, x, y, alpha, beta, uplo, .. } = self;
121
122        // only fortran-preferred (col-major) is accepted in inner wrapper
123        let layout_a = get_layout_array2(&a);
124        assert!(layout_a.is_fpref());
125
126        // initialize intent(hide)
127        let (n_, n) = a.dim();
128        let lda = a.stride_of(Axis(1));
129        let incx = x.stride_of(Axis(0));
130
131        // perform check
132        blas_assert_eq!(n, n_, InvalidDim)?;
133        blas_assert_eq!(x.len_of(Axis(0)), n, InvalidDim)?;
134
135        // prepare output
136        let y = match y {
137            Some(y) => {
138                blas_assert_eq!(y.len_of(Axis(0)), n, InvalidDim)?;
139                ArrayOut1::ViewMut(y)
140            },
141            None => ArrayOut1::Owned(Array1::zeros(n)),
142        };
143        let incy = y.view().stride_of(Axis(0));
144
145        // finalize
146        let driver = HEMV_Driver {
147            uplo: uplo.try_into()?,
148            n: n.try_into()?,
149            alpha,
150            a,
151            lda: lda.try_into()?,
152            x,
153            incx: incx.try_into()?,
154            beta,
155            y,
156            incy: incy.try_into()?,
157        };
158        return Ok(driver);
159    }
160}
161
162/* #endregion */
163
164/* #region BLAS wrapper */
165
166pub type HEMV<'a, 'x, 'y, F> = HEMV_Builder<'a, 'x, 'y, F>;
167pub type SSYMV<'a, 'x, 'y> = HEMV<'a, 'x, 'y, f32>;
168pub type DSYMV<'a, 'x, 'y> = HEMV<'a, 'x, 'y, f64>;
169pub type CHEMV<'a, 'x, 'y> = HEMV<'a, 'x, 'y, c32>;
170pub type ZHEMV<'a, 'x, 'y> = HEMV<'a, 'x, 'y, c64>;
171
172impl<'a, 'x, 'y, F> BLASBuilder<'y, F, Ix1> for HEMV_Builder<'a, 'x, 'y, F>
173where
174    F: HEMVNum,
175{
176    fn run(self) -> Result<ArrayOut1<'y, F>, BLASError> {
177        // initialize
178        let obj = self.build()?;
179
180        let layout_a = get_layout_array2(&obj.a);
181
182        if layout_a.is_fpref() {
183            // F-contiguous
184            return obj.driver()?.run_blas();
185        } else {
186            // C-contiguous
187            let a_cow = obj.a.to_row_layout()?;
188            if F::is_complex() {
189                let x = obj.x.mapv(F::conj);
190                let y = obj.y.map(|mut y| {
191                    y.mapv_inplace(F::conj);
192                    y
193                });
194                let obj = HEMV_ {
195                    a: a_cow.t(),
196                    x: x.view(),
197                    y,
198                    uplo: obj.uplo.flip()?,
199                    alpha: F::conj(obj.alpha),
200                    beta: F::conj(obj.beta),
201                    ..obj
202                };
203                let mut y = obj.driver()?.run_blas()?;
204                y.view_mut().mapv_inplace(F::conj);
205                return Ok(y);
206            } else {
207                let obj = HEMV_ { a: a_cow.t(), uplo: obj.uplo.flip()?, ..obj };
208                return obj.driver()?.run_blas();
209            }
210        }
211    }
212}
213
214/* #endregion */