blas_array2/blas3/
hemm.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6/* #region BLAS func */
7
8pub trait HEMMNum: BLASFloat {
9    unsafe fn hemm(
10        side: *const c_char,
11        uplo: *const c_char,
12        m: *const blas_int,
13        n: *const blas_int,
14        alpha: *const Self,
15        a: *const Self,
16        lda: *const blas_int,
17        b: *const Self,
18        ldb: *const blas_int,
19        beta: *const Self,
20        c: *mut Self,
21        ldc: *const blas_int,
22    );
23}
24
25macro_rules! impl_func {
26    ($type: ty, $func: ident) => {
27        impl HEMMNum for $type {
28            unsafe fn hemm(
29                side: *const c_char,
30                uplo: *const c_char,
31                m: *const blas_int,
32                n: *const blas_int,
33                alpha: *const Self,
34                a: *const Self,
35                lda: *const blas_int,
36                b: *const Self,
37                ldb: *const blas_int,
38                beta: *const Self,
39                c: *mut Self,
40                ldc: *const blas_int,
41            ) {
42                ffi::$func(side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc);
43            }
44        }
45    };
46}
47
48impl_func!(c32, chemm_);
49impl_func!(c64, zhemm_);
50
51/* #endregion */
52
53/* #region BLAS driver */
54
55pub struct HEMM_Driver<'a, 'b, 'c, F>
56where
57    F: HEMMNum,
58{
59    side: c_char,
60    uplo: c_char,
61    m: blas_int,
62    n: blas_int,
63    alpha: F,
64    a: ArrayView2<'a, F>,
65    lda: blas_int,
66    b: ArrayView2<'b, F>,
67    ldb: blas_int,
68    beta: F,
69    c: ArrayOut2<'c, F>,
70    ldc: blas_int,
71}
72
73impl<'a, 'b, 'c, F> BLASDriver<'c, F, Ix2> for HEMM_Driver<'a, 'b, 'c, F>
74where
75    F: HEMMNum,
76{
77    fn run_blas(self) -> Result<ArrayOut2<'c, F>, BLASError> {
78        let Self { side, uplo, m, n, alpha, a, lda, b, ldb, beta, mut c, ldc, .. } = self;
79        let a_ptr = a.as_ptr();
80        let b_ptr = b.as_ptr();
81        let c_ptr = c.get_data_mut_ptr();
82
83        // assuming dimension checks has been performed
84        // unconditionally return Ok if output does not contain anything
85        if m == 0 || n == 0 {
86            return Ok(c.clone_to_view_mut());
87        }
88
89        unsafe {
90            F::hemm(&side, &uplo, &m, &n, &alpha, a_ptr, &lda, b_ptr, &ldb, &beta, c_ptr, &ldc);
91        }
92        return Ok(c.clone_to_view_mut());
93    }
94}
95
96/* #endregion */
97
98/* #region BLAS builder */
99
100#[derive(Builder)]
101#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
102pub struct HEMM_<'a, 'b, 'c, F>
103where
104    F: HEMMNum,
105{
106    pub a: ArrayView2<'a, F>,
107    pub b: ArrayView2<'b, F>,
108
109    #[builder(setter(into, strip_option), default = "None")]
110    pub c: Option<ArrayViewMut2<'c, F>>,
111    #[builder(setter(into), default = "F::one()")]
112    pub alpha: F,
113    #[builder(setter(into), default = "F::zero()")]
114    pub beta: F,
115    #[builder(setter(into), default = "BLASLeft")]
116    pub side: BLASSide,
117    #[builder(setter(into), default = "BLASLower")]
118    pub uplo: BLASUpLo,
119    #[builder(setter(into, strip_option), default = "None")]
120    pub layout: Option<BLASLayout>,
121}
122
123impl<'a, 'b, 'c, F> BLASBuilder_<'c, F, Ix2> for HEMM_<'a, 'b, 'c, F>
124where
125    F: HEMMNum,
126{
127    fn driver(self) -> Result<HEMM_Driver<'a, 'b, 'c, F>, BLASError> {
128        let Self { a, b, c, alpha, beta, side, uplo, layout, .. } = self;
129
130        // only fortran-preferred (col-major) is accepted in inner wrapper
131        assert_eq!(layout, Some(BLASColMajor));
132        assert!(a.is_fpref() && a.is_fpref());
133
134        // initialize intent(hide)
135        let m = b.len_of(Axis(0));
136        let n = b.len_of(Axis(1));
137        let lda = a.stride_of(Axis(1));
138        let ldb = b.stride_of(Axis(1));
139
140        // perform check
141        match side {
142            BLASLeft => blas_assert_eq!(a.dim(), (m, m), InvalidDim)?,
143            BLASRight => blas_assert_eq!(a.dim(), (n, n), InvalidDim)?,
144            _ => blas_invalid!(side)?,
145        }
146
147        // optional intent(out)
148        let c = match c {
149            Some(c) => {
150                blas_assert_eq!(c.dim(), (m, n), InvalidDim)?;
151                if c.view().is_fpref() {
152                    ArrayOut2::ViewMut(c)
153                } else {
154                    let c_buffer = c.view().to_col_layout()?.into_owned();
155                    ArrayOut2::ToBeCloned(c, c_buffer)
156                }
157            },
158            None => ArrayOut2::Owned(Array2::zeros((m, n).f())),
159        };
160        let ldc = c.view().stride_of(Axis(1));
161
162        // finalize
163        let driver = HEMM_Driver::<'a, 'b, 'c, F> {
164            side: side.try_into()?,
165            uplo: uplo.try_into()?,
166            m: m.try_into()?,
167            n: n.try_into()?,
168            alpha,
169            a,
170            lda: lda.try_into()?,
171            b,
172            ldb: ldb.try_into()?,
173            beta,
174            c,
175            ldc: ldc.try_into()?,
176        };
177        return Ok(driver);
178    }
179}
180
181/* #endregion */
182
183/* #region BLAS wrapper */
184
185pub type HEMM<'a, 'b, 'c, F> = HEMM_Builder<'a, 'b, 'c, F>;
186pub type CHEMM<'a, 'b, 'c> = HEMM<'a, 'b, 'c, c32>;
187pub type ZHEMM<'a, 'b, 'c> = HEMM<'a, 'b, 'c, c64>;
188
189impl<'a, 'b, 'c, F> BLASBuilder<'c, F, Ix2> for HEMM_Builder<'a, 'b, 'c, F>
190where
191    F: HEMMNum,
192{
193    fn run(self) -> Result<ArrayOut2<'c, F>, BLASError> {
194        // initialize
195        let HEMM_ { a, b, c, alpha, beta, side, uplo, layout, .. } = self.build()?;
196
197        let layout_a = get_layout_array2(&a);
198        let layout_b = get_layout_array2(&b);
199        let layout_c = c.as_ref().map(|c| get_layout_array2(&c.view()));
200
201        let layout = get_layout_row_preferred(&[layout, layout_c], &[layout_a, layout_b]);
202        if layout == BLASColMajor {
203            // F-contiguous: C = op(A) op(B)
204            let a_cow = a.to_col_layout()?;
205            let b_cow = b.to_col_layout()?;
206            let obj = HEMM_ {
207                a: a_cow.view(),
208                b: b_cow.view(),
209                c,
210                alpha,
211                beta,
212                side,
213                uplo,
214                layout: Some(BLASColMajor),
215            };
216            return obj.driver()?.run_blas();
217        } else {
218            // C-contiguous: C' = op(B') op(A')
219            let a_cow = a.to_row_layout()?;
220            let b_cow = b.to_row_layout()?;
221            let obj = HEMM_ {
222                a: a_cow.t(),
223                b: b_cow.t(),
224                c: c.map(|c| c.reversed_axes()),
225                alpha,
226                beta,
227                side: side.flip()?,
228                uplo: uplo.flip()?,
229                layout: Some(BLASColMajor),
230            };
231            let c = obj.driver()?.run_blas()?.reversed_axes();
232            return Ok(c);
233        }
234    }
235}
236
237/* #endregion */