blas_array2/blas3/
herk.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5use num_traits::*;
6
7/* #region BLAS func */
8
9pub trait HERKNum: BLASFloat {
10    unsafe fn herk(
11        uplo: *const c_char,
12        trans: *const c_char,
13        n: *const blas_int,
14        k: *const blas_int,
15        alpha: *const Self::RealFloat,
16        a: *const Self,
17        lda: *const blas_int,
18        beta: *const Self::RealFloat,
19        c: *mut Self,
20        ldc: *const blas_int,
21    );
22}
23
24macro_rules! impl_herk {
25    ($type: ty, $func: ident) => {
26        impl HERKNum for $type {
27            unsafe fn herk(
28                uplo: *const c_char,
29                trans: *const c_char,
30                n: *const blas_int,
31                k: *const blas_int,
32                alpha: *const Self::RealFloat,
33                a: *const Self,
34                lda: *const blas_int,
35                beta: *const Self::RealFloat,
36                c: *mut Self,
37                ldc: *const blas_int,
38            ) {
39                ffi::$func(uplo, trans, n, k, alpha, a, lda, beta, c, ldc);
40            }
41        }
42    };
43}
44
45impl_herk!(c32, cherk_);
46impl_herk!(c64, zherk_);
47
48/* #endregion */
49
50/* #region BLAS driver */
51
52pub struct HERK_Driver<'a, 'c, F>
53where
54    F: BLASFloat,
55{
56    uplo: c_char,
57    trans: c_char,
58    n: blas_int,
59    k: blas_int,
60    alpha: F::RealFloat,
61    a: ArrayView2<'a, F>,
62    lda: blas_int,
63    beta: F::RealFloat,
64    c: ArrayOut2<'c, F>,
65    ldc: blas_int,
66}
67
68impl<'a, 'c, F> BLASDriver<'c, F, Ix2> for HERK_Driver<'a, 'c, F>
69where
70    F: HERKNum,
71{
72    fn run_blas(self) -> Result<ArrayOut2<'c, F>, BLASError> {
73        let Self { uplo, trans, n, k, alpha, a, lda, beta, mut c, ldc } = self;
74        let a_ptr = a.as_ptr();
75        let c_ptr = c.get_data_mut_ptr();
76
77        // assuming dimension checks has been performed
78        // unconditionally return Ok if output does not contain anything
79        if n == 0 {
80            return Ok(c.clone_to_view_mut());
81        } else if k == 0 {
82            let beta_f = F::from_real(beta);
83            if uplo == BLASLower.try_into()? {
84                for i in 0..n {
85                    c.view_mut().slice_mut(s![i.., i]).mapv_inplace(|v| v * beta_f);
86                }
87            } else if uplo == BLASUpper.try_into()? {
88                for i in 0..n {
89                    c.view_mut().slice_mut(s![..=i, i]).mapv_inplace(|v| v * beta_f);
90                }
91            } else {
92                blas_invalid!(uplo)?
93            }
94            return Ok(c.clone_to_view_mut());
95        }
96
97        unsafe {
98            F::herk(&uplo, &trans, &n, &k, &alpha, a_ptr, &lda, &beta, c_ptr, &ldc);
99        }
100        return Ok(c.clone_to_view_mut());
101    }
102}
103
104/* #endregion */
105
106/* #region BLAS builder */
107
108#[derive(Builder)]
109#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
110pub struct HERK_<'a, 'c, F>
111where
112    F: HERKNum,
113{
114    pub a: ArrayView2<'a, F>,
115
116    #[builder(setter(into, strip_option), default = "None")]
117    pub c: Option<ArrayViewMut2<'c, F>>,
118    #[builder(setter(into), default = "F::RealFloat::one()")]
119    pub alpha: F::RealFloat,
120    #[builder(setter(into), default = "F::RealFloat::zero()")]
121    pub beta: F::RealFloat,
122    #[builder(setter(into), default = "BLASLower")]
123    pub uplo: BLASUpLo,
124    #[builder(setter(into), default = "BLASNoTrans")]
125    pub trans: BLASTranspose,
126    #[builder(setter(into, strip_option), default = "None")]
127    pub layout: Option<BLASLayout>,
128}
129
130impl<'a, 'c, F> BLASBuilder_<'c, F, Ix2> for HERK_<'a, 'c, F>
131where
132    F: HERKNum,
133{
134    fn driver(self) -> Result<HERK_Driver<'a, 'c, F>, BLASError> {
135        let Self { a, c, alpha, beta, uplo, trans, layout } = self;
136
137        // only fortran-preferred (col-major) is accepted in inner wrapper
138        assert_eq!(layout, Some(BLASColMajor));
139        assert!(a.is_fpref());
140
141        // initialize intent(hide) (cherk, zherk: NC accepted)
142        let (n, k) = match trans {
143            BLASNoTrans => (a.len_of(Axis(0)), a.len_of(Axis(1))),
144            BLASConjTrans => (a.len_of(Axis(1)), a.len_of(Axis(0))),
145            _ => blas_invalid!(trans)?,
146        };
147        let lda = a.stride_of(Axis(1));
148
149        // optional intent(out)
150        let c = match c {
151            Some(c) => {
152                blas_assert_eq!(c.dim(), (n, n), InvalidDim)?;
153                if c.view().is_fpref() {
154                    ArrayOut2::ViewMut(c)
155                } else {
156                    let c_buffer = c.view().to_col_layout()?.into_owned();
157                    ArrayOut2::ToBeCloned(c, c_buffer)
158                }
159            },
160            None => ArrayOut2::Owned(Array2::zeros((n, n).f())),
161        };
162        let ldc = c.view().stride_of(Axis(1));
163
164        // finalize
165        let driver = HERK_Driver {
166            uplo: uplo.try_into()?,
167            trans: trans.try_into()?,
168            n: n.try_into()?,
169            k: k.try_into()?,
170            alpha,
171            a,
172            lda: lda.try_into()?,
173            beta,
174            c,
175            ldc: ldc.try_into()?,
176        };
177        return Ok(driver);
178    }
179}
180
181/* #endregion */
182
183/* #region BLAS wrapper */
184
185pub type HERK<'a, 'c, F> = HERK_Builder<'a, 'c, F>;
186pub type CHERK<'a, 'c> = HERK<'a, 'c, c32>;
187pub type ZHERK<'a, 'c> = HERK<'a, 'c, c64>;
188
189impl<'a, 'c, F> BLASBuilder<'c, F, Ix2> for HERK_Builder<'a, 'c, F>
190where
191    F: HERKNum,
192{
193    fn run(self) -> Result<ArrayOut2<'c, F>, BLASError> {
194        // initialize
195        let HERK_ { a, c, alpha, beta, uplo, trans, layout } = self.build()?;
196        let at = a.t();
197
198        // Note that since we will change `trans` in outer wrapper to utilize mix-contiguous
199        // additional check to this parameter is required
200        match trans {
201            // cherk, zherk: NC accepted
202            BLASNoTrans | BLASConjTrans => (),
203            _ => blas_invalid!(trans)?,
204        };
205
206        let layout_a = get_layout_array2(&a);
207        let layout_c = c.as_ref().map(|c| get_layout_array2(&c.view()));
208
209        let layout = get_layout_row_preferred(&[layout, layout_c], &[layout_a]);
210        if layout == BLASColMajor {
211            // F-contiguous: C = A op(A) or C = op(A) A
212            let (trans, a_cow) = flip_trans_fpref(trans, &a, &at, true)?;
213            let obj = HERK_ { a: a_cow.view(), c, alpha, beta, uplo, trans, layout: Some(BLASColMajor) };
214            return obj.driver()?.run_blas();
215        } else if layout == BLASRowMajor {
216            let (trans, a_cow) = flip_trans_cpref(trans, &a, &at, true)?;
217            let obj = HERK_ {
218                a: a_cow.t(),
219                c: c.map(|c| c.reversed_axes()),
220                alpha,
221                beta,
222                uplo: uplo.flip()?,
223                trans: trans.flip(true)?,
224                layout: Some(BLASColMajor),
225            };
226            return Ok(obj.driver()?.run_blas()?.reversed_axes());
227        } else {
228            return blas_raise!(RuntimeError, "This is designed not to execuate this line.");
229        }
230    }
231}
232
233/* #endregion */