blas_array2/blas3/
her2k.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5use num_traits::*;
6
7/* #region BLAS func */
8
9pub trait HER2KNum: BLASFloat {
10    unsafe fn her2k(
11        uplo: *const c_char,
12        trans: *const c_char,
13        n: *const blas_int,
14        k: *const blas_int,
15        alpha: *const Self,
16        a: *const Self,
17        lda: *const blas_int,
18        b: *const Self,
19        ldb: *const blas_int,
20        beta: *const Self::RealFloat,
21        c: *mut Self,
22        ldc: *const blas_int,
23    );
24}
25
26macro_rules! impl_her2k {
27    ($type: ty, $func: ident) => {
28        impl HER2KNum for $type {
29            unsafe fn her2k(
30                uplo: *const c_char,
31                trans: *const c_char,
32                n: *const blas_int,
33                k: *const blas_int,
34                alpha: *const Self,
35                a: *const Self,
36                lda: *const blas_int,
37                b: *const Self,
38                ldb: *const blas_int,
39                beta: *const Self::RealFloat,
40                c: *mut Self,
41                ldc: *const blas_int,
42            ) {
43                ffi::$func(uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
44            }
45        }
46    };
47}
48
49impl_her2k!(c32, cher2k_);
50impl_her2k!(c64, zher2k_);
51
52/* #endregion */
53
54/* #region BLAS driver */
55
56pub struct HER2K_Driver<'a, 'b, 'c, F>
57where
58    F: HER2KNum,
59{
60    uplo: c_char,
61    trans: c_char,
62    n: blas_int,
63    k: blas_int,
64    alpha: F,
65    a: ArrayView2<'a, F>,
66    lda: blas_int,
67    b: ArrayView2<'b, F>,
68    ldb: blas_int,
69    beta: F::RealFloat,
70    c: ArrayOut2<'c, F>,
71    ldc: blas_int,
72}
73
74impl<'a, 'b, 'c, F> BLASDriver<'c, F, Ix2> for HER2K_Driver<'a, 'b, 'c, F>
75where
76    F: HER2KNum,
77{
78    fn run_blas(self) -> Result<ArrayOut2<'c, F>, BLASError> {
79        let Self { uplo, trans, n, k, alpha, a, lda, b, ldb, beta, mut c, ldc } = self;
80        let a_ptr = a.as_ptr();
81        let b_ptr = b.as_ptr();
82        let c_ptr = c.get_data_mut_ptr();
83
84        // assuming dimension checks has been performed
85        // unconditionally return Ok if output does not contain anything
86        if n == 0 {
87            return Ok(c.clone_to_view_mut());
88        } else if k == 0 {
89            let beta_f = F::RealFloat::from(beta);
90            if uplo == BLASLower.try_into()? {
91                for i in 0..n {
92                    c.view_mut().slice_mut(s![i.., i]).mapv_inplace(|v| v * F::from_real(beta_f));
93                }
94            } else if uplo == BLASUpper.try_into()? {
95                for i in 0..n {
96                    c.view_mut().slice_mut(s![..=i, i]).mapv_inplace(|v| v * F::from_real(beta_f));
97                }
98            } else {
99                blas_invalid!(uplo)?
100            }
101            return Ok(c.clone_to_view_mut());
102        }
103
104        unsafe {
105            F::her2k(&uplo, &trans, &n, &k, &alpha, a_ptr, &lda, b_ptr, &ldb, &beta, c_ptr, &ldc);
106        }
107        return Ok(c.clone_to_view_mut());
108    }
109}
110
111/* #endregion */
112
113/* #region BLAS builder */
114
115#[derive(Builder)]
116#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
117pub struct HER2K_<'a, 'b, 'c, F>
118where
119    F: HER2KNum,
120{
121    pub a: ArrayView2<'a, F>,
122    pub b: ArrayView2<'b, F>,
123
124    #[builder(setter(into, strip_option), default = "None")]
125    pub c: Option<ArrayViewMut2<'c, F>>,
126    #[builder(setter(into), default = "F::one()")]
127    pub alpha: F,
128    #[builder(setter(into), default = "F::RealFloat::zero()")]
129    pub beta: F::RealFloat,
130    #[builder(setter(into), default = "BLASLower")]
131    pub uplo: BLASUpLo,
132    #[builder(setter(into), default = "BLASNoTrans")]
133    pub trans: BLASTranspose,
134    #[builder(setter(into, strip_option), default = "None")]
135    pub layout: Option<BLASLayout>,
136}
137
138impl<'a, 'b, 'c, F> BLASBuilder_<'c, F, Ix2> for HER2K_<'a, 'b, 'c, F>
139where
140    F: HER2KNum,
141{
142    fn driver(self) -> Result<HER2K_Driver<'a, 'b, 'c, F>, BLASError> {
143        let Self { a, b, c, alpha, beta, uplo, trans, layout } = self;
144
145        // only fortran-preferred (col-major) is accepted in inner wrapper
146        assert_eq!(layout, Some(BLASColMajor));
147        assert!(a.is_fpref() && a.is_fpref());
148
149        // initialize intent(hide)
150        let (n, k) = match trans {
151            BLASNoTrans => (a.len_of(Axis(0)), a.len_of(Axis(1))),
152            BLASConjTrans => (a.len_of(Axis(1)), a.len_of(Axis(0))),
153            _ => blas_invalid!(trans)?,
154        };
155        let lda = a.stride_of(Axis(1));
156        let ldb = b.stride_of(Axis(1));
157
158        // perform check (cher2k: NC accepted)
159        // dimension of b
160        match trans {
161            BLASNoTrans => blas_assert_eq!(b.dim(), (n, k), InvalidDim)?,
162            BLASConjTrans => blas_assert_eq!(b.dim(), (k, n), InvalidDim)?,
163            _ => blas_invalid!(trans)?,
164        };
165
166        // optional intent(out)
167        let c = match c {
168            Some(c) => {
169                blas_assert_eq!(c.dim(), (n, n), InvalidDim)?;
170                if c.view().is_fpref() {
171                    ArrayOut2::ViewMut(c)
172                } else {
173                    let c_buffer = c.view().to_col_layout()?.into_owned();
174                    ArrayOut2::ToBeCloned(c, c_buffer)
175                }
176            },
177            None => ArrayOut2::Owned(Array2::zeros((n, n).f())),
178        };
179        let ldc = c.view().stride_of(Axis(1));
180
181        // finalize
182        let driver = HER2K_Driver {
183            uplo: uplo.try_into()?,
184            trans: trans.try_into()?,
185            n: n.try_into()?,
186            k: k.try_into()?,
187            alpha,
188            a,
189            lda: lda.try_into()?,
190            b,
191            ldb: ldb.try_into()?,
192            beta,
193            c,
194            ldc: ldc.try_into()?,
195        };
196        return Ok(driver);
197    }
198}
199
200/* #endregion */
201
202/* #region BLAS wrapper */
203
204pub type HER2K<'a, 'b, 'c, F> = HER2K_Builder<'a, 'b, 'c, F>;
205pub type CHER2K<'a, 'b, 'c> = HER2K<'a, 'b, 'c, c32>;
206pub type ZHER2K<'a, 'b, 'c> = HER2K<'a, 'b, 'c, c64>;
207
208impl<'a, 'b, 'c, F> BLASBuilder<'c, F, Ix2> for HER2K_Builder<'a, 'b, 'c, F>
209where
210    F: HER2KNum,
211{
212    fn run(self) -> Result<ArrayOut2<'c, F>, BLASError> {
213        // initialize
214        let HER2K_ { a, b, c, alpha, beta, uplo, trans, layout } = self.build()?;
215
216        // Note that since we will change `trans` in outer wrapper to utilize mix-contiguous
217        // additional check to this parameter is required
218        match trans {
219            // cherk, zherk: NT accepted
220            BLASNoTrans | BLASConjTrans => (),
221            _ => blas_invalid!(trans)?,
222        };
223
224        let layout_a = get_layout_array2(&a);
225        let layout_b = get_layout_array2(&b);
226        let layout_c = c.as_ref().map(|c| get_layout_array2(&c.view()));
227
228        // her2k is difficult to gain any improvement when input matrices layouts are mixed
229        let layout = get_layout_row_preferred(&[layout, layout_c], &[layout_a, layout_b]);
230        if layout == BLASColMajor {
231            // F-contiguous: C = A op(B) + B op(A)
232            let a_cow = a.to_col_layout()?;
233            let b_cow = b.to_col_layout()?;
234            let obj = HER2K_ {
235                a: a_cow.view(),
236                b: b_cow.view(),
237                c,
238                alpha,
239                beta,
240                uplo,
241                trans,
242                layout: Some(BLASColMajor),
243            };
244            return obj.driver()?.run_blas();
245        } else if layout == BLASRowMajor {
246            // C-contiguous: C' = op(B') A' + op(A') B'
247            let a_cow = a.to_row_layout()?;
248            let b_cow = b.to_row_layout()?;
249            let obj = HER2K_ {
250                a: b_cow.t(),
251                b: a_cow.t(),
252                c: c.map(|c| c.reversed_axes()),
253                alpha,
254                beta,
255                uplo: uplo.flip()?,
256                trans: trans.flip(true)?,
257                layout: Some(BLASColMajor),
258            };
259            return Ok(obj.driver()?.run_blas()?.reversed_axes());
260        } else {
261            return blas_raise!(RuntimeError, "This is designed not to execuate this line.");
262        }
263    }
264}
265
266/* #endregion */