1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6pub trait HEMVNum: BLASFloat {
9 unsafe fn hemv(
10 uplo: *const c_char,
11 n: *const blas_int,
12 alpha: *const Self,
13 a: *const Self,
14 lda: *const blas_int,
15 x: *const Self,
16 incx: *const blas_int,
17 beta: *const Self,
18 y: *mut Self,
19 incy: *const blas_int,
20 );
21}
22
23macro_rules! impl_func {
24 ($type: ty, $func: ident) => {
25 impl HEMVNum for $type {
26 unsafe fn hemv(
27 uplo: *const c_char,
28 n: *const blas_int,
29 alpha: *const $type,
30 a: *const $type,
31 lda: *const blas_int,
32 x: *const $type,
33 incx: *const blas_int,
34 beta: *const $type,
35 y: *mut $type,
36 incy: *const blas_int,
37 ) {
38 ffi::$func(uplo, n, alpha, a, lda, x, incx, beta, y, incy);
39 }
40 }
41 };
42}
43
44impl_func!(f32, ssymv_);
45impl_func!(f64, dsymv_);
46impl_func!(c32, chemv_);
47impl_func!(c64, zhemv_);
48
49pub struct HEMV_Driver<'a, 'x, 'y, F>
54where
55 F: HEMVNum,
56{
57 uplo: c_char,
58 n: blas_int,
59 alpha: F,
60 a: ArrayView2<'a, F>,
61 lda: blas_int,
62 x: ArrayView1<'x, F>,
63 incx: blas_int,
64 beta: F,
65 y: ArrayOut1<'y, F>,
66 incy: blas_int,
67}
68
69impl<'a, 'x, 'y, F> BLASDriver<'y, F, Ix1> for HEMV_Driver<'a, 'x, 'y, F>
70where
71 F: HEMVNum,
72{
73 fn run_blas(self) -> Result<ArrayOut1<'y, F>, BLASError> {
74 let Self { uplo, n, alpha, a, lda, x, incx, beta, mut y, incy, .. } = self;
75 let a_ptr = a.as_ptr();
76 let x_ptr = x.as_ptr();
77 let y_ptr = y.get_data_mut_ptr();
78
79 if n == 0 {
82 return Ok(y);
83 }
84
85 unsafe {
86 F::hemv(&uplo, &n, &alpha, a_ptr, &lda, x_ptr, &incx, &beta, y_ptr, &incy);
87 }
88 return Ok(y);
89 }
90}
91
92#[derive(Builder)]
97#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
98pub struct HEMV_<'a, 'x, 'y, F>
99where
100 F: BLASFloat,
101{
102 pub a: ArrayView2<'a, F>,
103 pub x: ArrayView1<'x, F>,
104
105 #[builder(setter(into, strip_option), default = "None")]
106 pub y: Option<ArrayViewMut1<'y, F>>,
107 #[builder(setter(into), default = "F::one()")]
108 pub alpha: F,
109 #[builder(setter(into), default = "F::zero()")]
110 pub beta: F,
111 #[builder(setter(into), default = "BLASUpper")]
112 pub uplo: BLASUpLo,
113}
114
115impl<'a, 'x, 'y, F> BLASBuilder_<'y, F, Ix1> for HEMV_<'a, 'x, 'y, F>
116where
117 F: HEMVNum,
118{
119 fn driver(self) -> Result<HEMV_Driver<'a, 'x, 'y, F>, BLASError> {
120 let Self { a, x, y, alpha, beta, uplo, .. } = self;
121
122 let layout_a = get_layout_array2(&a);
124 assert!(layout_a.is_fpref());
125
126 let (n_, n) = a.dim();
128 let lda = a.stride_of(Axis(1));
129 let incx = x.stride_of(Axis(0));
130
131 blas_assert_eq!(n, n_, InvalidDim)?;
133 blas_assert_eq!(x.len_of(Axis(0)), n, InvalidDim)?;
134
135 let y = match y {
137 Some(y) => {
138 blas_assert_eq!(y.len_of(Axis(0)), n, InvalidDim)?;
139 ArrayOut1::ViewMut(y)
140 },
141 None => ArrayOut1::Owned(Array1::zeros(n)),
142 };
143 let incy = y.view().stride_of(Axis(0));
144
145 let driver = HEMV_Driver {
147 uplo: uplo.try_into()?,
148 n: n.try_into()?,
149 alpha,
150 a,
151 lda: lda.try_into()?,
152 x,
153 incx: incx.try_into()?,
154 beta,
155 y,
156 incy: incy.try_into()?,
157 };
158 return Ok(driver);
159 }
160}
161
162pub type HEMV<'a, 'x, 'y, F> = HEMV_Builder<'a, 'x, 'y, F>;
167pub type SSYMV<'a, 'x, 'y> = HEMV<'a, 'x, 'y, f32>;
168pub type DSYMV<'a, 'x, 'y> = HEMV<'a, 'x, 'y, f64>;
169pub type CHEMV<'a, 'x, 'y> = HEMV<'a, 'x, 'y, c32>;
170pub type ZHEMV<'a, 'x, 'y> = HEMV<'a, 'x, 'y, c64>;
171
172impl<'a, 'x, 'y, F> BLASBuilder<'y, F, Ix1> for HEMV_Builder<'a, 'x, 'y, F>
173where
174 F: HEMVNum,
175{
176 fn run(self) -> Result<ArrayOut1<'y, F>, BLASError> {
177 let obj = self.build()?;
179
180 let layout_a = get_layout_array2(&obj.a);
181
182 if layout_a.is_fpref() {
183 return obj.driver()?.run_blas();
185 } else {
186 let a_cow = obj.a.to_row_layout()?;
188 if F::is_complex() {
189 let x = obj.x.mapv(F::conj);
190 let y = obj.y.map(|mut y| {
191 y.mapv_inplace(F::conj);
192 y
193 });
194 let obj = HEMV_ {
195 a: a_cow.t(),
196 x: x.view(),
197 y,
198 uplo: obj.uplo.flip()?,
199 alpha: F::conj(obj.alpha),
200 beta: F::conj(obj.beta),
201 ..obj
202 };
203 let mut y = obj.driver()?.run_blas()?;
204 y.view_mut().mapv_inplace(F::conj);
205 return Ok(y);
206 } else {
207 let obj = HEMV_ { a: a_cow.t(), uplo: obj.uplo.flip()?, ..obj };
208 return obj.driver()?.run_blas();
209 }
210 }
211 }
212}
213
214