1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6pub trait TBSVNum: BLASFloat {
9 unsafe fn tbsv(
10 uplo: *const c_char,
11 trans: *const c_char,
12 diag: *const c_char,
13 n: *const blas_int,
14 k: *const blas_int,
15 a: *const Self,
16 lda: *const blas_int,
17 x: *mut Self,
18 incx: *const blas_int,
19 );
20}
21
22macro_rules! impl_func {
23 ($type: ty, $func: ident) => {
24 impl TBSVNum for $type {
25 unsafe fn tbsv(
26 uplo: *const c_char,
27 trans: *const c_char,
28 diag: *const c_char,
29 n: *const blas_int,
30 k: *const blas_int,
31 a: *const Self,
32 lda: *const blas_int,
33 x: *mut Self,
34 incx: *const blas_int,
35 ) {
36 ffi::$func(uplo, trans, diag, n, k, a, lda, x, incx);
37 }
38 }
39 };
40}
41
42impl_func!(f32, stbsv_);
43impl_func!(f64, dtbsv_);
44impl_func!(c32, ctbsv_);
45impl_func!(c64, ztbsv_);
46
47pub struct TBSV_Driver<'a, 'x, F>
52where
53 F: TBSVNum,
54{
55 uplo: c_char,
56 trans: c_char,
57 diag: c_char,
58 n: blas_int,
59 k: blas_int,
60 a: ArrayView2<'a, F>,
61 lda: blas_int,
62 x: ArrayOut1<'x, F>,
63 incx: blas_int,
64}
65
66impl<'a, 'x, F> BLASDriver<'x, F, Ix1> for TBSV_Driver<'a, 'x, F>
67where
68 F: TBSVNum,
69{
70 fn run_blas(self) -> Result<ArrayOut1<'x, F>, BLASError> {
71 let Self { uplo, trans, diag, n, k, a, lda, mut x, incx } = self;
72 let a_ptr = a.as_ptr();
73 let x_ptr = x.get_data_mut_ptr();
74
75 if n == 0 {
78 return Ok(x);
79 }
80
81 unsafe {
82 F::tbsv(&uplo, &trans, &diag, &n, &k, a_ptr, &lda, x_ptr, &incx);
83 }
84 return Ok(x);
85 }
86}
87
88#[derive(Builder)]
93#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
94pub struct TBSV_<'a, 'x, F>
95where
96 F: TBSVNum,
97{
98 pub a: ArrayView2<'a, F>,
99 pub x: ArrayViewMut1<'x, F>,
100
101 #[builder(setter(into), default = "BLASUpper")]
102 pub uplo: BLASUpLo,
103 #[builder(setter(into), default = "BLASNoTrans")]
104 pub trans: BLASTranspose,
105 #[builder(setter(into), default = "BLASNonUnit")]
106 pub diag: BLASDiag,
107 #[builder(setter(into, strip_option), default = "None")]
108 pub layout: Option<BLASLayout>,
109}
110
111impl<'a, 'x, F> BLASBuilder_<'x, F, Ix1> for TBSV_<'a, 'x, F>
112where
113 F: TBSVNum,
114{
115 fn driver(self) -> Result<TBSV_Driver<'a, 'x, F>, BLASError> {
116 let Self { a, x, uplo, trans, diag, layout } = self;
117
118 let layout_a = get_layout_array2(&a);
120 assert!(layout_a.is_fpref());
121 assert!(layout == Some(BLASLayout::ColMajor));
122
123 let (k_, n) = a.dim();
125 blas_assert!(k_ > 0, InvalidDim, "Rows of input `a` must larger than zero.")?;
126 let k = k_ - 1;
127 let lda = a.stride_of(Axis(1));
128 let incx = x.stride_of(Axis(0));
129
130 blas_assert_eq!(x.len_of(Axis(0)), n, InvalidDim)?;
132
133 let x = ArrayOut1::ViewMut(x);
135
136 let driver = TBSV_Driver {
138 uplo: uplo.try_into()?,
139 trans: trans.try_into()?,
140 diag: diag.try_into()?,
141 n: n.try_into()?,
142 k: k.try_into()?,
143 a,
144 lda: lda.try_into()?,
145 x,
146 incx: incx.try_into()?,
147 };
148 return Ok(driver);
149 }
150}
151
152pub type TBSV<'a, 'x, F> = TBSV_Builder<'a, 'x, F>;
157pub type STBSV<'a, 'x> = TBSV<'a, 'x, f32>;
158pub type DTBSV<'a, 'x> = TBSV<'a, 'x, f64>;
159pub type CTBSV<'a, 'x> = TBSV<'a, 'x, c32>;
160pub type ZTBSV<'a, 'x> = TBSV<'a, 'x, c64>;
161
162impl<'a, 'x, F> BLASBuilder<'x, F, Ix1> for TBSV_Builder<'a, 'x, F>
163where
164 F: TBSVNum,
165{
166 fn run(self) -> Result<ArrayOut1<'x, F>, BLASError> {
167 let obj = self.build()?;
169
170 let layout_a = get_layout_array2(&obj.a);
171 let layout = get_layout_row_preferred(&[obj.layout, Some(layout_a)], &[]);
172
173 if layout == BLASColMajor {
174 let a_cow = obj.a.to_col_layout()?;
176 let obj = TBSV_ { a: a_cow.view(), layout: Some(BLASColMajor), ..obj };
177 return obj.driver()?.run_blas();
178 } else {
179 let a_cow = obj.a.to_row_layout()?;
181 match obj.trans {
182 BLASNoTrans => {
183 let obj = TBSV_ {
185 a: a_cow.t(),
186 trans: BLASTrans,
187 uplo: obj.uplo.flip()?,
188 layout: Some(BLASColMajor),
189 ..obj
190 };
191 return obj.driver()?.run_blas();
192 },
193 BLASTrans => {
194 let obj = TBSV_ {
196 a: a_cow.t(),
197 trans: BLASNoTrans,
198 uplo: obj.uplo.flip()?,
199 layout: Some(BLASColMajor),
200 ..obj
201 };
202 return obj.driver()?.run_blas();
203 },
204 BLASConjTrans => {
205 let mut x = obj.x;
207 x.mapv_inplace(F::conj);
208 let obj = TBSV_ {
209 a: a_cow.t(),
210 x,
211 trans: BLASNoTrans,
212 uplo: obj.uplo.flip()?,
213 layout: Some(BLASColMajor),
214 ..obj
215 };
216 let mut x = obj.driver()?.run_blas()?;
217 x.view_mut().mapv_inplace(F::conj);
218 return Ok(x);
219 },
220 _ => return blas_invalid!(obj.trans)?,
221 }
222 }
223 }
224}
225
226