1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6pub trait HPMVNum: BLASFloat {
9 unsafe fn hpmv(
10 uplo: *const c_char,
11 n: *const blas_int,
12 alpha: *const Self,
13 ap: *const Self,
14 x: *const Self,
15 incx: *const blas_int,
16 beta: *const Self,
17 y: *mut Self,
18 incy: *const blas_int,
19 );
20}
21
22macro_rules! impl_func {
23 ($type: ty, $func: ident) => {
24 impl HPMVNum for $type {
25 unsafe fn hpmv(
26 uplo: *const c_char,
27 n: *const blas_int,
28 alpha: *const Self,
29 ap: *const Self,
30 x: *const Self,
31 incx: *const blas_int,
32 beta: *const Self,
33 y: *mut Self,
34 incy: *const blas_int,
35 ) {
36 ffi::$func(uplo, n, alpha, ap, x, incx, beta, y, incy);
37 }
38 }
39 };
40}
41
42impl_func!(f32, sspmv_);
43impl_func!(f64, dspmv_);
44impl_func!(c32, chpmv_);
45impl_func!(c64, zhpmv_);
46
47pub struct HPMV_Driver<'a, 'x, 'y, F>
52where
53 F: HPMVNum,
54{
55 uplo: c_char,
56 n: blas_int,
57 alpha: F,
58 ap: ArrayView1<'a, F>,
59 x: ArrayView1<'x, F>,
60 incx: blas_int,
61 beta: F,
62 y: ArrayOut1<'y, F>,
63 incy: blas_int,
64}
65
66impl<'a, 'x, 'y, F> BLASDriver<'y, F, Ix1> for HPMV_Driver<'a, 'x, 'y, F>
67where
68 F: HPMVNum,
69{
70 fn run_blas(self) -> Result<ArrayOut1<'y, F>, BLASError> {
71 let Self { uplo, n, alpha, ap, x, incx, beta, mut y, incy, .. } = self;
72 let ap_ptr = ap.as_ptr();
73 let x_ptr = x.as_ptr();
74 let y_ptr = y.get_data_mut_ptr();
75
76 if n == 0 {
79 return Ok(y);
80 }
81
82 unsafe {
83 F::hpmv(&uplo, &n, &alpha, ap_ptr, x_ptr, &incx, &beta, y_ptr, &incy);
84 }
85 return Ok(y);
86 }
87}
88
89#[derive(Builder)]
94#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
95pub struct HPMV_<'a, 'x, 'y, F>
96where
97 F: HPMVNum,
98{
99 pub ap: ArrayView1<'a, F>,
100 pub x: ArrayView1<'x, F>,
101
102 #[builder(setter(into, strip_option), default = "None")]
103 pub y: Option<ArrayViewMut1<'y, F>>,
104 #[builder(setter(into), default = "F::one()")]
105 pub alpha: F,
106 #[builder(setter(into), default = "F::zero()")]
107 pub beta: F,
108 #[builder(setter(into), default = "BLASUpper")]
109 pub uplo: BLASUpLo,
110 #[builder(setter(into, strip_option), default = "None")]
111 pub layout: Option<BLASLayout>,
112}
113
114impl<'a, 'x, 'y, F> BLASBuilder_<'y, F, Ix1> for HPMV_<'a, 'x, 'y, F>
115where
116 F: HPMVNum,
117{
118 fn driver(self) -> Result<HPMV_Driver<'a, 'x, 'y, F>, BLASError> {
119 let Self { ap, x, y, alpha, beta, uplo, layout, .. } = self;
120
121 let incap = ap.stride_of(Axis(0));
123 assert!(incap <= 1);
124 assert_eq!(layout, Some(BLASColMajor));
125
126 let np = ap.len_of(Axis(0));
128 let n = x.len_of(Axis(0));
129 let incx = x.stride_of(Axis(0));
130
131 blas_assert_eq!(np, n * (n + 1) / 2, InvalidDim)?;
133
134 let y = match y {
136 Some(y) => {
137 blas_assert_eq!(y.len_of(Axis(0)), n, InvalidDim)?;
138 ArrayOut1::ViewMut(y)
139 },
140 None => ArrayOut1::Owned(Array1::zeros(n)),
141 };
142 let incy = y.view().stride_of(Axis(0));
143
144 let driver = HPMV_Driver {
146 uplo: uplo.try_into()?,
147 n: n.try_into()?,
148 alpha,
149 ap,
150 x,
151 incx: incx.try_into()?,
152 beta,
153 y,
154 incy: incy.try_into()?,
155 };
156 return Ok(driver);
157 }
158}
159
160pub type HPMV<'a, 'x, 'y, F> = HPMV_Builder<'a, 'x, 'y, F>;
165pub type SSPMV<'a, 'x, 'y> = HPMV<'a, 'x, 'y, f32>;
166pub type DSPMV<'a, 'x, 'y> = HPMV<'a, 'x, 'y, f64>;
167pub type CHPMV<'a, 'x, 'y> = HPMV<'a, 'x, 'y, c32>;
168pub type ZHPMV<'a, 'x, 'y> = HPMV<'a, 'x, 'y, c64>;
169
170impl<'a, 'x, 'y, F> BLASBuilder<'y, F, Ix1> for HPMV_Builder<'a, 'x, 'y, F>
171where
172 F: HPMVNum,
173{
174 fn run(self) -> Result<ArrayOut1<'y, F>, BLASError> {
175 let obj = self.build()?;
177
178 let layout = obj.layout.unwrap_or(BLASRowMajor);
179
180 if layout == BLASColMajor {
181 let ap_cow = obj.ap.to_seq_layout()?;
183 let obj = HPMV_ { ap: ap_cow.view(), layout: Some(BLASColMajor), ..obj };
184 return obj.driver()?.run_blas();
185 } else {
186 let ap_cow = obj.ap.to_seq_layout()?;
188 if F::is_complex() {
189 let x = obj.x.mapv(F::conj);
190 let y = obj.y.map(|mut y| {
191 y.mapv_inplace(F::conj);
192 y
193 });
194 let obj = HPMV_ {
195 ap: ap_cow.view(),
196 x: x.view(),
197 y,
198 uplo: obj.uplo.flip()?,
199 alpha: F::conj(obj.alpha),
200 beta: F::conj(obj.beta),
201 layout: Some(BLASColMajor),
202 ..obj
203 };
204 let mut y = obj.driver()?.run_blas()?;
205 y.view_mut().mapv_inplace(F::conj);
206 return Ok(y);
207 } else {
208 let obj =
209 HPMV_ { ap: ap_cow.view(), uplo: obj.uplo.flip()?, layout: Some(BLASColMajor), ..obj };
210 return obj.driver()?.run_blas();
211 }
212 }
213 }
214}
215
216