blas_array2/blas2/
tbsv.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6/* #region BLAS func */
7
8pub trait TBSVNum: BLASFloat {
9    unsafe fn tbsv(
10        uplo: *const c_char,
11        trans: *const c_char,
12        diag: *const c_char,
13        n: *const blas_int,
14        k: *const blas_int,
15        a: *const Self,
16        lda: *const blas_int,
17        x: *mut Self,
18        incx: *const blas_int,
19    );
20}
21
22macro_rules! impl_func {
23    ($type: ty, $func: ident) => {
24        impl TBSVNum for $type {
25            unsafe fn tbsv(
26                uplo: *const c_char,
27                trans: *const c_char,
28                diag: *const c_char,
29                n: *const blas_int,
30                k: *const blas_int,
31                a: *const Self,
32                lda: *const blas_int,
33                x: *mut Self,
34                incx: *const blas_int,
35            ) {
36                ffi::$func(uplo, trans, diag, n, k, a, lda, x, incx);
37            }
38        }
39    };
40}
41
42impl_func!(f32, stbsv_);
43impl_func!(f64, dtbsv_);
44impl_func!(c32, ctbsv_);
45impl_func!(c64, ztbsv_);
46
47/* #endregion */
48
49/* #region BLAS driver */
50
51pub struct TBSV_Driver<'a, 'x, F>
52where
53    F: TBSVNum,
54{
55    uplo: c_char,
56    trans: c_char,
57    diag: c_char,
58    n: blas_int,
59    k: blas_int,
60    a: ArrayView2<'a, F>,
61    lda: blas_int,
62    x: ArrayOut1<'x, F>,
63    incx: blas_int,
64}
65
66impl<'a, 'x, F> BLASDriver<'x, F, Ix1> for TBSV_Driver<'a, 'x, F>
67where
68    F: TBSVNum,
69{
70    fn run_blas(self) -> Result<ArrayOut1<'x, F>, BLASError> {
71        let Self { uplo, trans, diag, n, k, a, lda, mut x, incx } = self;
72        let a_ptr = a.as_ptr();
73        let x_ptr = x.get_data_mut_ptr();
74
75        // assuming dimension checks has been performed
76        // unconditionally return Ok if output does not contain anything
77        if n == 0 {
78            return Ok(x);
79        }
80
81        unsafe {
82            F::tbsv(&uplo, &trans, &diag, &n, &k, a_ptr, &lda, x_ptr, &incx);
83        }
84        return Ok(x);
85    }
86}
87
88/* #endregion */
89
90/* #region BLAS builder */
91
92#[derive(Builder)]
93#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
94pub struct TBSV_<'a, 'x, F>
95where
96    F: TBSVNum,
97{
98    pub a: ArrayView2<'a, F>,
99    pub x: ArrayViewMut1<'x, F>,
100
101    #[builder(setter(into), default = "BLASUpper")]
102    pub uplo: BLASUpLo,
103    #[builder(setter(into), default = "BLASNoTrans")]
104    pub trans: BLASTranspose,
105    #[builder(setter(into), default = "BLASNonUnit")]
106    pub diag: BLASDiag,
107    #[builder(setter(into, strip_option), default = "None")]
108    pub layout: Option<BLASLayout>,
109}
110
111impl<'a, 'x, F> BLASBuilder_<'x, F, Ix1> for TBSV_<'a, 'x, F>
112where
113    F: TBSVNum,
114{
115    fn driver(self) -> Result<TBSV_Driver<'a, 'x, F>, BLASError> {
116        let Self { a, x, uplo, trans, diag, layout } = self;
117
118        // only fortran-preferred (col-major) is accepted in inner wrapper
119        let layout_a = get_layout_array2(&a);
120        assert!(layout_a.is_fpref());
121        assert!(layout == Some(BLASLayout::ColMajor));
122
123        // initialize intent(hide)
124        let (k_, n) = a.dim();
125        blas_assert!(k_ > 0, InvalidDim, "Rows of input `a` must larger than zero.")?;
126        let k = k_ - 1;
127        let lda = a.stride_of(Axis(1));
128        let incx = x.stride_of(Axis(0));
129
130        // perform check
131        blas_assert_eq!(x.len_of(Axis(0)), n, InvalidDim)?;
132
133        // prepare output
134        let x = ArrayOut1::ViewMut(x);
135
136        // finalize
137        let driver = TBSV_Driver {
138            uplo: uplo.try_into()?,
139            trans: trans.try_into()?,
140            diag: diag.try_into()?,
141            n: n.try_into()?,
142            k: k.try_into()?,
143            a,
144            lda: lda.try_into()?,
145            x,
146            incx: incx.try_into()?,
147        };
148        return Ok(driver);
149    }
150}
151
152/* #endregion */
153
154/* #region BLAS wrapper */
155
156pub type TBSV<'a, 'x, F> = TBSV_Builder<'a, 'x, F>;
157pub type STBSV<'a, 'x> = TBSV<'a, 'x, f32>;
158pub type DTBSV<'a, 'x> = TBSV<'a, 'x, f64>;
159pub type CTBSV<'a, 'x> = TBSV<'a, 'x, c32>;
160pub type ZTBSV<'a, 'x> = TBSV<'a, 'x, c64>;
161
162impl<'a, 'x, F> BLASBuilder<'x, F, Ix1> for TBSV_Builder<'a, 'x, F>
163where
164    F: TBSVNum,
165{
166    fn run(self) -> Result<ArrayOut1<'x, F>, BLASError> {
167        // initialize
168        let obj = self.build()?;
169
170        let layout_a = get_layout_array2(&obj.a);
171        let layout = get_layout_row_preferred(&[obj.layout, Some(layout_a)], &[]);
172
173        if layout == BLASColMajor {
174            // F-contiguous
175            let a_cow = obj.a.to_col_layout()?;
176            let obj = TBSV_ { a: a_cow.view(), layout: Some(BLASColMajor), ..obj };
177            return obj.driver()?.run_blas();
178        } else {
179            // C-contiguous
180            let a_cow = obj.a.to_row_layout()?;
181            match obj.trans {
182                BLASNoTrans => {
183                    // N -> T
184                    let obj = TBSV_ {
185                        a: a_cow.t(),
186                        trans: BLASTrans,
187                        uplo: obj.uplo.flip()?,
188                        layout: Some(BLASColMajor),
189                        ..obj
190                    };
191                    return obj.driver()?.run_blas();
192                },
193                BLASTrans => {
194                    // N -> T
195                    let obj = TBSV_ {
196                        a: a_cow.t(),
197                        trans: BLASNoTrans,
198                        uplo: obj.uplo.flip()?,
199                        layout: Some(BLASColMajor),
200                        ..obj
201                    };
202                    return obj.driver()?.run_blas();
203                },
204                BLASConjTrans => {
205                    // C -> N
206                    let mut x = obj.x;
207                    x.mapv_inplace(F::conj);
208                    let obj = TBSV_ {
209                        a: a_cow.t(),
210                        x,
211                        trans: BLASNoTrans,
212                        uplo: obj.uplo.flip()?,
213                        layout: Some(BLASColMajor),
214                        ..obj
215                    };
216                    let mut x = obj.driver()?.run_blas()?;
217                    x.view_mut().mapv_inplace(F::conj);
218                    return Ok(x);
219                },
220                _ => return blas_invalid!(obj.trans)?,
221            }
222        }
223    }
224}
225
226/* #endregion */