blas_array2/blas3/
syrk.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6/* #region BLAS func */
7
8pub trait SYRKNum: BLASFloat {
9    unsafe fn syrk(
10        uplo: *const c_char,
11        trans: *const c_char,
12        n: *const blas_int,
13        k: *const blas_int,
14        alpha: *const Self,
15        a: *const Self,
16        lda: *const blas_int,
17        beta: *const Self,
18        c: *mut Self,
19        ldc: *const blas_int,
20    );
21}
22
23macro_rules! impl_syrk {
24    ($type: ty, $func: ident) => {
25        impl SYRKNum for $type {
26            unsafe fn syrk(
27                uplo: *const c_char,
28                trans: *const c_char,
29                n: *const blas_int,
30                k: *const blas_int,
31                alpha: *const Self,
32                a: *const Self,
33                lda: *const blas_int,
34                beta: *const Self,
35                c: *mut Self,
36                ldc: *const blas_int,
37            ) {
38                ffi::$func(uplo, trans, n, k, alpha, a, lda, beta, c, ldc);
39            }
40        }
41    };
42}
43
44impl_syrk!(f32, ssyrk_);
45impl_syrk!(f64, dsyrk_);
46impl_syrk!(c32, csyrk_);
47impl_syrk!(c64, zsyrk_);
48
49/* #endregion */
50
51/* #region BLAS driver */
52
53pub struct SYRK_Driver<'a, 'c, F>
54where
55    F: BLASFloat,
56{
57    uplo: c_char,
58    trans: c_char,
59    n: blas_int,
60    k: blas_int,
61    alpha: F,
62    a: ArrayView2<'a, F>,
63    lda: blas_int,
64    beta: F,
65    c: ArrayOut2<'c, F>,
66    ldc: blas_int,
67}
68
69impl<'a, 'c, F> BLASDriver<'c, F, Ix2> for SYRK_Driver<'a, 'c, F>
70where
71    F: SYRKNum,
72{
73    fn run_blas(self) -> Result<ArrayOut2<'c, F>, BLASError> {
74        let Self { uplo, trans, n, k, alpha, a, lda, beta, mut c, ldc } = self;
75        let a_ptr = a.as_ptr();
76        let c_ptr = c.get_data_mut_ptr();
77
78        // assuming dimension checks has been performed
79        // unconditionally return Ok if output does not contain anything
80        if n == 0 {
81            return Ok(c.clone_to_view_mut());
82        } else if k == 0 {
83            let beta_f = F::from(beta);
84            if uplo == BLASLower.try_into()? {
85                for i in 0..n {
86                    c.view_mut().slice_mut(s![i.., i]).mapv_inplace(|v| v * beta_f);
87                }
88            } else if uplo == BLASUpper.try_into()? {
89                for i in 0..n {
90                    c.view_mut().slice_mut(s![..=i, i]).mapv_inplace(|v| v * beta_f);
91                }
92            } else {
93                blas_invalid!(uplo)?
94            }
95            return Ok(c.clone_to_view_mut());
96        }
97
98        unsafe {
99            F::syrk(&uplo, &trans, &n, &k, &alpha, a_ptr, &lda, &beta, c_ptr, &ldc);
100        }
101        return Ok(c.clone_to_view_mut());
102    }
103}
104
105/* #endregion */
106
107/* #region BLAS builder */
108
109#[derive(Builder)]
110#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
111pub struct SYRK_<'a, 'c, F>
112where
113    F: SYRKNum,
114{
115    pub a: ArrayView2<'a, F>,
116
117    #[builder(setter(into, strip_option), default = "None")]
118    pub c: Option<ArrayViewMut2<'c, F>>,
119    #[builder(setter(into), default = "F::one()")]
120    pub alpha: F,
121    #[builder(setter(into), default = "F::zero()")]
122    pub beta: F,
123    #[builder(setter(into), default = "BLASLower")]
124    pub uplo: BLASUpLo,
125    #[builder(setter(into), default = "BLASNoTrans")]
126    pub trans: BLASTranspose,
127    #[builder(setter(into, strip_option), default = "None")]
128    pub layout: Option<BLASLayout>,
129}
130
131impl<'a, 'c, F> BLASBuilder_<'c, F, Ix2> for SYRK_<'a, 'c, F>
132where
133    F: SYRKNum,
134{
135    fn driver(self) -> Result<SYRK_Driver<'a, 'c, F>, BLASError> {
136        let Self { a, c, alpha, beta, uplo, trans, layout } = self;
137
138        // only fortran-preferred (col-major) is accepted in inner wrapper
139        assert_eq!(layout, Some(BLASColMajor));
140        assert!(a.is_fpref());
141
142        // initialize intent(hide)
143        let (n, k) = match trans {
144            BLASNoTrans => (a.len_of(Axis(0)), a.len_of(Axis(1))),
145            BLASTrans | BLASConjTrans => (a.len_of(Axis(1)), a.len_of(Axis(0))),
146            _ => blas_invalid!(trans)?,
147        };
148        let lda = a.stride_of(Axis(1));
149
150        // perform check
151        match F::is_complex() {
152            false => match trans {
153                // ssyrk, dsyrk: NTC accepted
154                BLASNoTrans | BLASTrans | BLASConjTrans => (),
155                _ => blas_invalid!(trans)?,
156            },
157            true => match trans {
158                // csyrk, zsyrk: NT accepted
159                BLASNoTrans | BLASTrans => (),
160                _ => blas_invalid!(trans)?,
161            },
162        };
163
164        // optional intent(out)
165        let c = match c {
166            Some(c) => {
167                blas_assert_eq!(c.dim(), (n, n), InvalidDim)?;
168                if c.view().is_fpref() {
169                    ArrayOut2::ViewMut(c)
170                } else {
171                    let c_buffer = c.view().to_col_layout()?.into_owned();
172                    ArrayOut2::ToBeCloned(c, c_buffer)
173                }
174            },
175            None => ArrayOut2::Owned(Array2::zeros((n, n).f())),
176        };
177        let ldc = c.view().stride_of(Axis(1));
178
179        // finalize
180        let driver = SYRK_Driver {
181            uplo: uplo.try_into()?,
182            trans: trans.try_into()?,
183            n: n.try_into()?,
184            k: k.try_into()?,
185            alpha,
186            a,
187            lda: lda.try_into()?,
188            beta,
189            c,
190            ldc: ldc.try_into()?,
191        };
192        return Ok(driver);
193    }
194}
195
196/* #endregion */
197
198/* #region BLAS wrapper */
199
200pub type SYRK<'a, 'c, F> = SYRK_Builder<'a, 'c, F>;
201pub type SSYRK<'a, 'c> = SYRK<'a, 'c, f32>;
202pub type DSYRK<'a, 'c> = SYRK<'a, 'c, f64>;
203pub type CSYRK<'a, 'c> = SYRK<'a, 'c, c32>;
204pub type ZSYRK<'a, 'c> = SYRK<'a, 'c, c64>;
205
206impl<'a, 'c, F> BLASBuilder<'c, F, Ix2> for SYRK_Builder<'a, 'c, F>
207where
208    F: SYRKNum,
209{
210    fn run(self) -> Result<ArrayOut2<'c, F>, BLASError> {
211        // initialize
212        let SYRK_ { a, c, alpha, beta, uplo, trans, layout } = self.build()?;
213        let at = a.t();
214
215        // Note that since we will change `trans` in outer wrapper to utilize mix-contiguous
216        // additional check to this parameter is required
217        match F::is_complex() {
218            false => match trans {
219                // ssyrk, dsyrk: NTC accepted
220                BLASNoTrans | BLASTrans | BLASConjTrans => (),
221                _ => blas_invalid!(trans)?,
222            },
223            true => match trans {
224                // csyrk, zsyrk: NT accepted
225                BLASNoTrans | BLASTrans => (),
226                _ => blas_invalid!(trans)?,
227            },
228        };
229
230        let layout_a = get_layout_array2(&a);
231        let layout_c = c.as_ref().map(|c| get_layout_array2(&c.view()));
232
233        let layout = get_layout_row_preferred(&[layout, layout_c], &[layout_a]);
234        if layout == BLASColMajor {
235            // F-contiguous: C = A op(A) or C = op(A) A
236            let (trans, a_cow) = flip_trans_fpref(trans, &a, &at, false)?;
237            let obj = SYRK_ { a: a_cow.view(), c, alpha, beta, uplo, trans, layout: Some(BLASColMajor) };
238            return obj.driver()?.run_blas();
239        } else if layout == BLASRowMajor {
240            let (trans, a_cow) = flip_trans_cpref(trans, &a, &at, false)?;
241            let obj = SYRK_ {
242                a: a_cow.t(),
243                c: c.map(|c| c.reversed_axes()),
244                alpha,
245                beta,
246                uplo: uplo.flip()?,
247                trans: trans.flip(false)?,
248                layout: Some(BLASColMajor),
249            };
250            return Ok(obj.driver()?.run_blas()?.reversed_axes());
251        } else {
252            return blas_raise!(RuntimeError, "This is designed not to execuate this line.");
253        }
254    }
255}
256
257/* #endregion */