1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6pub trait GEMVNum: BLASFloat {
9 unsafe fn gemv(
10 trans: *const c_char,
11 m: *const blas_int,
12 n: *const blas_int,
13 alpha: *const Self,
14 a: *const Self,
15 lda: *const blas_int,
16 x: *const Self,
17 incx: *const blas_int,
18 beta: *const Self,
19 y: *mut Self,
20 incy: *const blas_int,
21 );
22}
23
24macro_rules! impl_func {
25 ($type: ty, $func: ident) => {
26 impl GEMVNum for $type {
27 unsafe fn gemv(
28 trans: *const c_char,
29 m: *const blas_int,
30 n: *const blas_int,
31 alpha: *const Self,
32 a: *const Self,
33 lda: *const blas_int,
34 x: *const Self,
35 incx: *const blas_int,
36 beta: *const Self,
37 y: *mut Self,
38 incy: *const blas_int,
39 ) {
40 ffi::$func(trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
41 }
42 }
43 };
44}
45
46impl_func!(f32, sgemv_);
47impl_func!(f64, dgemv_);
48impl_func!(c32, cgemv_);
49impl_func!(c64, zgemv_);
50
51pub struct GEMV_Driver<'a, 'x, 'y, F>
56where
57 F: BLASFloat,
58{
59 trans: c_char,
60 m: blas_int,
61 n: blas_int,
62 alpha: F,
63 a: ArrayView2<'a, F>,
64 lda: blas_int,
65 x: ArrayView1<'x, F>,
66 incx: blas_int,
67 beta: F,
68 y: ArrayOut1<'y, F>,
69 incy: blas_int,
70}
71
72impl<'a, 'x, 'y, F> BLASDriver<'y, F, Ix1> for GEMV_Driver<'a, 'x, 'y, F>
73where
74 F: GEMVNum,
75{
76 fn run_blas(self) -> Result<ArrayOut1<'y, F>, BLASError> {
77 let Self { trans, m, n, alpha, a, lda, x, incx, beta, mut y, incy } = self;
78 let a_ptr = a.as_ptr();
79 let x_ptr = x.as_ptr();
80 let y_ptr = y.get_data_mut_ptr();
81
82 if m == 0 || n == 0 {
85 return Ok(y);
86 }
87
88 unsafe {
89 F::gemv(&trans, &m, &n, &alpha, a_ptr, &lda, x_ptr, &incx, &beta, y_ptr, &incy);
90 }
91 return Ok(y);
92 }
93}
94
95#[derive(Builder)]
100#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
101pub struct GEMV_<'a, 'x, 'y, F>
102where
103 F: GEMVNum,
104{
105 pub a: ArrayView2<'a, F>,
106 pub x: ArrayView1<'x, F>,
107
108 #[builder(setter(into, strip_option), default = "None")]
109 pub y: Option<ArrayViewMut1<'y, F>>,
110 #[builder(setter(into), default = "F::one()")]
111 pub alpha: F,
112 #[builder(setter(into), default = "F::zero()")]
113 pub beta: F,
114 #[builder(setter(into), default = "BLASNoTrans")]
115 pub trans: BLASTranspose,
116}
117
118impl<'a, 'x, 'y, F> BLASBuilder_<'y, F, Ix1> for GEMV_<'a, 'x, 'y, F>
119where
120 F: GEMVNum,
121{
122 fn driver(self) -> Result<GEMV_Driver<'a, 'x, 'y, F>, BLASError> {
123 let Self { a, x, y, alpha, beta, trans } = self;
124
125 let layout_a = get_layout_array2(&a);
127 assert!(layout_a.is_fpref());
128
129 let (m, n) = a.dim();
131 let lda = a.stride_of(Axis(1));
132 let incx = x.stride_of(Axis(0));
133
134 match trans {
136 BLASNoTrans => blas_assert_eq!(x.len_of(Axis(0)), n, InvalidDim)?,
137 BLASTrans | BLASConjTrans => blas_assert_eq!(x.len_of(Axis(0)), m, InvalidDim)?,
138 _ => blas_invalid!(trans)?,
139 };
140
141 let y = match y {
143 Some(y) => {
144 match trans {
145 BLASNoTrans => blas_assert_eq!(y.len_of(Axis(0)), m, InvalidDim)?,
146 BLASTrans | BLASConjTrans => blas_assert_eq!(y.len_of(Axis(0)), n, InvalidDim)?,
147 _ => blas_invalid!(trans)?,
148 };
149 ArrayOut1::ViewMut(y)
150 },
151 None => ArrayOut1::Owned(Array1::zeros(match trans {
152 BLASNoTrans => m,
153 BLASTrans | BLASConjTrans => n,
154 _ => blas_invalid!(trans)?,
155 })),
156 };
157 let incy = y.view().stride_of(Axis(0));
158
159 let driver = GEMV_Driver {
161 trans: trans.try_into()?,
162 m: m.try_into()?,
163 n: n.try_into()?,
164 alpha,
165 a,
166 lda: lda.try_into()?,
167 x,
168 incx: incx.try_into()?,
169 beta,
170 y,
171 incy: incy.try_into()?,
172 };
173 return Ok(driver);
174 }
175}
176
177pub type GEMV<'a, 'x, 'y, F> = GEMV_Builder<'a, 'x, 'y, F>;
182pub type SGEMV<'a, 'x, 'y> = GEMV<'a, 'x, 'y, f32>;
183pub type DGEMV<'a, 'x, 'y> = GEMV<'a, 'x, 'y, f64>;
184pub type CGEMV<'a, 'x, 'y> = GEMV<'a, 'x, 'y, c32>;
185pub type ZGEMV<'a, 'x, 'y> = GEMV<'a, 'x, 'y, c64>;
186
187impl<'a, 'x, 'y, F> BLASBuilder<'y, F, Ix1> for GEMV_Builder<'a, 'x, 'y, F>
188where
189 F: GEMVNum,
190{
191 fn run(self) -> Result<ArrayOut1<'y, F>, BLASError> {
192 let obj = self.build()?;
194
195 let layout_a = get_layout_array2(&obj.a);
196
197 if layout_a.is_fpref() {
198 return obj.driver()?.run_blas();
200 } else {
201 let a_cow = obj.a.to_row_layout()?;
203 match obj.trans {
204 BLASNoTrans => {
205 let obj = GEMV_ { a: a_cow.t(), trans: BLASTrans, ..obj };
207 return obj.driver()?.run_blas();
208 },
209 BLASTrans => {
210 let obj = GEMV_ { a: a_cow.t(), trans: BLASNoTrans, ..obj };
212 return obj.driver()?.run_blas();
213 },
214 BLASConjTrans => {
215 let x = obj.x.mapv(F::conj);
217 let y = obj.y.map(|mut y| {
218 y.mapv_inplace(F::conj);
219 y
220 });
221 let obj = GEMV_ {
222 a: a_cow.t(),
223 trans: BLASNoTrans,
224 x: x.view(),
225 y,
226 alpha: F::conj(obj.alpha),
227 beta: F::conj(obj.beta),
228 };
229 let mut y = obj.driver()?.run_blas()?;
230 y.view_mut().mapv_inplace(F::conj);
231 return Ok(y);
232 },
233 _ => return blas_invalid!(&obj.trans)?,
234 };
235 }
236 }
237}
238
239