1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6pub trait HBMVNum: BLASFloat {
9 unsafe fn hbmv(
10 uplo: *const c_char,
11 n: *const blas_int,
12 k: *const blas_int,
13 alpha: *const Self,
14 a: *const Self,
15 lda: *const blas_int,
16 x: *const Self,
17 incx: *const blas_int,
18 beta: *const Self,
19 y: *mut Self,
20 incy: *const blas_int,
21 );
22}
23
24macro_rules! impl_func {
25 ($type: ty, $func: ident) => {
26 impl HBMVNum for $type {
27 unsafe fn hbmv(
28 uplo: *const c_char,
29 n: *const blas_int,
30 k: *const blas_int,
31 alpha: *const Self,
32 a: *const Self,
33 lda: *const blas_int,
34 x: *const Self,
35 incx: *const blas_int,
36 beta: *const Self,
37 y: *mut Self,
38 incy: *const blas_int,
39 ) {
40 ffi::$func(uplo, n, k, alpha, a, lda, x, incx, beta, y, incy);
41 }
42 }
43 };
44}
45
46impl_func!(f32, ssbmv_);
47impl_func!(f64, dsbmv_);
48impl_func!(c32, chbmv_);
49impl_func!(c64, zhbmv_);
50
51pub struct HBMV_Driver<'a, 'x, 'y, F>
56where
57 F: HBMVNum,
58{
59 uplo: c_char,
60 n: blas_int,
61 k: blas_int,
62 alpha: F,
63 a: ArrayView2<'a, F>,
64 lda: blas_int,
65 x: ArrayView1<'x, F>,
66 incx: blas_int,
67 beta: F,
68 y: ArrayOut1<'y, F>,
69 incy: blas_int,
70}
71
72impl<'a, 'x, 'y, F> BLASDriver<'y, F, Ix1> for HBMV_Driver<'a, 'x, 'y, F>
73where
74 F: HBMVNum,
75{
76 fn run_blas(self) -> Result<ArrayOut1<'y, F>, BLASError> {
77 let Self { uplo, n, k, alpha, a, lda, x, incx, beta, mut y, incy, .. } = self;
78 let a_ptr = a.as_ptr();
79 let x_ptr = x.as_ptr();
80 let y_ptr = y.get_data_mut_ptr();
81
82 if n == 0 {
85 return Ok(y);
86 }
87
88 unsafe {
89 F::hbmv(&uplo, &n, &k, &alpha, a_ptr, &lda, x_ptr, &incx, &beta, y_ptr, &incy);
90 }
91 return Ok(y);
92 }
93}
94
95#[derive(Builder)]
100#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
101pub struct HBMV_<'a, 'x, 'y, F>
102where
103 F: HBMVNum,
104{
105 pub a: ArrayView2<'a, F>,
106 pub x: ArrayView1<'x, F>,
107
108 #[builder(setter(into, strip_option), default = "None")]
109 pub y: Option<ArrayViewMut1<'y, F>>,
110 #[builder(setter(into), default = "F::one()")]
111 pub alpha: F,
112 #[builder(setter(into), default = "F::zero()")]
113 pub beta: F,
114 #[builder(setter(into), default = "BLASUpper")]
115 pub uplo: BLASUpLo,
116 #[builder(setter(into, strip_option), default = "None")]
117 pub layout: Option<BLASLayout>,
118}
119
120impl<'a, 'x, 'y, F> BLASBuilder_<'y, F, Ix1> for HBMV_<'a, 'x, 'y, F>
121where
122 F: HBMVNum,
123{
124 fn driver(self) -> Result<HBMV_Driver<'a, 'x, 'y, F>, BLASError> {
125 let Self { a, x, y, alpha, beta, uplo, layout, .. } = self;
126
127 let layout_a = get_layout_array2(&a);
129 assert!(layout_a.is_fpref());
130 assert!(layout == Some(BLASLayout::ColMajor));
131
132 let (k_, n) = a.dim();
134 blas_assert!(k_ > 0, InvalidDim, "Rows of input `a` must larger than zero.")?;
135 let k = k_ - 1;
136 let lda = a.stride_of(Axis(1));
137 let incx = x.stride_of(Axis(0));
138
139 blas_assert_eq!(x.len_of(Axis(0)), n, InvalidDim)?;
141
142 let y = match y {
144 Some(y) => {
145 blas_assert_eq!(y.len_of(Axis(0)), n, InvalidDim)?;
146 ArrayOut1::ViewMut(y)
147 },
148 None => ArrayOut1::Owned(Array1::zeros(n)),
149 };
150 let incy = y.view().stride_of(Axis(0));
151
152 let driver = HBMV_Driver {
154 uplo: uplo.try_into()?,
155 n: n.try_into()?,
156 k: k.try_into()?,
157 alpha,
158 a,
159 lda: lda.try_into()?,
160 x,
161 incx: incx.try_into()?,
162 beta,
163 y,
164 incy: incy.try_into()?,
165 };
166 return Ok(driver);
167 }
168}
169
170pub type HBMV<'a, 'x, 'y, F> = HBMV_Builder<'a, 'x, 'y, F>;
175pub type SSBMV<'a, 'x, 'y> = HBMV<'a, 'x, 'y, f32>;
176pub type DSBMV<'a, 'x, 'y> = HBMV<'a, 'x, 'y, f64>;
177pub type CHBMV<'a, 'x, 'y> = HBMV<'a, 'x, 'y, c32>;
178pub type ZHBMV<'a, 'x, 'y> = HBMV<'a, 'x, 'y, c64>;
179
180impl<'a, 'x, 'y, F> BLASBuilder<'y, F, Ix1> for HBMV_Builder<'a, 'x, 'y, F>
181where
182 F: HBMVNum,
183{
184 fn run(self) -> Result<ArrayOut1<'y, F>, BLASError> {
185 let obj = self.build()?;
187
188 let layout_a = get_layout_array2(&obj.a);
189 let layout = get_layout_row_preferred(&[obj.layout, Some(layout_a)], &[]);
190
191 if layout == BLASColMajor {
192 let a_cow = obj.a.to_col_layout()?;
194 let obj = HBMV_ { a: a_cow.view(), layout: Some(BLASColMajor), ..obj };
195 return obj.driver()?.run_blas();
196 } else {
197 let a_cow = obj.a.to_row_layout()?;
199 if F::is_complex() {
200 let x = obj.x.mapv(F::conj);
201 let y = obj.y.map(|mut y| {
202 y.mapv_inplace(F::conj);
203 y
204 });
205 let obj = HBMV_ {
206 a: a_cow.t(),
207 x: x.view(),
208 y,
209 uplo: obj.uplo.flip()?,
210 alpha: F::conj(obj.alpha),
211 beta: F::conj(obj.beta),
212 layout: Some(BLASColMajor),
213 ..obj
214 };
215 let mut y = obj.driver()?.run_blas()?;
216 y.view_mut().mapv_inplace(F::conj);
217 return Ok(y);
218 } else {
219 let obj = HBMV_ { a: a_cow.t(), uplo: obj.uplo.flip()?, layout: Some(BLASColMajor), ..obj };
220 return obj.driver()?.run_blas();
221 }
222 }
223 }
224}
225
226