blas_array2/blas2/
gbmv.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6/* #region BLAS func */
7
8pub trait GBMVNum: BLASFloat {
9    unsafe fn gbmv(
10        trans: *const c_char,
11        m: *const blas_int,
12        n: *const blas_int,
13        kl: *const blas_int,
14        ku: *const blas_int,
15        alpha: *const Self,
16        a: *const Self,
17        lda: *const blas_int,
18        x: *const Self,
19        incx: *const blas_int,
20        beta: *const Self,
21        y: *mut Self,
22        incy: *const blas_int,
23    );
24}
25
26macro_rules! impl_func {
27    ($type: ty, $func: ident) => {
28        impl GBMVNum for $type {
29            unsafe fn gbmv(
30                trans: *const c_char,
31                m: *const blas_int,
32                n: *const blas_int,
33                kl: *const blas_int,
34                ku: *const blas_int,
35                alpha: *const Self,
36                a: *const Self,
37                lda: *const blas_int,
38                x: *const Self,
39                incx: *const blas_int,
40                beta: *const Self,
41                y: *mut Self,
42                incy: *const blas_int,
43            ) {
44                ffi::$func(trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy);
45            }
46        }
47    };
48}
49
50impl_func!(f32, sgbmv_);
51impl_func!(f64, dgbmv_);
52impl_func!(c32, cgbmv_);
53impl_func!(c64, zgbmv_);
54
55/* #endregion */
56
57/* #region BLAS driver */
58
59pub struct GBMV_Driver<'a, 'x, 'y, F>
60where
61    F: GBMVNum,
62{
63    trans: c_char,
64    m: blas_int,
65    n: blas_int,
66    kl: blas_int,
67    ku: blas_int,
68    alpha: F,
69    a: ArrayView2<'a, F>,
70    lda: blas_int,
71    x: ArrayView1<'x, F>,
72    incx: blas_int,
73    beta: F,
74    y: ArrayOut1<'y, F>,
75    incy: blas_int,
76}
77
78impl<'a, 'x, 'y, F> BLASDriver<'y, F, Ix1> for GBMV_Driver<'a, 'x, 'y, F>
79where
80    F: GBMVNum,
81{
82    fn run_blas(self) -> Result<ArrayOut1<'y, F>, BLASError> {
83        let Self { trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, mut y, incy } = self;
84        let a_ptr = a.as_ptr();
85        let x_ptr = x.as_ptr();
86        let y_ptr = y.get_data_mut_ptr();
87
88        // assuming dimension checks has been performed
89        // unconditionally return Ok if output does not contain anything
90        if n == 0 {
91            return Ok(y);
92        }
93
94        unsafe {
95            F::gbmv(&trans, &m, &n, &kl, &ku, &alpha, a_ptr, &lda, x_ptr, &incx, &beta, y_ptr, &incy);
96        }
97        return Ok(y);
98    }
99}
100
101/* #endregion */
102
103/* #region BLAS builder */
104
105#[derive(Builder)]
106#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
107pub struct GBMV_<'a, 'x, 'y, F>
108where
109    F: GBMVNum,
110{
111    pub a: ArrayView2<'a, F>,
112    pub x: ArrayView1<'x, F>,
113    pub m: usize,
114    pub kl: usize,
115
116    #[builder(setter(into, strip_option), default = "None")]
117    pub y: Option<ArrayViewMut1<'y, F>>,
118    #[builder(setter(into), default = "F::one()")]
119    pub alpha: F,
120    #[builder(setter(into), default = "F::zero()")]
121    pub beta: F,
122    #[builder(setter(into), default = "BLASNoTrans")]
123    pub trans: BLASTranspose,
124    #[builder(setter(into, strip_option), default = "None")]
125    pub layout: Option<BLASLayout>,
126}
127
128impl<'a, 'x, 'y, F> BLASBuilder_<'y, F, Ix1> for GBMV_<'a, 'x, 'y, F>
129where
130    F: GBMVNum,
131{
132    fn driver(self) -> Result<GBMV_Driver<'a, 'x, 'y, F>, BLASError> {
133        let Self { a, x, m, kl, y, alpha, beta, trans, layout } = self;
134
135        // only fortran-preferred (col-major) is accepted in inner wrapper
136        let layout_a = get_layout_array2(&a);
137        assert!(layout_a.is_fpref());
138        assert!(layout == Some(BLASLayout::ColMajor));
139
140        // initialize intent(hide)
141        let (k, n) = a.dim();
142        let lda = a.stride_of(Axis(1));
143        let incx = x.stride_of(Axis(0));
144
145        // perform check
146        blas_assert!(k > kl, InvalidDim)?;
147        blas_assert!(m >= k, InvalidDim)?;
148        let ku = k - 1 - kl;
149        match trans {
150            BLASNoTrans => blas_assert_eq!(x.len_of(Axis(0)), n, InvalidDim)?,
151            BLASTrans | BLASConjTrans => blas_assert_eq!(x.len_of(Axis(0)), m, InvalidDim)?,
152            _ => blas_invalid!(trans)?,
153        };
154
155        // prepare output
156        let y = match y {
157            Some(y) => {
158                match trans {
159                    BLASNoTrans => blas_assert_eq!(y.len_of(Axis(0)), m, InvalidDim)?,
160                    BLASTrans | BLASConjTrans => blas_assert_eq!(y.len_of(Axis(0)), n, InvalidDim)?,
161                    _ => blas_invalid!(trans)?,
162                };
163                ArrayOut1::ViewMut(y)
164            },
165            None => ArrayOut1::Owned(Array1::zeros(match trans {
166                BLASNoTrans => m,
167                BLASTrans | BLASConjTrans => n,
168                _ => blas_invalid!(trans)?,
169            })),
170        };
171        let incy = y.view().stride_of(Axis(0));
172
173        // finalize
174        let driver = GBMV_Driver {
175            trans: trans.try_into()?,
176            m: m.try_into()?,
177            n: n.try_into()?,
178            kl: kl.try_into()?,
179            ku: ku.try_into()?,
180            alpha,
181            a,
182            lda: lda.try_into()?,
183            x,
184            incx: incx.try_into()?,
185            beta,
186            y,
187            incy: incy.try_into()?,
188        };
189        return Ok(driver);
190    }
191}
192
193/* #endregion */
194
195/* #region BLAS wrapper */
196
197pub type GBMV<'a, 'x, 'y, F> = GBMV_Builder<'a, 'x, 'y, F>;
198pub type SGBMV<'a, 'x, 'y> = GBMV<'a, 'x, 'y, f32>;
199pub type DGBMV<'a, 'x, 'y> = GBMV<'a, 'x, 'y, f64>;
200pub type CGBMV<'a, 'x, 'y> = GBMV<'a, 'x, 'y, c32>;
201pub type ZGBMV<'a, 'x, 'y> = GBMV<'a, 'x, 'y, c64>;
202
203impl<'a, 'x, 'y, F> BLASBuilder<'y, F, Ix1> for GBMV_Builder<'a, 'x, 'y, F>
204where
205    F: GBMVNum,
206{
207    fn run(self) -> Result<ArrayOut1<'y, F>, BLASError> {
208        // initialize
209        let GBMV_ { a, x, m, kl, y, alpha, beta, trans, layout } = self.build()?;
210
211        let layout_a = get_layout_array2(&a);
212        let layout = match layout {
213            Some(layout) => layout,
214            None => match layout_a {
215                BLASLayout::Sequential => BLASColMajor,
216                BLASRowMajor => BLASRowMajor,
217                BLASColMajor => BLASColMajor,
218                _ => blas_raise!(InvalidFlag, "Without defining layout, this function checks layout of input matrix `a` but it is not contiguous.")?,
219            }
220        };
221
222        if layout == BLASColMajor {
223            // F-contiguous
224            let a_cow = a.to_col_layout()?;
225            let obj = GBMV_ { a: a_cow.view(), x, m, kl, y, alpha, beta, trans, layout: Some(BLASColMajor) };
226            return obj.driver()?.run_blas();
227        } else {
228            // C-contiguous
229            let a_cow = a.to_row_layout()?;
230            let k = a_cow.len_of(Axis(1));
231            blas_assert!(k > kl, InvalidDim)?;
232            let ku = k - kl - 1;
233            match trans {
234                BLASNoTrans => {
235                    // N -> T
236                    let obj = GBMV_ {
237                        a: a_cow.t(),
238                        x,
239                        m,
240                        kl: ku,
241                        y,
242                        alpha,
243                        beta,
244                        trans: BLASTrans,
245                        layout: Some(BLASColMajor),
246                    };
247                    return obj.driver()?.run_blas();
248                },
249                BLASTrans => {
250                    // N -> T
251                    let obj = GBMV_ {
252                        a: a_cow.t(),
253                        x,
254                        m,
255                        kl: ku,
256                        y,
257                        alpha,
258                        beta,
259                        trans: BLASNoTrans,
260                        layout: Some(BLASColMajor),
261                    };
262                    return obj.driver()?.run_blas();
263                },
264                BLASConjTrans => {
265                    // C -> N
266                    let x = x.mapv(F::conj);
267                    let y = y.map(|mut y| {
268                        y.mapv_inplace(F::conj);
269                        y
270                    });
271                    let obj = GBMV_ {
272                        a: a_cow.t(),
273                        x: x.view(),
274                        m,
275                        kl: ku,
276                        y,
277                        alpha: F::conj(alpha),
278                        beta: F::conj(beta),
279                        trans: BLASNoTrans,
280                        layout: Some(BLASColMajor),
281                    };
282                    let mut y = obj.driver()?.run_blas()?;
283                    y.view_mut().mapv_inplace(F::conj);
284                    return Ok(y);
285                },
286                _ => return blas_invalid!(trans)?,
287            }
288        }
289    }
290}
291
292/* #endregion */