1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6pub trait TRSMNum: BLASFloat {
9 unsafe fn trsm(
10 side: *const c_char,
11 uplo: *const c_char,
12 transa: *const c_char,
13 diag: *const c_char,
14 m: *const blas_int,
15 n: *const blas_int,
16 alpha: *const Self,
17 a: *const Self,
18 lda: *const blas_int,
19 b: *mut Self,
20 ldb: *const blas_int,
21 );
22}
23
24macro_rules! impl_func {
25 ($type: ty, $func: ident) => {
26 impl TRSMNum for $type {
27 unsafe fn trsm(
28 side: *const c_char,
29 uplo: *const c_char,
30 transa: *const c_char,
31 diag: *const c_char,
32 m: *const blas_int,
33 n: *const blas_int,
34 alpha: *const Self,
35 a: *const Self,
36 lda: *const blas_int,
37 b: *mut Self,
38 ldb: *const blas_int,
39 ) {
40 ffi::$func(side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb);
41 }
42 }
43 };
44}
45
46impl_func!(f32, strsm_);
47impl_func!(f64, dtrsm_);
48impl_func!(c32, ctrsm_);
49impl_func!(c64, ztrsm_);
50
51pub struct TRSM_Driver<'a, 'b, F>
56where
57 F: BLASFloat,
58{
59 side: c_char,
60 uplo: c_char,
61 transa: c_char,
62 diag: c_char,
63 m: blas_int,
64 n: blas_int,
65 alpha: F,
66 a: ArrayView2<'a, F>,
67 lda: blas_int,
68 b: ArrayOut2<'b, F>,
69 ldb: blas_int,
70}
71
72impl<'a, 'b, F> BLASDriver<'b, F, Ix2> for TRSM_Driver<'a, 'b, F>
73where
74 F: TRSMNum,
75{
76 fn run_blas(self) -> Result<ArrayOut2<'b, F>, BLASError> {
77 let Self { side, uplo, transa, diag, m, n, alpha, a, lda, mut b, ldb } = self;
78 let a_ptr = a.as_ptr();
79 let b_ptr = b.get_data_mut_ptr();
80
81 if m == 0 || n == 0 {
84 return Ok(b.clone_to_view_mut());
85 }
86
87 unsafe {
88 F::trsm(&side, &uplo, &transa, &diag, &m, &n, &alpha, a_ptr, &lda, b_ptr, &ldb);
89 }
90 return Ok(b.clone_to_view_mut());
91 }
92}
93
94#[derive(Builder)]
99#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
100pub struct TRSM_<'a, 'b, F>
101where
102 F: TRSMNum,
103{
104 pub a: ArrayView2<'a, F>,
105 pub b: ArrayViewMut2<'b, F>,
106
107 #[builder(setter(into), default = "F::one()")]
108 pub alpha: F,
109 #[builder(setter(into), default = "BLASLeft")]
110 pub side: BLASSide,
111 #[builder(setter(into), default = "BLASUpper")]
112 pub uplo: BLASUpLo,
113 #[builder(setter(into), default = "BLASNoTrans")]
114 pub transa: BLASTranspose,
115 #[builder(setter(into), default = "BLASNonUnit")]
116 pub diag: BLASDiag,
117 #[builder(setter(into, strip_option), default = "None")]
118 pub layout: Option<BLASLayout>,
119}
120
121impl<'a, 'b, F> BLASBuilder_<'b, F, Ix2> for TRSM_<'a, 'b, F>
122where
123 F: TRSMNum,
124{
125 fn driver(self) -> Result<TRSM_Driver<'a, 'b, F>, BLASError> {
126 let Self { a, b, alpha, side, uplo, transa, diag, layout } = self;
127
128 assert_eq!(layout, Some(BLASColMajor));
130 assert!(a.is_fpref());
131
132 let (m, n) = b.dim();
134 let lda = a.stride_of(Axis(1));
135
136 match side {
138 BLASLeft => blas_assert_eq!(a.dim(), (m, m), InvalidDim)?,
139 BLASRight => blas_assert_eq!(a.dim(), (n, n), InvalidDim)?,
140 _ => blas_invalid!(side)?,
141 };
142
143 let b = if b.view().is_fpref() {
145 ArrayOut2::ViewMut(b)
146 } else {
147 let b_buffer = b.view().to_col_layout()?.into_owned();
148 ArrayOut2::ToBeCloned(b, b_buffer)
149 };
150 let ldb = b.view().stride_of(Axis(1));
151
152 let driver = TRSM_Driver {
154 side: side.try_into()?,
155 uplo: uplo.try_into()?,
156 transa: transa.try_into()?,
157 diag: diag.try_into()?,
158 m: m.try_into()?,
159 n: n.try_into()?,
160 alpha,
161 a,
162 lda: lda.try_into()?,
163 b,
164 ldb: ldb.try_into()?,
165 };
166 return Ok(driver);
167 }
168}
169
170pub type TRSM<'a, 'b, F> = TRSM_Builder<'a, 'b, F>;
175pub type STRSM<'a, 'b> = TRSM<'a, 'b, f32>;
176pub type DTRSM<'a, 'b> = TRSM<'a, 'b, f64>;
177pub type CTRSM<'a, 'b> = TRSM<'a, 'b, c32>;
178pub type ZTRSM<'a, 'b> = TRSM<'a, 'b, c64>;
179
180impl<'a, 'b, F> BLASBuilder<'b, F, Ix2> for TRSM_Builder<'a, 'b, F>
181where
182 F: TRSMNum,
183{
184 fn run(self) -> Result<ArrayOut2<'b, F>, BLASError> {
185 let TRSM_ { a, b, alpha, side, uplo, transa, diag, layout } = self.build()?;
187 let at = a.t();
188
189 let layout_a = get_layout_array2(&a);
190 let layout_b = get_layout_array2(&b.view());
191
192 let layout = get_layout_row_preferred(&[layout, Some(layout_b)], &[layout_a]);
193 if layout == BLASColMajor {
194 let (transa_new, a_cow) = flip_trans_fpref(transa, &a, &at, false)?;
196 let uplo = if transa_new != transa { uplo.flip()? } else { uplo };
197 let obj = TRSM_ {
198 a: a_cow.view(),
199 b,
200 alpha,
201 side,
202 uplo,
203 transa: transa_new,
204 diag,
205 layout: Some(BLASColMajor),
206 };
207 return obj.driver()?.run_blas();
208 } else {
209 let (transa_new, a_cow) = flip_trans_cpref(transa, &a, &at, false)?;
211 let uplo = if transa_new != transa { uplo.flip()? } else { uplo };
212 let obj = TRSM_ {
213 a: a_cow.t(),
214 b: b.reversed_axes(),
215 alpha,
216 side: side.flip()?,
217 uplo: uplo.flip()?,
218 transa: transa_new,
219 diag,
220 layout: Some(BLASColMajor),
221 };
222 return Ok(obj.driver()?.run_blas()?.reversed_axes());
223 }
224 }
225}
226
227