blas_array2/blas3/
syr2k.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6/* #region BLAS func */
7
8pub trait SYR2KNum: BLASFloat {
9    unsafe fn syr2k(
10        uplo: *const c_char,
11        trans: *const c_char,
12        n: *const blas_int,
13        k: *const blas_int,
14        alpha: *const Self,
15        a: *const Self,
16        lda: *const blas_int,
17        b: *const Self,
18        ldb: *const blas_int,
19        beta: *const Self,
20        c: *mut Self,
21        ldc: *const blas_int,
22    );
23}
24
25macro_rules! impl_syr2k {
26    ($type: ty, $func: ident) => {
27        impl SYR2KNum for $type {
28            unsafe fn syr2k(
29                uplo: *const c_char,
30                trans: *const c_char,
31                n: *const blas_int,
32                k: *const blas_int,
33                alpha: *const Self,
34                a: *const Self,
35                lda: *const blas_int,
36                b: *const Self,
37                ldb: *const blas_int,
38                beta: *const Self,
39                c: *mut Self,
40                ldc: *const blas_int,
41            ) {
42                ffi::$func(uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
43            }
44        }
45    };
46}
47
48impl_syr2k!(f32, ssyr2k_);
49impl_syr2k!(f64, dsyr2k_);
50impl_syr2k!(c32, csyr2k_);
51impl_syr2k!(c64, zsyr2k_);
52
53/* #endregion */
54
55/* #region BLAS driver */
56
57pub struct SYR2K_Driver<'a, 'b, 'c, F>
58where
59    F: SYR2KNum,
60{
61    uplo: c_char,
62    trans: c_char,
63    n: blas_int,
64    k: blas_int,
65    alpha: F,
66    a: ArrayView2<'a, F>,
67    lda: blas_int,
68    b: ArrayView2<'b, F>,
69    ldb: blas_int,
70    beta: F,
71    c: ArrayOut2<'c, F>,
72    ldc: blas_int,
73}
74
75impl<'a, 'b, 'c, F> BLASDriver<'c, F, Ix2> for SYR2K_Driver<'a, 'b, 'c, F>
76where
77    F: SYR2KNum,
78{
79    fn run_blas(self) -> Result<ArrayOut2<'c, F>, BLASError> {
80        let Self { uplo, trans, n, k, alpha, a, lda, b, ldb, beta, mut c, ldc } = self;
81        let a_ptr = a.as_ptr();
82        let b_ptr = b.as_ptr();
83        let c_ptr = c.get_data_mut_ptr();
84
85        // assuming dimension checks has been performed
86        // unconditionally return Ok if output does not contain anything
87        if n == 0 {
88            return Ok(c.clone_to_view_mut());
89        } else if k == 0 {
90            let beta_f = F::from(beta);
91            if uplo == BLASLower.try_into()? {
92                for i in 0..n {
93                    c.view_mut().slice_mut(s![i.., i]).mapv_inplace(|v| v * beta_f);
94                }
95            } else if uplo == BLASUpper.try_into()? {
96                for i in 0..n {
97                    c.view_mut().slice_mut(s![..=i, i]).mapv_inplace(|v| v * beta_f);
98                }
99            } else {
100                blas_invalid!(uplo)?
101            }
102            return Ok(c.clone_to_view_mut());
103        }
104
105        unsafe {
106            F::syr2k(&uplo, &trans, &n, &k, &alpha, a_ptr, &lda, b_ptr, &ldb, &beta, c_ptr, &ldc);
107        }
108        return Ok(c.clone_to_view_mut());
109    }
110}
111
112/* #endregion */
113
114/* #region BLAS builder */
115
116#[derive(Builder)]
117#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
118pub struct SYR2K_<'a, 'b, 'c, F>
119where
120    F: SYR2KNum,
121{
122    pub a: ArrayView2<'a, F>,
123    pub b: ArrayView2<'b, F>,
124
125    #[builder(setter(into, strip_option), default = "None")]
126    pub c: Option<ArrayViewMut2<'c, F>>,
127    #[builder(setter(into), default = "F::one()")]
128    pub alpha: F,
129    #[builder(setter(into), default = "F::zero()")]
130    pub beta: F,
131    #[builder(setter(into), default = "BLASLower")]
132    pub uplo: BLASUpLo,
133    #[builder(setter(into), default = "BLASNoTrans")]
134    pub trans: BLASTranspose,
135    #[builder(setter(into, strip_option), default = "None")]
136    pub layout: Option<BLASLayout>,
137}
138
139impl<'a, 'b, 'c, F> BLASBuilder_<'c, F, Ix2> for SYR2K_<'a, 'b, 'c, F>
140where
141    F: SYR2KNum,
142{
143    fn driver(self) -> Result<SYR2K_Driver<'a, 'b, 'c, F>, BLASError> {
144        let Self { a, b, c, alpha, beta, uplo, trans, layout } = self;
145
146        // only fortran-preferred (col-major) is accepted in inner wrapper
147        assert_eq!(layout, Some(BLASColMajor));
148        assert!(a.is_fpref() && a.is_fpref());
149
150        // initialize intent(hide)
151        let (n, k) = match trans {
152            BLASNoTrans => (a.len_of(Axis(0)), a.len_of(Axis(1))),
153            BLASTrans | BLASConjTrans => (a.len_of(Axis(1)), a.len_of(Axis(0))),
154            _ => blas_invalid!(trans)?,
155        };
156        let lda = a.stride_of(Axis(1));
157        let ldb = b.stride_of(Axis(1));
158
159        // perform check
160        // dimension of b
161        match trans {
162            BLASNoTrans => blas_assert_eq!(b.dim(), (n, k), InvalidDim)?,
163            BLASTrans | BLASConjTrans => blas_assert_eq!(b.dim(), (k, n), InvalidDim)?,
164            _ => blas_invalid!(trans)?,
165        };
166        // trans keyword
167        match F::is_complex() {
168            false => match trans {
169                // ssyrk, dsyrk: NTC accepted
170                BLASNoTrans | BLASTrans | BLASConjTrans => (),
171                _ => blas_invalid!(trans)?,
172            },
173            true => match trans {
174                // csyrk, zsyrk: NT accepted
175                BLASNoTrans | BLASTrans => (),
176                _ => blas_invalid!(trans)?,
177            },
178        };
179
180        // optional intent(out)
181        let c = match c {
182            Some(c) => {
183                blas_assert_eq!(c.dim(), (n, n), InvalidDim)?;
184                if c.view().is_fpref() {
185                    ArrayOut2::ViewMut(c)
186                } else {
187                    let c_buffer = c.view().to_col_layout()?.into_owned();
188                    ArrayOut2::ToBeCloned(c, c_buffer)
189                }
190            },
191            None => ArrayOut2::Owned(Array2::zeros((n, n).f())),
192        };
193        let ldc = c.view().stride_of(Axis(1));
194
195        // finalize
196        let driver = SYR2K_Driver {
197            uplo: uplo.try_into()?,
198            trans: trans.try_into()?,
199            n: n.try_into()?,
200            k: k.try_into()?,
201            alpha,
202            a,
203            lda: lda.try_into()?,
204            b,
205            ldb: ldb.try_into()?,
206            beta,
207            c,
208            ldc: ldc.try_into()?,
209        };
210        return Ok(driver);
211    }
212}
213
214/* #endregion */
215
216/* #region BLAS wrapper */
217
218pub type SYR2K<'a, 'b, 'c, F> = SYR2K_Builder<'a, 'b, 'c, F>;
219pub type SSYR2K<'a, 'b, 'c> = SYR2K<'a, 'b, 'c, f32>;
220pub type DSYR2K<'a, 'b, 'c> = SYR2K<'a, 'b, 'c, f64>;
221pub type CSYR2K<'a, 'b, 'c> = SYR2K<'a, 'b, 'c, c32>;
222pub type ZSYR2K<'a, 'b, 'c> = SYR2K<'a, 'b, 'c, c64>;
223
224impl<'a, 'b, 'c, F> BLASBuilder<'c, F, Ix2> for SYR2K_Builder<'a, 'b, 'c, F>
225where
226    F: SYR2KNum,
227{
228    fn run(self) -> Result<ArrayOut2<'c, F>, BLASError> {
229        // initialize
230        let SYR2K_ { a, b, c, alpha, beta, uplo, trans, layout } = self.build()?;
231
232        // Note that since we will change `trans` in outer wrapper to utilize mix-contiguous
233        // additional check to this parameter is required
234        match F::is_complex() {
235            false => match trans {
236                // ssyrk, dsyrk: NTC accepted
237                BLASNoTrans | BLASTrans | BLASConjTrans => (),
238                _ => blas_invalid!(trans)?,
239            },
240            true => match trans {
241                // csyrk, zsyrk: NT accepted
242                BLASNoTrans | BLASTrans => (),
243                _ => blas_invalid!(trans)?,
244            },
245        };
246
247        let layout_a = get_layout_array2(&a);
248        let layout_b = get_layout_array2(&b);
249        let layout_c = c.as_ref().map(|c| get_layout_array2(&c.view()));
250
251        // syr2k is difficult to gain any improvement when input matrices layouts are mixed
252        let layout = get_layout_row_preferred(&[layout, layout_c], &[layout_a, layout_b]);
253        if layout == BLASColMajor {
254            // F-contiguous: C = A op(B) + B op(A)
255            let a_cow = a.to_col_layout()?;
256            let b_cow = b.to_col_layout()?;
257            let obj = SYR2K_ {
258                a: a_cow.view(),
259                b: b_cow.view(),
260                c,
261                alpha,
262                beta,
263                uplo,
264                trans,
265                layout: Some(BLASColMajor),
266            };
267            return obj.driver()?.run_blas();
268        } else if layout == BLASRowMajor {
269            // C-contiguous: C' = op(B') A' + op(A') B'
270            let a_cow = a.to_row_layout()?;
271            let b_cow = b.to_row_layout()?;
272            let obj = SYR2K_ {
273                a: b_cow.t(),
274                b: a_cow.t(),
275                c: c.map(|c| c.reversed_axes()),
276                alpha,
277                beta,
278                uplo: uplo.flip()?,
279                trans: trans.flip(false)?,
280                layout: Some(BLASColMajor),
281            };
282            return Ok(obj.driver()?.run_blas()?.reversed_axes());
283        } else {
284            return blas_raise!(RuntimeError, "This is designed not to execuate this line.");
285        }
286    }
287}
288
289/* #endregion */