blas_array2/blas3/
symm.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6/* #region BLAS func */
7
8pub trait SYMMNum: BLASFloat {
9    unsafe fn symm(
10        side: *const c_char,
11        uplo: *const c_char,
12        m: *const blas_int,
13        n: *const blas_int,
14        alpha: *const Self,
15        a: *const Self,
16        lda: *const blas_int,
17        b: *const Self,
18        ldb: *const blas_int,
19        beta: *const Self,
20        c: *mut Self,
21        ldc: *const blas_int,
22    );
23}
24
25macro_rules! impl_func {
26    ($type: ty, $func: ident) => {
27        impl SYMMNum for $type {
28            unsafe fn symm(
29                side: *const c_char,
30                uplo: *const c_char,
31                m: *const blas_int,
32                n: *const blas_int,
33                alpha: *const Self,
34                a: *const Self,
35                lda: *const blas_int,
36                b: *const Self,
37                ldb: *const blas_int,
38                beta: *const Self,
39                c: *mut Self,
40                ldc: *const blas_int,
41            ) {
42                ffi::$func(side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc);
43            }
44        }
45    };
46}
47
48impl_func!(f32, ssymm_);
49impl_func!(f64, dsymm_);
50impl_func!(c32, csymm_);
51impl_func!(c64, zsymm_);
52
53/* #endregion */
54
55/* #region BLAS driver */
56
57pub struct SYMM_Driver<'a, 'b, 'c, F>
58where
59    F: SYMMNum,
60{
61    side: c_char,
62    uplo: c_char,
63    m: blas_int,
64    n: blas_int,
65    alpha: F,
66    a: ArrayView2<'a, F>,
67    lda: blas_int,
68    b: ArrayView2<'b, F>,
69    ldb: blas_int,
70    beta: F,
71    c: ArrayOut2<'c, F>,
72    ldc: blas_int,
73}
74
75impl<'a, 'b, 'c, F> BLASDriver<'c, F, Ix2> for SYMM_Driver<'a, 'b, 'c, F>
76where
77    F: SYMMNum,
78{
79    fn run_blas(self) -> Result<ArrayOut2<'c, F>, BLASError> {
80        let Self { side, uplo, m, n, alpha, a, lda, b, ldb, beta, mut c, ldc, .. } = self;
81        let a_ptr = a.as_ptr();
82        let b_ptr = b.as_ptr();
83        let c_ptr = c.get_data_mut_ptr();
84
85        // assuming dimension checks has been performed
86        // unconditionally return Ok if output does not contain anything
87        if m == 0 || n == 0 {
88            return Ok(c.clone_to_view_mut());
89        }
90
91        unsafe {
92            F::symm(&side, &uplo, &m, &n, &alpha, a_ptr, &lda, b_ptr, &ldb, &beta, c_ptr, &ldc);
93        }
94        return Ok(c.clone_to_view_mut());
95    }
96}
97
98/* #endregion */
99
100/* #region BLAS builder */
101
102#[derive(Builder)]
103#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
104pub struct SYMM_<'a, 'b, 'c, F>
105where
106    F: BLASFloat,
107{
108    pub a: ArrayView2<'a, F>,
109    pub b: ArrayView2<'b, F>,
110
111    #[builder(setter(into, strip_option), default = "None")]
112    pub c: Option<ArrayViewMut2<'c, F>>,
113    #[builder(setter(into), default = "F::one()")]
114    pub alpha: F,
115    #[builder(setter(into), default = "F::zero()")]
116    pub beta: F,
117    #[builder(setter(into), default = "BLASLeft")]
118    pub side: BLASSide,
119    #[builder(setter(into), default = "BLASLower")]
120    pub uplo: BLASUpLo,
121    #[builder(setter(into, strip_option), default = "None")]
122    pub layout: Option<BLASLayout>,
123}
124
125impl<'a, 'b, 'c, F> BLASBuilder_<'c, F, Ix2> for SYMM_<'a, 'b, 'c, F>
126where
127    F: SYMMNum,
128{
129    fn driver(self) -> Result<SYMM_Driver<'a, 'b, 'c, F>, BLASError> {
130        let Self { a, b, c, alpha, beta, side, uplo, layout, .. } = self;
131
132        // only fortran-preferred (col-major) is accepted in inner wrapper
133        assert_eq!(layout, Some(BLASColMajor));
134        assert!(a.is_fpref() && a.is_fpref());
135
136        // initialize intent(hide)
137        let m = b.len_of(Axis(0));
138        let n = b.len_of(Axis(1));
139        let lda = a.stride_of(Axis(1));
140        let ldb = b.stride_of(Axis(1));
141
142        // perform check
143        match side {
144            BLASLeft => blas_assert_eq!(a.dim(), (m, m), InvalidDim)?,
145            BLASRight => blas_assert_eq!(a.dim(), (n, n), InvalidDim)?,
146            _ => blas_invalid!(side)?,
147        }
148
149        // optional intent(out)
150        let c = match c {
151            Some(c) => {
152                blas_assert_eq!(c.dim(), (m, n), InvalidDim)?;
153                if c.view().is_fpref() {
154                    ArrayOut2::ViewMut(c)
155                } else {
156                    let c_buffer = c.view().to_col_layout()?.into_owned();
157                    ArrayOut2::ToBeCloned(c, c_buffer)
158                }
159            },
160            None => ArrayOut2::Owned(Array2::zeros((m, n).f())),
161        };
162        let ldc = c.view().stride_of(Axis(1));
163
164        // finalize
165        let driver = SYMM_Driver::<'a, 'b, 'c, F> {
166            side: side.try_into()?,
167            uplo: uplo.try_into()?,
168            m: m.try_into()?,
169            n: n.try_into()?,
170            alpha,
171            a,
172            lda: lda.try_into()?,
173            b,
174            ldb: ldb.try_into()?,
175            beta,
176            c,
177            ldc: ldc.try_into()?,
178        };
179        return Ok(driver);
180    }
181}
182
183/* #endregion */
184
185/* #region BLAS wrapper */
186
187pub type SYMM<'a, 'b, 'c, F> = SYMM_Builder<'a, 'b, 'c, F>;
188pub type SSYMM<'a, 'b, 'c> = SYMM<'a, 'b, 'c, f32>;
189pub type DSYMM<'a, 'b, 'c> = SYMM<'a, 'b, 'c, f64>;
190pub type CSYMM<'a, 'b, 'c> = SYMM<'a, 'b, 'c, c32>;
191pub type ZSYMM<'a, 'b, 'c> = SYMM<'a, 'b, 'c, c64>;
192
193impl<'a, 'b, 'c, F> BLASBuilder<'c, F, Ix2> for SYMM_Builder<'a, 'b, 'c, F>
194where
195    F: SYMMNum,
196{
197    fn run(self) -> Result<ArrayOut2<'c, F>, BLASError> {
198        // initialize
199        let SYMM_ { a, b, c, alpha, beta, side, uplo, layout, .. } = self.build()?;
200        let at = a.t();
201
202        let layout_a = get_layout_array2(&a);
203        let layout_b = get_layout_array2(&b);
204        let layout_c = c.as_ref().map(|c| get_layout_array2(&c.view()));
205
206        let layout = get_layout_row_preferred(&[layout, layout_c], &[layout_a, layout_b]);
207        if layout == BLASColMajor {
208            // F-contiguous: C = op(A) op(B)
209            let (uplo, a_cow) = match layout_a.is_fpref() {
210                true => (uplo, a.to_col_layout()?),
211                false => (uplo.flip()?, at.to_col_layout()?),
212            };
213            let b_cow = b.to_col_layout()?;
214            let obj = SYMM_ {
215                a: a_cow.view(),
216                b: b_cow.view(),
217                c,
218                alpha,
219                beta,
220                side,
221                uplo,
222                layout: Some(BLASColMajor),
223            };
224            return obj.driver()?.run_blas();
225        } else {
226            // C-contiguous: C' = op(B') op(A')
227            let (uplo, a_cow) = match layout_a.is_cpref() {
228                true => (uplo, a.to_row_layout()?),
229                false => (uplo.flip()?, at.to_row_layout()?),
230            };
231            let b_cow = b.to_row_layout()?;
232            let obj = SYMM_ {
233                a: a_cow.t(),
234                b: b_cow.t(),
235                c: c.map(|c| c.reversed_axes()),
236                alpha,
237                beta,
238                side: side.flip()?,
239                uplo: uplo.flip()?,
240                layout: Some(BLASColMajor),
241            };
242            let c = obj.driver()?.run_blas()?.reversed_axes();
243            return Ok(c);
244        }
245    }
246}
247
248/* #endregion */