1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6pub trait SYMMNum: BLASFloat {
9 unsafe fn symm(
10 side: *const c_char,
11 uplo: *const c_char,
12 m: *const blas_int,
13 n: *const blas_int,
14 alpha: *const Self,
15 a: *const Self,
16 lda: *const blas_int,
17 b: *const Self,
18 ldb: *const blas_int,
19 beta: *const Self,
20 c: *mut Self,
21 ldc: *const blas_int,
22 );
23}
24
25macro_rules! impl_func {
26 ($type: ty, $func: ident) => {
27 impl SYMMNum for $type {
28 unsafe fn symm(
29 side: *const c_char,
30 uplo: *const c_char,
31 m: *const blas_int,
32 n: *const blas_int,
33 alpha: *const Self,
34 a: *const Self,
35 lda: *const blas_int,
36 b: *const Self,
37 ldb: *const blas_int,
38 beta: *const Self,
39 c: *mut Self,
40 ldc: *const blas_int,
41 ) {
42 ffi::$func(side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc);
43 }
44 }
45 };
46}
47
48impl_func!(f32, ssymm_);
49impl_func!(f64, dsymm_);
50impl_func!(c32, csymm_);
51impl_func!(c64, zsymm_);
52
53pub struct SYMM_Driver<'a, 'b, 'c, F>
58where
59 F: SYMMNum,
60{
61 side: c_char,
62 uplo: c_char,
63 m: blas_int,
64 n: blas_int,
65 alpha: F,
66 a: ArrayView2<'a, F>,
67 lda: blas_int,
68 b: ArrayView2<'b, F>,
69 ldb: blas_int,
70 beta: F,
71 c: ArrayOut2<'c, F>,
72 ldc: blas_int,
73}
74
75impl<'a, 'b, 'c, F> BLASDriver<'c, F, Ix2> for SYMM_Driver<'a, 'b, 'c, F>
76where
77 F: SYMMNum,
78{
79 fn run_blas(self) -> Result<ArrayOut2<'c, F>, BLASError> {
80 let Self { side, uplo, m, n, alpha, a, lda, b, ldb, beta, mut c, ldc, .. } = self;
81 let a_ptr = a.as_ptr();
82 let b_ptr = b.as_ptr();
83 let c_ptr = c.get_data_mut_ptr();
84
85 if m == 0 || n == 0 {
88 return Ok(c.clone_to_view_mut());
89 }
90
91 unsafe {
92 F::symm(&side, &uplo, &m, &n, &alpha, a_ptr, &lda, b_ptr, &ldb, &beta, c_ptr, &ldc);
93 }
94 return Ok(c.clone_to_view_mut());
95 }
96}
97
98#[derive(Builder)]
103#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
104pub struct SYMM_<'a, 'b, 'c, F>
105where
106 F: BLASFloat,
107{
108 pub a: ArrayView2<'a, F>,
109 pub b: ArrayView2<'b, F>,
110
111 #[builder(setter(into, strip_option), default = "None")]
112 pub c: Option<ArrayViewMut2<'c, F>>,
113 #[builder(setter(into), default = "F::one()")]
114 pub alpha: F,
115 #[builder(setter(into), default = "F::zero()")]
116 pub beta: F,
117 #[builder(setter(into), default = "BLASLeft")]
118 pub side: BLASSide,
119 #[builder(setter(into), default = "BLASLower")]
120 pub uplo: BLASUpLo,
121 #[builder(setter(into, strip_option), default = "None")]
122 pub layout: Option<BLASLayout>,
123}
124
125impl<'a, 'b, 'c, F> BLASBuilder_<'c, F, Ix2> for SYMM_<'a, 'b, 'c, F>
126where
127 F: SYMMNum,
128{
129 fn driver(self) -> Result<SYMM_Driver<'a, 'b, 'c, F>, BLASError> {
130 let Self { a, b, c, alpha, beta, side, uplo, layout, .. } = self;
131
132 assert_eq!(layout, Some(BLASColMajor));
134 assert!(a.is_fpref() && a.is_fpref());
135
136 let m = b.len_of(Axis(0));
138 let n = b.len_of(Axis(1));
139 let lda = a.stride_of(Axis(1));
140 let ldb = b.stride_of(Axis(1));
141
142 match side {
144 BLASLeft => blas_assert_eq!(a.dim(), (m, m), InvalidDim)?,
145 BLASRight => blas_assert_eq!(a.dim(), (n, n), InvalidDim)?,
146 _ => blas_invalid!(side)?,
147 }
148
149 let c = match c {
151 Some(c) => {
152 blas_assert_eq!(c.dim(), (m, n), InvalidDim)?;
153 if c.view().is_fpref() {
154 ArrayOut2::ViewMut(c)
155 } else {
156 let c_buffer = c.view().to_col_layout()?.into_owned();
157 ArrayOut2::ToBeCloned(c, c_buffer)
158 }
159 },
160 None => ArrayOut2::Owned(Array2::zeros((m, n).f())),
161 };
162 let ldc = c.view().stride_of(Axis(1));
163
164 let driver = SYMM_Driver::<'a, 'b, 'c, F> {
166 side: side.try_into()?,
167 uplo: uplo.try_into()?,
168 m: m.try_into()?,
169 n: n.try_into()?,
170 alpha,
171 a,
172 lda: lda.try_into()?,
173 b,
174 ldb: ldb.try_into()?,
175 beta,
176 c,
177 ldc: ldc.try_into()?,
178 };
179 return Ok(driver);
180 }
181}
182
183pub type SYMM<'a, 'b, 'c, F> = SYMM_Builder<'a, 'b, 'c, F>;
188pub type SSYMM<'a, 'b, 'c> = SYMM<'a, 'b, 'c, f32>;
189pub type DSYMM<'a, 'b, 'c> = SYMM<'a, 'b, 'c, f64>;
190pub type CSYMM<'a, 'b, 'c> = SYMM<'a, 'b, 'c, c32>;
191pub type ZSYMM<'a, 'b, 'c> = SYMM<'a, 'b, 'c, c64>;
192
193impl<'a, 'b, 'c, F> BLASBuilder<'c, F, Ix2> for SYMM_Builder<'a, 'b, 'c, F>
194where
195 F: SYMMNum,
196{
197 fn run(self) -> Result<ArrayOut2<'c, F>, BLASError> {
198 let SYMM_ { a, b, c, alpha, beta, side, uplo, layout, .. } = self.build()?;
200 let at = a.t();
201
202 let layout_a = get_layout_array2(&a);
203 let layout_b = get_layout_array2(&b);
204 let layout_c = c.as_ref().map(|c| get_layout_array2(&c.view()));
205
206 let layout = get_layout_row_preferred(&[layout, layout_c], &[layout_a, layout_b]);
207 if layout == BLASColMajor {
208 let (uplo, a_cow) = match layout_a.is_fpref() {
210 true => (uplo, a.to_col_layout()?),
211 false => (uplo.flip()?, at.to_col_layout()?),
212 };
213 let b_cow = b.to_col_layout()?;
214 let obj = SYMM_ {
215 a: a_cow.view(),
216 b: b_cow.view(),
217 c,
218 alpha,
219 beta,
220 side,
221 uplo,
222 layout: Some(BLASColMajor),
223 };
224 return obj.driver()?.run_blas();
225 } else {
226 let (uplo, a_cow) = match layout_a.is_cpref() {
228 true => (uplo, a.to_row_layout()?),
229 false => (uplo.flip()?, at.to_row_layout()?),
230 };
231 let b_cow = b.to_row_layout()?;
232 let obj = SYMM_ {
233 a: a_cow.t(),
234 b: b_cow.t(),
235 c: c.map(|c| c.reversed_axes()),
236 alpha,
237 beta,
238 side: side.flip()?,
239 uplo: uplo.flip()?,
240 layout: Some(BLASColMajor),
241 };
242 let c = obj.driver()?.run_blas()?.reversed_axes();
243 return Ok(c);
244 }
245 }
246}
247
248