1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6pub trait TPMVNum: BLASFloat {
9 unsafe fn tpmv(
10 uplo: *const c_char,
11 trans: *const c_char,
12 diag: *const c_char,
13 n: *const blas_int,
14 ap: *const Self,
15 x: *mut Self,
16 incx: *const blas_int,
17 );
18}
19
20macro_rules! impl_func {
21 ($type: ty, $func: ident) => {
22 impl TPMVNum for $type {
23 unsafe fn tpmv(
24 uplo: *const c_char,
25 trans: *const c_char,
26 diag: *const c_char,
27 n: *const blas_int,
28 ap: *const Self,
29 x: *mut Self,
30 incx: *const blas_int,
31 ) {
32 ffi::$func(uplo, trans, diag, n, ap, x, incx);
33 }
34 }
35 };
36}
37
38impl_func!(f32, stpmv_);
39impl_func!(f64, dtpmv_);
40impl_func!(c32, ctpmv_);
41impl_func!(c64, ztpmv_);
42
43pub struct TPMV_Driver<'a, 'x, F>
48where
49 F: TPMVNum,
50{
51 uplo: c_char,
52 trans: c_char,
53 diag: c_char,
54 n: blas_int,
55 ap: ArrayView1<'a, F>,
56 x: ArrayOut1<'x, F>,
57 incx: blas_int,
58}
59
60impl<'a, 'x, F> BLASDriver<'x, F, Ix1> for TPMV_Driver<'a, 'x, F>
61where
62 F: TPMVNum,
63{
64 fn run_blas(self) -> Result<ArrayOut1<'x, F>, BLASError> {
65 let Self { uplo, trans, diag, n, ap, mut x, incx } = self;
66 let ap_ptr = ap.as_ptr();
67 let x_ptr = x.get_data_mut_ptr();
68
69 if n == 0 {
72 return Ok(x);
73 }
74
75 unsafe {
76 F::tpmv(&uplo, &trans, &diag, &n, ap_ptr, x_ptr, &incx);
77 }
78 return Ok(x);
79 }
80}
81
82#[derive(Builder)]
87#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
88pub struct TPMV_<'a, 'x, F>
89where
90 F: TPMVNum,
91{
92 pub ap: ArrayView1<'a, F>,
93 pub x: ArrayViewMut1<'x, F>,
94
95 #[builder(setter(into), default = "BLASUpper")]
96 pub uplo: BLASUpLo,
97 #[builder(setter(into), default = "BLASNoTrans")]
98 pub trans: BLASTranspose,
99 #[builder(setter(into), default = "BLASNonUnit")]
100 pub diag: BLASDiag,
101 #[builder(setter(into, strip_option), default = "None")]
102 pub layout: Option<BLASLayout>,
103}
104
105impl<'a, 'x, F> BLASBuilder_<'x, F, Ix1> for TPMV_<'a, 'x, F>
106where
107 F: TPMVNum,
108{
109 fn driver(self) -> Result<TPMV_Driver<'a, 'x, F>, BLASError> {
110 let Self { ap, x, uplo, trans, diag, layout } = self;
111
112 let incap = ap.stride_of(Axis(0));
114 assert!(incap <= 1);
115 assert_eq!(layout, Some(BLASColMajor));
116
117 let np = ap.len_of(Axis(0));
119 let n = x.len_of(Axis(0));
120 let incx = x.stride_of(Axis(0));
121
122 blas_assert_eq!(np, n * (n + 1) / 2, InvalidDim)?;
124
125 let x = ArrayOut1::ViewMut(x);
127
128 let driver = TPMV_Driver {
130 uplo: uplo.try_into()?,
131 trans: trans.try_into()?,
132 diag: diag.try_into()?,
133 n: n.try_into()?,
134 ap,
135 x,
136 incx: incx.try_into()?,
137 };
138 return Ok(driver);
139 }
140}
141
142pub type TPMV<'a, 'x, F> = TPMV_Builder<'a, 'x, F>;
147pub type STPMV<'a, 'x> = TPMV<'a, 'x, f32>;
148pub type DTPMV<'a, 'x> = TPMV<'a, 'x, f64>;
149pub type CTPMV<'a, 'x> = TPMV<'a, 'x, c32>;
150pub type ZTPMV<'a, 'x> = TPMV<'a, 'x, c64>;
151
152impl<'a, 'x, F> BLASBuilder<'x, F, Ix1> for TPMV_Builder<'a, 'x, F>
153where
154 F: TPMVNum,
155{
156 fn run(self) -> Result<ArrayOut1<'x, F>, BLASError> {
157 let obj = self.build()?;
159
160 let layout = obj.layout.unwrap_or(BLASRowMajor);
161
162 if layout == BLASColMajor {
163 let ap_cow = obj.ap.to_seq_layout()?;
165 let obj = TPMV_ { ap: ap_cow.view(), layout: Some(BLASColMajor), ..obj };
166 return obj.driver()?.run_blas();
167 } else {
168 let ap_cow = obj.ap.to_seq_layout()?;
170 match obj.trans {
171 BLASNoTrans => {
172 let obj = TPMV_ {
174 ap: ap_cow.view(),
175 trans: BLASTrans,
176 uplo: obj.uplo.flip()?,
177 layout: Some(BLASColMajor),
178 ..obj
179 };
180 return obj.driver()?.run_blas();
181 },
182 BLASTrans => {
183 let obj = TPMV_ {
185 ap: ap_cow.view(),
186 trans: BLASNoTrans,
187 uplo: obj.uplo.flip()?,
188 layout: Some(BLASColMajor),
189 ..obj
190 };
191 return obj.driver()?.run_blas();
192 },
193 BLASConjTrans => {
194 let mut x = obj.x;
196 x.mapv_inplace(F::conj);
197 let obj = TPMV_ {
198 ap: ap_cow.view(),
199 x,
200 trans: BLASNoTrans,
201 uplo: obj.uplo.flip()?,
202 layout: Some(BLASColMajor),
203 ..obj
204 };
205 let mut x = obj.driver()?.run_blas()?;
206 x.view_mut().mapv_inplace(F::conj);
207 return Ok(x);
208 },
209 _ => return blas_invalid!(obj.trans)?,
210 }
211 }
212 }
213}
214
215