blas_array2/blas2/
tpmv.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6/* #region BLAS func */
7
8pub trait TPMVNum: BLASFloat {
9    unsafe fn tpmv(
10        uplo: *const c_char,
11        trans: *const c_char,
12        diag: *const c_char,
13        n: *const blas_int,
14        ap: *const Self,
15        x: *mut Self,
16        incx: *const blas_int,
17    );
18}
19
20macro_rules! impl_func {
21    ($type: ty, $func: ident) => {
22        impl TPMVNum for $type {
23            unsafe fn tpmv(
24                uplo: *const c_char,
25                trans: *const c_char,
26                diag: *const c_char,
27                n: *const blas_int,
28                ap: *const Self,
29                x: *mut Self,
30                incx: *const blas_int,
31            ) {
32                ffi::$func(uplo, trans, diag, n, ap, x, incx);
33            }
34        }
35    };
36}
37
38impl_func!(f32, stpmv_);
39impl_func!(f64, dtpmv_);
40impl_func!(c32, ctpmv_);
41impl_func!(c64, ztpmv_);
42
43/* #endregion */
44
45/* #region BLAS driver */
46
47pub struct TPMV_Driver<'a, 'x, F>
48where
49    F: TPMVNum,
50{
51    uplo: c_char,
52    trans: c_char,
53    diag: c_char,
54    n: blas_int,
55    ap: ArrayView1<'a, F>,
56    x: ArrayOut1<'x, F>,
57    incx: blas_int,
58}
59
60impl<'a, 'x, F> BLASDriver<'x, F, Ix1> for TPMV_Driver<'a, 'x, F>
61where
62    F: TPMVNum,
63{
64    fn run_blas(self) -> Result<ArrayOut1<'x, F>, BLASError> {
65        let Self { uplo, trans, diag, n, ap, mut x, incx } = self;
66        let ap_ptr = ap.as_ptr();
67        let x_ptr = x.get_data_mut_ptr();
68
69        // assuming dimension checks has been performed
70        // unconditionally return Ok if output does not contain anything
71        if n == 0 {
72            return Ok(x);
73        }
74
75        unsafe {
76            F::tpmv(&uplo, &trans, &diag, &n, ap_ptr, x_ptr, &incx);
77        }
78        return Ok(x);
79    }
80}
81
82/* #endregion */
83
84/* #region BLAS builder */
85
86#[derive(Builder)]
87#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
88pub struct TPMV_<'a, 'x, F>
89where
90    F: TPMVNum,
91{
92    pub ap: ArrayView1<'a, F>,
93    pub x: ArrayViewMut1<'x, F>,
94
95    #[builder(setter(into), default = "BLASUpper")]
96    pub uplo: BLASUpLo,
97    #[builder(setter(into), default = "BLASNoTrans")]
98    pub trans: BLASTranspose,
99    #[builder(setter(into), default = "BLASNonUnit")]
100    pub diag: BLASDiag,
101    #[builder(setter(into, strip_option), default = "None")]
102    pub layout: Option<BLASLayout>,
103}
104
105impl<'a, 'x, F> BLASBuilder_<'x, F, Ix1> for TPMV_<'a, 'x, F>
106where
107    F: TPMVNum,
108{
109    fn driver(self) -> Result<TPMV_Driver<'a, 'x, F>, BLASError> {
110        let Self { ap, x, uplo, trans, diag, layout } = self;
111
112        // only fortran-preferred (col-major) is accepted in inner wrapper
113        let incap = ap.stride_of(Axis(0));
114        assert!(incap <= 1);
115        assert_eq!(layout, Some(BLASColMajor));
116
117        // initialize intent(hide)
118        let np = ap.len_of(Axis(0));
119        let n = x.len_of(Axis(0));
120        let incx = x.stride_of(Axis(0));
121
122        // perform check
123        blas_assert_eq!(np, n * (n + 1) / 2, InvalidDim)?;
124
125        // prepare output
126        let x = ArrayOut1::ViewMut(x);
127
128        // finalize
129        let driver = TPMV_Driver {
130            uplo: uplo.try_into()?,
131            trans: trans.try_into()?,
132            diag: diag.try_into()?,
133            n: n.try_into()?,
134            ap,
135            x,
136            incx: incx.try_into()?,
137        };
138        return Ok(driver);
139    }
140}
141
142/* #endregion */
143
144/* #region BLAS wrapper */
145
146pub type TPMV<'a, 'x, F> = TPMV_Builder<'a, 'x, F>;
147pub type STPMV<'a, 'x> = TPMV<'a, 'x, f32>;
148pub type DTPMV<'a, 'x> = TPMV<'a, 'x, f64>;
149pub type CTPMV<'a, 'x> = TPMV<'a, 'x, c32>;
150pub type ZTPMV<'a, 'x> = TPMV<'a, 'x, c64>;
151
152impl<'a, 'x, F> BLASBuilder<'x, F, Ix1> for TPMV_Builder<'a, 'x, F>
153where
154    F: TPMVNum,
155{
156    fn run(self) -> Result<ArrayOut1<'x, F>, BLASError> {
157        // initialize
158        let obj = self.build()?;
159
160        let layout = obj.layout.unwrap_or(BLASRowMajor);
161
162        if layout == BLASColMajor {
163            // F-contiguous
164            let ap_cow = obj.ap.to_seq_layout()?;
165            let obj = TPMV_ { ap: ap_cow.view(), layout: Some(BLASColMajor), ..obj };
166            return obj.driver()?.run_blas();
167        } else {
168            // C-contiguous
169            let ap_cow = obj.ap.to_seq_layout()?;
170            match obj.trans {
171                BLASNoTrans => {
172                    // N -> T
173                    let obj = TPMV_ {
174                        ap: ap_cow.view(),
175                        trans: BLASTrans,
176                        uplo: obj.uplo.flip()?,
177                        layout: Some(BLASColMajor),
178                        ..obj
179                    };
180                    return obj.driver()?.run_blas();
181                },
182                BLASTrans => {
183                    // T -> N
184                    let obj = TPMV_ {
185                        ap: ap_cow.view(),
186                        trans: BLASNoTrans,
187                        uplo: obj.uplo.flip()?,
188                        layout: Some(BLASColMajor),
189                        ..obj
190                    };
191                    return obj.driver()?.run_blas();
192                },
193                BLASConjTrans => {
194                    // C -> N
195                    let mut x = obj.x;
196                    x.mapv_inplace(F::conj);
197                    let obj = TPMV_ {
198                        ap: ap_cow.view(),
199                        x,
200                        trans: BLASNoTrans,
201                        uplo: obj.uplo.flip()?,
202                        layout: Some(BLASColMajor),
203                        ..obj
204                    };
205                    let mut x = obj.driver()?.run_blas()?;
206                    x.view_mut().mapv_inplace(F::conj);
207                    return Ok(x);
208                },
209                _ => return blas_invalid!(obj.trans)?,
210            }
211        }
212    }
213}
214
215/* #endregion */