blas_array2/blas2/
gemv.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6/* #region BLAS func */
7
8pub trait GEMVNum: BLASFloat {
9    unsafe fn gemv(
10        trans: *const c_char,
11        m: *const blas_int,
12        n: *const blas_int,
13        alpha: *const Self,
14        a: *const Self,
15        lda: *const blas_int,
16        x: *const Self,
17        incx: *const blas_int,
18        beta: *const Self,
19        y: *mut Self,
20        incy: *const blas_int,
21    );
22}
23
24macro_rules! impl_func {
25    ($type: ty, $func: ident) => {
26        impl GEMVNum for $type {
27            unsafe fn gemv(
28                trans: *const c_char,
29                m: *const blas_int,
30                n: *const blas_int,
31                alpha: *const Self,
32                a: *const Self,
33                lda: *const blas_int,
34                x: *const Self,
35                incx: *const blas_int,
36                beta: *const Self,
37                y: *mut Self,
38                incy: *const blas_int,
39            ) {
40                ffi::$func(trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
41            }
42        }
43    };
44}
45
46impl_func!(f32, sgemv_);
47impl_func!(f64, dgemv_);
48impl_func!(c32, cgemv_);
49impl_func!(c64, zgemv_);
50
51/* #endregion */
52
53/* #region BLAS driver */
54
55pub struct GEMV_Driver<'a, 'x, 'y, F>
56where
57    F: BLASFloat,
58{
59    trans: c_char,
60    m: blas_int,
61    n: blas_int,
62    alpha: F,
63    a: ArrayView2<'a, F>,
64    lda: blas_int,
65    x: ArrayView1<'x, F>,
66    incx: blas_int,
67    beta: F,
68    y: ArrayOut1<'y, F>,
69    incy: blas_int,
70}
71
72impl<'a, 'x, 'y, F> BLASDriver<'y, F, Ix1> for GEMV_Driver<'a, 'x, 'y, F>
73where
74    F: GEMVNum,
75{
76    fn run_blas(self) -> Result<ArrayOut1<'y, F>, BLASError> {
77        let Self { trans, m, n, alpha, a, lda, x, incx, beta, mut y, incy } = self;
78        let a_ptr = a.as_ptr();
79        let x_ptr = x.as_ptr();
80        let y_ptr = y.get_data_mut_ptr();
81
82        // assuming dimension checks has been performed
83        // unconditionally return Ok if output does not contain anything
84        if m == 0 || n == 0 {
85            return Ok(y);
86        }
87
88        unsafe {
89            F::gemv(&trans, &m, &n, &alpha, a_ptr, &lda, x_ptr, &incx, &beta, y_ptr, &incy);
90        }
91        return Ok(y);
92    }
93}
94
95/* #endregion */
96
97/* #region BLAS builder */
98
99#[derive(Builder)]
100#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
101pub struct GEMV_<'a, 'x, 'y, F>
102where
103    F: GEMVNum,
104{
105    pub a: ArrayView2<'a, F>,
106    pub x: ArrayView1<'x, F>,
107
108    #[builder(setter(into, strip_option), default = "None")]
109    pub y: Option<ArrayViewMut1<'y, F>>,
110    #[builder(setter(into), default = "F::one()")]
111    pub alpha: F,
112    #[builder(setter(into), default = "F::zero()")]
113    pub beta: F,
114    #[builder(setter(into), default = "BLASNoTrans")]
115    pub trans: BLASTranspose,
116}
117
118impl<'a, 'x, 'y, F> BLASBuilder_<'y, F, Ix1> for GEMV_<'a, 'x, 'y, F>
119where
120    F: GEMVNum,
121{
122    fn driver(self) -> Result<GEMV_Driver<'a, 'x, 'y, F>, BLASError> {
123        let Self { a, x, y, alpha, beta, trans } = self;
124
125        // only fortran-preferred (col-major) is accepted in inner wrapper
126        let layout_a = get_layout_array2(&a);
127        assert!(layout_a.is_fpref());
128
129        // initialize intent(hide)
130        let (m, n) = a.dim();
131        let lda = a.stride_of(Axis(1));
132        let incx = x.stride_of(Axis(0));
133
134        // perform check
135        match trans {
136            BLASNoTrans => blas_assert_eq!(x.len_of(Axis(0)), n, InvalidDim)?,
137            BLASTrans | BLASConjTrans => blas_assert_eq!(x.len_of(Axis(0)), m, InvalidDim)?,
138            _ => blas_invalid!(trans)?,
139        };
140
141        // prepare output
142        let y = match y {
143            Some(y) => {
144                match trans {
145                    BLASNoTrans => blas_assert_eq!(y.len_of(Axis(0)), m, InvalidDim)?,
146                    BLASTrans | BLASConjTrans => blas_assert_eq!(y.len_of(Axis(0)), n, InvalidDim)?,
147                    _ => blas_invalid!(trans)?,
148                };
149                ArrayOut1::ViewMut(y)
150            },
151            None => ArrayOut1::Owned(Array1::zeros(match trans {
152                BLASNoTrans => m,
153                BLASTrans | BLASConjTrans => n,
154                _ => blas_invalid!(trans)?,
155            })),
156        };
157        let incy = y.view().stride_of(Axis(0));
158
159        // finalize
160        let driver = GEMV_Driver {
161            trans: trans.try_into()?,
162            m: m.try_into()?,
163            n: n.try_into()?,
164            alpha,
165            a,
166            lda: lda.try_into()?,
167            x,
168            incx: incx.try_into()?,
169            beta,
170            y,
171            incy: incy.try_into()?,
172        };
173        return Ok(driver);
174    }
175}
176
177/* #endregion */
178
179/* #region BLAS wrapper */
180
181pub type GEMV<'a, 'x, 'y, F> = GEMV_Builder<'a, 'x, 'y, F>;
182pub type SGEMV<'a, 'x, 'y> = GEMV<'a, 'x, 'y, f32>;
183pub type DGEMV<'a, 'x, 'y> = GEMV<'a, 'x, 'y, f64>;
184pub type CGEMV<'a, 'x, 'y> = GEMV<'a, 'x, 'y, c32>;
185pub type ZGEMV<'a, 'x, 'y> = GEMV<'a, 'x, 'y, c64>;
186
187impl<'a, 'x, 'y, F> BLASBuilder<'y, F, Ix1> for GEMV_Builder<'a, 'x, 'y, F>
188where
189    F: GEMVNum,
190{
191    fn run(self) -> Result<ArrayOut1<'y, F>, BLASError> {
192        // initialize
193        let obj = self.build()?;
194
195        let layout_a = get_layout_array2(&obj.a);
196
197        if layout_a.is_fpref() {
198            // F-contiguous: y = alpha op(A) x + beta y
199            return obj.driver()?.run_blas();
200        } else {
201            // C-contiguous
202            let a_cow = obj.a.to_row_layout()?;
203            match obj.trans {
204                BLASNoTrans => {
205                    // N -> T: y = alpha (A')' x + beta y
206                    let obj = GEMV_ { a: a_cow.t(), trans: BLASTrans, ..obj };
207                    return obj.driver()?.run_blas();
208                },
209                BLASTrans => {
210                    // T -> N: y = alpha (A') x + beta y
211                    let obj = GEMV_ { a: a_cow.t(), trans: BLASNoTrans, ..obj };
212                    return obj.driver()?.run_blas();
213                },
214                BLASConjTrans => {
215                    // C -> N: y* = alpha* (A') x* + beta* y*; y = y*
216                    let x = obj.x.mapv(F::conj);
217                    let y = obj.y.map(|mut y| {
218                        y.mapv_inplace(F::conj);
219                        y
220                    });
221                    let obj = GEMV_ {
222                        a: a_cow.t(),
223                        trans: BLASNoTrans,
224                        x: x.view(),
225                        y,
226                        alpha: F::conj(obj.alpha),
227                        beta: F::conj(obj.beta),
228                    };
229                    let mut y = obj.driver()?.run_blas()?;
230                    y.view_mut().mapv_inplace(F::conj);
231                    return Ok(y);
232                },
233                _ => return blas_invalid!(&obj.trans)?,
234            };
235        }
236    }
237}
238
239/* #endregion */