blas_array2/blas2/
trmv.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6/* #region BLAS func */
7
8pub trait TRMVNum: BLASFloat {
9    unsafe fn trmv(
10        uplo: *const c_char,
11        trans: *const c_char,
12        diag: *const c_char,
13        n: *const blas_int,
14        a: *const Self,
15        lda: *const blas_int,
16        x: *mut Self,
17        incx: *const blas_int,
18    );
19}
20
21macro_rules! impl_func {
22    ($type: ty, $func: ident) => {
23        impl TRMVNum for $type {
24            unsafe fn trmv(
25                uplo: *const c_char,
26                trans: *const c_char,
27                diag: *const c_char,
28                n: *const blas_int,
29                a: *const Self,
30                lda: *const blas_int,
31                x: *mut Self,
32                incx: *const blas_int,
33            ) {
34                ffi::$func(uplo, trans, diag, n, a, lda, x, incx);
35            }
36        }
37    };
38}
39
40impl_func!(f32, strmv_);
41impl_func!(f64, dtrmv_);
42impl_func!(c32, ctrmv_);
43impl_func!(c64, ztrmv_);
44
45/* #endregion */
46
47/* #region BLAS driver */
48
49pub struct TRMV_Driver<'a, 'x, F>
50where
51    F: TRMVNum,
52{
53    uplo: c_char,
54    trans: c_char,
55    diag: c_char,
56    n: blas_int,
57    a: ArrayView2<'a, F>,
58    lda: blas_int,
59    x: ArrayOut1<'x, F>,
60    incx: blas_int,
61}
62
63impl<'a, 'x, F> BLASDriver<'x, F, Ix1> for TRMV_Driver<'a, 'x, F>
64where
65    F: TRMVNum,
66{
67    fn run_blas(self) -> Result<ArrayOut1<'x, F>, BLASError> {
68        let Self { uplo, trans, diag, n, a, lda, mut x, incx } = self;
69        let a_ptr = a.as_ptr();
70        let x_ptr = x.get_data_mut_ptr();
71
72        // assuming dimension checks has been performed
73        // unconditionally return Ok if output does not contain anything
74        if n == 0 {
75            return Ok(x);
76        }
77
78        unsafe {
79            F::trmv(&uplo, &trans, &diag, &n, a_ptr, &lda, x_ptr, &incx);
80        }
81        return Ok(x);
82    }
83}
84
85/* #endregion */
86
87/* #region BLAS builder */
88
89#[derive(Builder)]
90#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
91pub struct TRMV_<'a, 'x, F>
92where
93    F: TRMVNum,
94{
95    pub a: ArrayView2<'a, F>,
96    pub x: ArrayViewMut1<'x, F>,
97
98    #[builder(setter(into), default = "BLASUpper")]
99    pub uplo: BLASUpLo,
100    #[builder(setter(into), default = "BLASNoTrans")]
101    pub trans: BLASTranspose,
102    #[builder(setter(into), default = "BLASNonUnit")]
103    pub diag: BLASDiag,
104}
105
106impl<'a, 'x, F> BLASBuilder_<'x, F, Ix1> for TRMV_<'a, 'x, F>
107where
108    F: TRMVNum,
109{
110    fn driver(self) -> Result<TRMV_Driver<'a, 'x, F>, BLASError> {
111        let Self { a, x, uplo, trans, diag } = self;
112
113        // only fortran-preferred (col-major) is accepted in inner wrapper
114        let layout_a = get_layout_array2(&a);
115        assert!(layout_a.is_fpref());
116
117        // initialize intent(hide)
118        let (n, n_) = a.dim();
119        let lda = a.stride_of(Axis(1));
120        let incx = x.stride_of(Axis(0));
121
122        // perform check
123        blas_assert_eq!(n, n_, InvalidDim)?;
124        blas_assert_eq!(x.len_of(Axis(0)), n, InvalidDim)?;
125
126        // prepare output
127        let x = ArrayOut1::ViewMut(x);
128
129        // finalize
130        let driver = TRMV_Driver {
131            uplo: uplo.try_into()?,
132            trans: trans.try_into()?,
133            diag: diag.try_into()?,
134            n: n.try_into()?,
135            a,
136            lda: lda.try_into()?,
137            x,
138            incx: incx.try_into()?,
139        };
140        return Ok(driver);
141    }
142}
143
144/* #endregion */
145
146/* #region BLAS wrapper */
147
148pub type TRMV<'a, 'x, F> = TRMV_Builder<'a, 'x, F>;
149pub type STRMV<'a, 'x> = TRMV<'a, 'x, f32>;
150pub type DTRMV<'a, 'x> = TRMV<'a, 'x, f64>;
151pub type CTRMV<'a, 'x> = TRMV<'a, 'x, c32>;
152pub type ZTRMV<'a, 'x> = TRMV<'a, 'x, c64>;
153
154impl<'a, 'x, F> BLASBuilder<'x, F, Ix1> for TRMV_Builder<'a, 'x, F>
155where
156    F: TRMVNum,
157{
158    fn run(self) -> Result<ArrayOut1<'x, F>, BLASError> {
159        // initialize
160        let obj = self.build()?;
161
162        let layout_a = get_layout_array2(&obj.a);
163
164        if layout_a.is_fpref() {
165            // F-contiguous: x = op(A) x
166            return obj.driver()?.run_blas();
167        } else {
168            // C-contiguous:
169            let a_cow = obj.a.to_row_layout()?;
170            match obj.trans {
171                BLASNoTrans => {
172                    // N -> T: x = op(A')' x
173                    let obj = TRMV_ { a: a_cow.t(), trans: BLASTrans, uplo: obj.uplo.flip()?, ..obj };
174                    return obj.driver()?.run_blas();
175                },
176                BLASTrans => {
177                    // T -> N: x = op(A') x
178                    let obj = TRMV_ { a: a_cow.t(), trans: BLASNoTrans, uplo: obj.uplo.flip()?, ..obj };
179                    return obj.driver()?.run_blas();
180                },
181                BLASConjTrans => {
182                    // C -> T: x* = op(A') x*; x = x*
183                    let mut x = obj.x;
184                    x.mapv_inplace(F::conj);
185                    let obj = TRMV_ { a: a_cow.t(), x, trans: BLASNoTrans, uplo: obj.uplo.flip()?, ..obj };
186                    let mut x = obj.driver()?.run_blas()?;
187                    x.view_mut().mapv_inplace(F::conj);
188                    return Ok(x);
189                },
190                _ => return blas_invalid!(&obj.trans)?,
191            }
192        }
193    }
194}
195
196/* #endregion */