1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6pub trait TRMVNum: BLASFloat {
9 unsafe fn trmv(
10 uplo: *const c_char,
11 trans: *const c_char,
12 diag: *const c_char,
13 n: *const blas_int,
14 a: *const Self,
15 lda: *const blas_int,
16 x: *mut Self,
17 incx: *const blas_int,
18 );
19}
20
21macro_rules! impl_func {
22 ($type: ty, $func: ident) => {
23 impl TRMVNum for $type {
24 unsafe fn trmv(
25 uplo: *const c_char,
26 trans: *const c_char,
27 diag: *const c_char,
28 n: *const blas_int,
29 a: *const Self,
30 lda: *const blas_int,
31 x: *mut Self,
32 incx: *const blas_int,
33 ) {
34 ffi::$func(uplo, trans, diag, n, a, lda, x, incx);
35 }
36 }
37 };
38}
39
40impl_func!(f32, strmv_);
41impl_func!(f64, dtrmv_);
42impl_func!(c32, ctrmv_);
43impl_func!(c64, ztrmv_);
44
45pub struct TRMV_Driver<'a, 'x, F>
50where
51 F: TRMVNum,
52{
53 uplo: c_char,
54 trans: c_char,
55 diag: c_char,
56 n: blas_int,
57 a: ArrayView2<'a, F>,
58 lda: blas_int,
59 x: ArrayOut1<'x, F>,
60 incx: blas_int,
61}
62
63impl<'a, 'x, F> BLASDriver<'x, F, Ix1> for TRMV_Driver<'a, 'x, F>
64where
65 F: TRMVNum,
66{
67 fn run_blas(self) -> Result<ArrayOut1<'x, F>, BLASError> {
68 let Self { uplo, trans, diag, n, a, lda, mut x, incx } = self;
69 let a_ptr = a.as_ptr();
70 let x_ptr = x.get_data_mut_ptr();
71
72 if n == 0 {
75 return Ok(x);
76 }
77
78 unsafe {
79 F::trmv(&uplo, &trans, &diag, &n, a_ptr, &lda, x_ptr, &incx);
80 }
81 return Ok(x);
82 }
83}
84
85#[derive(Builder)]
90#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
91pub struct TRMV_<'a, 'x, F>
92where
93 F: TRMVNum,
94{
95 pub a: ArrayView2<'a, F>,
96 pub x: ArrayViewMut1<'x, F>,
97
98 #[builder(setter(into), default = "BLASUpper")]
99 pub uplo: BLASUpLo,
100 #[builder(setter(into), default = "BLASNoTrans")]
101 pub trans: BLASTranspose,
102 #[builder(setter(into), default = "BLASNonUnit")]
103 pub diag: BLASDiag,
104}
105
106impl<'a, 'x, F> BLASBuilder_<'x, F, Ix1> for TRMV_<'a, 'x, F>
107where
108 F: TRMVNum,
109{
110 fn driver(self) -> Result<TRMV_Driver<'a, 'x, F>, BLASError> {
111 let Self { a, x, uplo, trans, diag } = self;
112
113 let layout_a = get_layout_array2(&a);
115 assert!(layout_a.is_fpref());
116
117 let (n, n_) = a.dim();
119 let lda = a.stride_of(Axis(1));
120 let incx = x.stride_of(Axis(0));
121
122 blas_assert_eq!(n, n_, InvalidDim)?;
124 blas_assert_eq!(x.len_of(Axis(0)), n, InvalidDim)?;
125
126 let x = ArrayOut1::ViewMut(x);
128
129 let driver = TRMV_Driver {
131 uplo: uplo.try_into()?,
132 trans: trans.try_into()?,
133 diag: diag.try_into()?,
134 n: n.try_into()?,
135 a,
136 lda: lda.try_into()?,
137 x,
138 incx: incx.try_into()?,
139 };
140 return Ok(driver);
141 }
142}
143
144pub type TRMV<'a, 'x, F> = TRMV_Builder<'a, 'x, F>;
149pub type STRMV<'a, 'x> = TRMV<'a, 'x, f32>;
150pub type DTRMV<'a, 'x> = TRMV<'a, 'x, f64>;
151pub type CTRMV<'a, 'x> = TRMV<'a, 'x, c32>;
152pub type ZTRMV<'a, 'x> = TRMV<'a, 'x, c64>;
153
154impl<'a, 'x, F> BLASBuilder<'x, F, Ix1> for TRMV_Builder<'a, 'x, F>
155where
156 F: TRMVNum,
157{
158 fn run(self) -> Result<ArrayOut1<'x, F>, BLASError> {
159 let obj = self.build()?;
161
162 let layout_a = get_layout_array2(&obj.a);
163
164 if layout_a.is_fpref() {
165 return obj.driver()?.run_blas();
167 } else {
168 let a_cow = obj.a.to_row_layout()?;
170 match obj.trans {
171 BLASNoTrans => {
172 let obj = TRMV_ { a: a_cow.t(), trans: BLASTrans, uplo: obj.uplo.flip()?, ..obj };
174 return obj.driver()?.run_blas();
175 },
176 BLASTrans => {
177 let obj = TRMV_ { a: a_cow.t(), trans: BLASNoTrans, uplo: obj.uplo.flip()?, ..obj };
179 return obj.driver()?.run_blas();
180 },
181 BLASConjTrans => {
182 let mut x = obj.x;
184 x.mapv_inplace(F::conj);
185 let obj = TRMV_ { a: a_cow.t(), x, trans: BLASNoTrans, uplo: obj.uplo.flip()?, ..obj };
186 let mut x = obj.driver()?.run_blas()?;
187 x.view_mut().mapv_inplace(F::conj);
188 return Ok(x);
189 },
190 _ => return blas_invalid!(&obj.trans)?,
191 }
192 }
193 }
194}
195
196