1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6pub trait GBMVNum: BLASFloat {
9 unsafe fn gbmv(
10 trans: *const c_char,
11 m: *const blas_int,
12 n: *const blas_int,
13 kl: *const blas_int,
14 ku: *const blas_int,
15 alpha: *const Self,
16 a: *const Self,
17 lda: *const blas_int,
18 x: *const Self,
19 incx: *const blas_int,
20 beta: *const Self,
21 y: *mut Self,
22 incy: *const blas_int,
23 );
24}
25
26macro_rules! impl_func {
27 ($type: ty, $func: ident) => {
28 impl GBMVNum for $type {
29 unsafe fn gbmv(
30 trans: *const c_char,
31 m: *const blas_int,
32 n: *const blas_int,
33 kl: *const blas_int,
34 ku: *const blas_int,
35 alpha: *const Self,
36 a: *const Self,
37 lda: *const blas_int,
38 x: *const Self,
39 incx: *const blas_int,
40 beta: *const Self,
41 y: *mut Self,
42 incy: *const blas_int,
43 ) {
44 ffi::$func(trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy);
45 }
46 }
47 };
48}
49
50impl_func!(f32, sgbmv_);
51impl_func!(f64, dgbmv_);
52impl_func!(c32, cgbmv_);
53impl_func!(c64, zgbmv_);
54
55pub struct GBMV_Driver<'a, 'x, 'y, F>
60where
61 F: GBMVNum,
62{
63 trans: c_char,
64 m: blas_int,
65 n: blas_int,
66 kl: blas_int,
67 ku: blas_int,
68 alpha: F,
69 a: ArrayView2<'a, F>,
70 lda: blas_int,
71 x: ArrayView1<'x, F>,
72 incx: blas_int,
73 beta: F,
74 y: ArrayOut1<'y, F>,
75 incy: blas_int,
76}
77
78impl<'a, 'x, 'y, F> BLASDriver<'y, F, Ix1> for GBMV_Driver<'a, 'x, 'y, F>
79where
80 F: GBMVNum,
81{
82 fn run_blas(self) -> Result<ArrayOut1<'y, F>, BLASError> {
83 let Self { trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, mut y, incy } = self;
84 let a_ptr = a.as_ptr();
85 let x_ptr = x.as_ptr();
86 let y_ptr = y.get_data_mut_ptr();
87
88 if n == 0 {
91 return Ok(y);
92 }
93
94 unsafe {
95 F::gbmv(&trans, &m, &n, &kl, &ku, &alpha, a_ptr, &lda, x_ptr, &incx, &beta, y_ptr, &incy);
96 }
97 return Ok(y);
98 }
99}
100
101#[derive(Builder)]
106#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
107pub struct GBMV_<'a, 'x, 'y, F>
108where
109 F: GBMVNum,
110{
111 pub a: ArrayView2<'a, F>,
112 pub x: ArrayView1<'x, F>,
113 pub m: usize,
114 pub kl: usize,
115
116 #[builder(setter(into, strip_option), default = "None")]
117 pub y: Option<ArrayViewMut1<'y, F>>,
118 #[builder(setter(into), default = "F::one()")]
119 pub alpha: F,
120 #[builder(setter(into), default = "F::zero()")]
121 pub beta: F,
122 #[builder(setter(into), default = "BLASNoTrans")]
123 pub trans: BLASTranspose,
124 #[builder(setter(into, strip_option), default = "None")]
125 pub layout: Option<BLASLayout>,
126}
127
128impl<'a, 'x, 'y, F> BLASBuilder_<'y, F, Ix1> for GBMV_<'a, 'x, 'y, F>
129where
130 F: GBMVNum,
131{
132 fn driver(self) -> Result<GBMV_Driver<'a, 'x, 'y, F>, BLASError> {
133 let Self { a, x, m, kl, y, alpha, beta, trans, layout } = self;
134
135 let layout_a = get_layout_array2(&a);
137 assert!(layout_a.is_fpref());
138 assert!(layout == Some(BLASLayout::ColMajor));
139
140 let (k, n) = a.dim();
142 let lda = a.stride_of(Axis(1));
143 let incx = x.stride_of(Axis(0));
144
145 blas_assert!(k > kl, InvalidDim)?;
147 blas_assert!(m >= k, InvalidDim)?;
148 let ku = k - 1 - kl;
149 match trans {
150 BLASNoTrans => blas_assert_eq!(x.len_of(Axis(0)), n, InvalidDim)?,
151 BLASTrans | BLASConjTrans => blas_assert_eq!(x.len_of(Axis(0)), m, InvalidDim)?,
152 _ => blas_invalid!(trans)?,
153 };
154
155 let y = match y {
157 Some(y) => {
158 match trans {
159 BLASNoTrans => blas_assert_eq!(y.len_of(Axis(0)), m, InvalidDim)?,
160 BLASTrans | BLASConjTrans => blas_assert_eq!(y.len_of(Axis(0)), n, InvalidDim)?,
161 _ => blas_invalid!(trans)?,
162 };
163 ArrayOut1::ViewMut(y)
164 },
165 None => ArrayOut1::Owned(Array1::zeros(match trans {
166 BLASNoTrans => m,
167 BLASTrans | BLASConjTrans => n,
168 _ => blas_invalid!(trans)?,
169 })),
170 };
171 let incy = y.view().stride_of(Axis(0));
172
173 let driver = GBMV_Driver {
175 trans: trans.try_into()?,
176 m: m.try_into()?,
177 n: n.try_into()?,
178 kl: kl.try_into()?,
179 ku: ku.try_into()?,
180 alpha,
181 a,
182 lda: lda.try_into()?,
183 x,
184 incx: incx.try_into()?,
185 beta,
186 y,
187 incy: incy.try_into()?,
188 };
189 return Ok(driver);
190 }
191}
192
193pub type GBMV<'a, 'x, 'y, F> = GBMV_Builder<'a, 'x, 'y, F>;
198pub type SGBMV<'a, 'x, 'y> = GBMV<'a, 'x, 'y, f32>;
199pub type DGBMV<'a, 'x, 'y> = GBMV<'a, 'x, 'y, f64>;
200pub type CGBMV<'a, 'x, 'y> = GBMV<'a, 'x, 'y, c32>;
201pub type ZGBMV<'a, 'x, 'y> = GBMV<'a, 'x, 'y, c64>;
202
203impl<'a, 'x, 'y, F> BLASBuilder<'y, F, Ix1> for GBMV_Builder<'a, 'x, 'y, F>
204where
205 F: GBMVNum,
206{
207 fn run(self) -> Result<ArrayOut1<'y, F>, BLASError> {
208 let GBMV_ { a, x, m, kl, y, alpha, beta, trans, layout } = self.build()?;
210
211 let layout_a = get_layout_array2(&a);
212 let layout = match layout {
213 Some(layout) => layout,
214 None => match layout_a {
215 BLASLayout::Sequential => BLASColMajor,
216 BLASRowMajor => BLASRowMajor,
217 BLASColMajor => BLASColMajor,
218 _ => blas_raise!(InvalidFlag, "Without defining layout, this function checks layout of input matrix `a` but it is not contiguous.")?,
219 }
220 };
221
222 if layout == BLASColMajor {
223 let a_cow = a.to_col_layout()?;
225 let obj = GBMV_ { a: a_cow.view(), x, m, kl, y, alpha, beta, trans, layout: Some(BLASColMajor) };
226 return obj.driver()?.run_blas();
227 } else {
228 let a_cow = a.to_row_layout()?;
230 let k = a_cow.len_of(Axis(1));
231 blas_assert!(k > kl, InvalidDim)?;
232 let ku = k - kl - 1;
233 match trans {
234 BLASNoTrans => {
235 let obj = GBMV_ {
237 a: a_cow.t(),
238 x,
239 m,
240 kl: ku,
241 y,
242 alpha,
243 beta,
244 trans: BLASTrans,
245 layout: Some(BLASColMajor),
246 };
247 return obj.driver()?.run_blas();
248 },
249 BLASTrans => {
250 let obj = GBMV_ {
252 a: a_cow.t(),
253 x,
254 m,
255 kl: ku,
256 y,
257 alpha,
258 beta,
259 trans: BLASNoTrans,
260 layout: Some(BLASColMajor),
261 };
262 return obj.driver()?.run_blas();
263 },
264 BLASConjTrans => {
265 let x = x.mapv(F::conj);
267 let y = y.map(|mut y| {
268 y.mapv_inplace(F::conj);
269 y
270 });
271 let obj = GBMV_ {
272 a: a_cow.t(),
273 x: x.view(),
274 m,
275 kl: ku,
276 y,
277 alpha: F::conj(alpha),
278 beta: F::conj(beta),
279 trans: BLASNoTrans,
280 layout: Some(BLASColMajor),
281 };
282 let mut y = obj.driver()?.run_blas()?;
283 y.view_mut().mapv_inplace(F::conj);
284 return Ok(y);
285 },
286 _ => return blas_invalid!(trans)?,
287 }
288 }
289 }
290}
291
292