blas_array2/blas3/
trmm.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6/* #region BLAS func */
7
8pub trait TRMMNum: BLASFloat {
9    unsafe fn trmm(
10        side: *const c_char,
11        uplo: *const c_char,
12        transa: *const c_char,
13        diag: *const c_char,
14        m: *const blas_int,
15        n: *const blas_int,
16        alpha: *const Self,
17        a: *const Self,
18        lda: *const blas_int,
19        b: *mut Self,
20        ldb: *const blas_int,
21    );
22}
23
24macro_rules! impl_func {
25    ($type: ty, $func: ident) => {
26        impl TRMMNum for $type {
27            unsafe fn trmm(
28                side: *const c_char,
29                uplo: *const c_char,
30                transa: *const c_char,
31                diag: *const c_char,
32                m: *const blas_int,
33                n: *const blas_int,
34                alpha: *const Self,
35                a: *const Self,
36                lda: *const blas_int,
37                b: *mut Self,
38                ldb: *const blas_int,
39            ) {
40                ffi::$func(side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb);
41            }
42        }
43    };
44}
45
46impl_func!(f32, strmm_);
47impl_func!(f64, dtrmm_);
48impl_func!(c32, ctrmm_);
49impl_func!(c64, ztrmm_);
50
51/* #endregion */
52
53/* #region BLAS driver */
54
55pub struct TRMM_Driver<'a, 'b, F>
56where
57    F: BLASFloat,
58{
59    side: c_char,
60    uplo: c_char,
61    transa: c_char,
62    diag: c_char,
63    m: blas_int,
64    n: blas_int,
65    alpha: F,
66    a: ArrayView2<'a, F>,
67    lda: blas_int,
68    b: ArrayOut2<'b, F>,
69    ldb: blas_int,
70}
71
72impl<'a, 'b, F> BLASDriver<'b, F, Ix2> for TRMM_Driver<'a, 'b, F>
73where
74    F: TRMMNum,
75{
76    fn run_blas(self) -> Result<ArrayOut2<'b, F>, BLASError> {
77        let Self { side, uplo, transa, diag, m, n, alpha, a, lda, mut b, ldb } = self;
78        let a_ptr = a.as_ptr();
79        let b_ptr = b.get_data_mut_ptr();
80
81        // assuming dimension checks has been performed
82        // unconditionally return Ok if output does not contain anything
83        if m == 0 || n == 0 {
84            return Ok(b.clone_to_view_mut());
85        }
86
87        unsafe {
88            F::trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha, a_ptr, &lda, b_ptr, &ldb);
89        }
90        return Ok(b.clone_to_view_mut());
91    }
92}
93
94/* #endregion */
95
96/* #region BLAS builder */
97
98#[derive(Builder)]
99#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
100pub struct TRMM_<'a, 'b, F>
101where
102    F: TRMMNum,
103{
104    pub a: ArrayView2<'a, F>,
105    pub b: ArrayViewMut2<'b, F>,
106
107    #[builder(setter(into), default = "F::one()")]
108    pub alpha: F,
109    #[builder(setter(into), default = "BLASLeft")]
110    pub side: BLASSide,
111    #[builder(setter(into), default = "BLASUpper")]
112    pub uplo: BLASUpLo,
113    #[builder(setter(into), default = "BLASNoTrans")]
114    pub transa: BLASTranspose,
115    #[builder(setter(into), default = "BLASNonUnit")]
116    pub diag: BLASDiag,
117    #[builder(setter(into, strip_option), default = "None")]
118    pub layout: Option<BLASLayout>,
119}
120
121impl<'a, 'b, F> BLASBuilder_<'b, F, Ix2> for TRMM_<'a, 'b, F>
122where
123    F: TRMMNum,
124{
125    fn driver(self) -> Result<TRMM_Driver<'a, 'b, F>, BLASError> {
126        let Self { a, b, alpha, side, uplo, transa, diag, layout } = self;
127
128        // only fortran-preferred (col-major) is accepted in inner wrapper
129        assert_eq!(layout, Some(BLASColMajor));
130        assert!(a.is_fpref());
131
132        // initialize intent(hide)
133        let (m, n) = b.dim();
134        let lda = a.stride_of(Axis(1));
135
136        // perform check
137        match side {
138            BLASLeft => blas_assert_eq!(a.dim(), (m, m), InvalidDim)?,
139            BLASRight => blas_assert_eq!(a.dim(), (n, n), InvalidDim)?,
140            _ => blas_invalid!(side)?,
141        };
142
143        // prepare output
144        let b = if b.view().is_fpref() {
145            ArrayOut2::ViewMut(b)
146        } else {
147            let b_buffer = b.view().to_col_layout()?.into_owned();
148            ArrayOut2::ToBeCloned(b, b_buffer)
149        };
150        let ldb = b.view().stride_of(Axis(1));
151
152        // finalize
153        let driver = TRMM_Driver {
154            side: side.try_into()?,
155            uplo: uplo.try_into()?,
156            transa: transa.try_into()?,
157            diag: diag.try_into()?,
158            m: m.try_into()?,
159            n: n.try_into()?,
160            alpha,
161            a,
162            lda: lda.try_into()?,
163            b,
164            ldb: ldb.try_into()?,
165        };
166        return Ok(driver);
167    }
168}
169
170/* #endregion */
171
172/* #region BLAS wrapper */
173
174pub type TRMM<'a, 'b, F> = TRMM_Builder<'a, 'b, F>;
175pub type STRMM<'a, 'b> = TRMM<'a, 'b, f32>;
176pub type DTRMM<'a, 'b> = TRMM<'a, 'b, f64>;
177pub type CTRMM<'a, 'b> = TRMM<'a, 'b, c32>;
178pub type ZTRMM<'a, 'b> = TRMM<'a, 'b, c64>;
179
180impl<'a, 'b, F> BLASBuilder<'b, F, Ix2> for TRMM_Builder<'a, 'b, F>
181where
182    F: TRMMNum,
183{
184    fn run(self) -> Result<ArrayOut2<'b, F>, BLASError> {
185        // initialize
186        let TRMM_ { a, b, alpha, side, uplo, transa, diag, layout } = self.build()?;
187        let at = a.t();
188
189        let layout_a = get_layout_array2(&a);
190        let layout_b = get_layout_array2(&b.view());
191
192        let layout = get_layout_row_preferred(&[layout, Some(layout_b)], &[layout_a]);
193        if layout == BLASColMajor {
194            // F-contiguous: B = op(A) B (if side = L)
195            let (transa_new, a_cow) = flip_trans_fpref(transa, &a, &at, false)?;
196            let uplo = if transa_new != transa { uplo.flip()? } else { uplo };
197            let obj = TRMM_ {
198                a: a_cow.view(),
199                b,
200                alpha,
201                side,
202                uplo,
203                transa: transa_new,
204                diag,
205                layout: Some(BLASColMajor),
206            };
207            return obj.driver()?.run_blas();
208        } else {
209            // C-contiguous: B' = B' op(A') (if side = L)
210            let (transa_new, a_cow) = flip_trans_cpref(transa, &a, &at, false)?;
211            let uplo = if transa_new != transa { uplo.flip()? } else { uplo };
212            let obj = TRMM_ {
213                a: a_cow.t(),
214                b: b.reversed_axes(),
215                alpha,
216                side: side.flip()?,
217                uplo: uplo.flip()?,
218                transa: transa_new,
219                diag,
220                layout: Some(BLASColMajor),
221            };
222            return Ok(obj.driver()?.run_blas()?.reversed_axes());
223        }
224    }
225}
226
227/* #endregion */