1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6pub trait SYR2KNum: BLASFloat {
9 unsafe fn syr2k(
10 uplo: *const c_char,
11 trans: *const c_char,
12 n: *const blas_int,
13 k: *const blas_int,
14 alpha: *const Self,
15 a: *const Self,
16 lda: *const blas_int,
17 b: *const Self,
18 ldb: *const blas_int,
19 beta: *const Self,
20 c: *mut Self,
21 ldc: *const blas_int,
22 );
23}
24
25macro_rules! impl_syr2k {
26 ($type: ty, $func: ident) => {
27 impl SYR2KNum for $type {
28 unsafe fn syr2k(
29 uplo: *const c_char,
30 trans: *const c_char,
31 n: *const blas_int,
32 k: *const blas_int,
33 alpha: *const Self,
34 a: *const Self,
35 lda: *const blas_int,
36 b: *const Self,
37 ldb: *const blas_int,
38 beta: *const Self,
39 c: *mut Self,
40 ldc: *const blas_int,
41 ) {
42 ffi::$func(uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
43 }
44 }
45 };
46}
47
48impl_syr2k!(f32, ssyr2k_);
49impl_syr2k!(f64, dsyr2k_);
50impl_syr2k!(c32, csyr2k_);
51impl_syr2k!(c64, zsyr2k_);
52
53pub struct SYR2K_Driver<'a, 'b, 'c, F>
58where
59 F: SYR2KNum,
60{
61 uplo: c_char,
62 trans: c_char,
63 n: blas_int,
64 k: blas_int,
65 alpha: F,
66 a: ArrayView2<'a, F>,
67 lda: blas_int,
68 b: ArrayView2<'b, F>,
69 ldb: blas_int,
70 beta: F,
71 c: ArrayOut2<'c, F>,
72 ldc: blas_int,
73}
74
75impl<'a, 'b, 'c, F> BLASDriver<'c, F, Ix2> for SYR2K_Driver<'a, 'b, 'c, F>
76where
77 F: SYR2KNum,
78{
79 fn run_blas(self) -> Result<ArrayOut2<'c, F>, BLASError> {
80 let Self { uplo, trans, n, k, alpha, a, lda, b, ldb, beta, mut c, ldc } = self;
81 let a_ptr = a.as_ptr();
82 let b_ptr = b.as_ptr();
83 let c_ptr = c.get_data_mut_ptr();
84
85 if n == 0 {
88 return Ok(c.clone_to_view_mut());
89 } else if k == 0 {
90 let beta_f = F::from(beta);
91 if uplo == BLASLower.try_into()? {
92 for i in 0..n {
93 c.view_mut().slice_mut(s![i.., i]).mapv_inplace(|v| v * beta_f);
94 }
95 } else if uplo == BLASUpper.try_into()? {
96 for i in 0..n {
97 c.view_mut().slice_mut(s![..=i, i]).mapv_inplace(|v| v * beta_f);
98 }
99 } else {
100 blas_invalid!(uplo)?
101 }
102 return Ok(c.clone_to_view_mut());
103 }
104
105 unsafe {
106 F::syr2k(&uplo, &trans, &n, &k, &alpha, a_ptr, &lda, b_ptr, &ldb, &beta, c_ptr, &ldc);
107 }
108 return Ok(c.clone_to_view_mut());
109 }
110}
111
112#[derive(Builder)]
117#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
118pub struct SYR2K_<'a, 'b, 'c, F>
119where
120 F: SYR2KNum,
121{
122 pub a: ArrayView2<'a, F>,
123 pub b: ArrayView2<'b, F>,
124
125 #[builder(setter(into, strip_option), default = "None")]
126 pub c: Option<ArrayViewMut2<'c, F>>,
127 #[builder(setter(into), default = "F::one()")]
128 pub alpha: F,
129 #[builder(setter(into), default = "F::zero()")]
130 pub beta: F,
131 #[builder(setter(into), default = "BLASLower")]
132 pub uplo: BLASUpLo,
133 #[builder(setter(into), default = "BLASNoTrans")]
134 pub trans: BLASTranspose,
135 #[builder(setter(into, strip_option), default = "None")]
136 pub layout: Option<BLASLayout>,
137}
138
139impl<'a, 'b, 'c, F> BLASBuilder_<'c, F, Ix2> for SYR2K_<'a, 'b, 'c, F>
140where
141 F: SYR2KNum,
142{
143 fn driver(self) -> Result<SYR2K_Driver<'a, 'b, 'c, F>, BLASError> {
144 let Self { a, b, c, alpha, beta, uplo, trans, layout } = self;
145
146 assert_eq!(layout, Some(BLASColMajor));
148 assert!(a.is_fpref() && a.is_fpref());
149
150 let (n, k) = match trans {
152 BLASNoTrans => (a.len_of(Axis(0)), a.len_of(Axis(1))),
153 BLASTrans | BLASConjTrans => (a.len_of(Axis(1)), a.len_of(Axis(0))),
154 _ => blas_invalid!(trans)?,
155 };
156 let lda = a.stride_of(Axis(1));
157 let ldb = b.stride_of(Axis(1));
158
159 match trans {
162 BLASNoTrans => blas_assert_eq!(b.dim(), (n, k), InvalidDim)?,
163 BLASTrans | BLASConjTrans => blas_assert_eq!(b.dim(), (k, n), InvalidDim)?,
164 _ => blas_invalid!(trans)?,
165 };
166 match F::is_complex() {
168 false => match trans {
169 BLASNoTrans | BLASTrans | BLASConjTrans => (),
171 _ => blas_invalid!(trans)?,
172 },
173 true => match trans {
174 BLASNoTrans | BLASTrans => (),
176 _ => blas_invalid!(trans)?,
177 },
178 };
179
180 let c = match c {
182 Some(c) => {
183 blas_assert_eq!(c.dim(), (n, n), InvalidDim)?;
184 if c.view().is_fpref() {
185 ArrayOut2::ViewMut(c)
186 } else {
187 let c_buffer = c.view().to_col_layout()?.into_owned();
188 ArrayOut2::ToBeCloned(c, c_buffer)
189 }
190 },
191 None => ArrayOut2::Owned(Array2::zeros((n, n).f())),
192 };
193 let ldc = c.view().stride_of(Axis(1));
194
195 let driver = SYR2K_Driver {
197 uplo: uplo.try_into()?,
198 trans: trans.try_into()?,
199 n: n.try_into()?,
200 k: k.try_into()?,
201 alpha,
202 a,
203 lda: lda.try_into()?,
204 b,
205 ldb: ldb.try_into()?,
206 beta,
207 c,
208 ldc: ldc.try_into()?,
209 };
210 return Ok(driver);
211 }
212}
213
214pub type SYR2K<'a, 'b, 'c, F> = SYR2K_Builder<'a, 'b, 'c, F>;
219pub type SSYR2K<'a, 'b, 'c> = SYR2K<'a, 'b, 'c, f32>;
220pub type DSYR2K<'a, 'b, 'c> = SYR2K<'a, 'b, 'c, f64>;
221pub type CSYR2K<'a, 'b, 'c> = SYR2K<'a, 'b, 'c, c32>;
222pub type ZSYR2K<'a, 'b, 'c> = SYR2K<'a, 'b, 'c, c64>;
223
224impl<'a, 'b, 'c, F> BLASBuilder<'c, F, Ix2> for SYR2K_Builder<'a, 'b, 'c, F>
225where
226 F: SYR2KNum,
227{
228 fn run(self) -> Result<ArrayOut2<'c, F>, BLASError> {
229 let SYR2K_ { a, b, c, alpha, beta, uplo, trans, layout } = self.build()?;
231
232 match F::is_complex() {
235 false => match trans {
236 BLASNoTrans | BLASTrans | BLASConjTrans => (),
238 _ => blas_invalid!(trans)?,
239 },
240 true => match trans {
241 BLASNoTrans | BLASTrans => (),
243 _ => blas_invalid!(trans)?,
244 },
245 };
246
247 let layout_a = get_layout_array2(&a);
248 let layout_b = get_layout_array2(&b);
249 let layout_c = c.as_ref().map(|c| get_layout_array2(&c.view()));
250
251 let layout = get_layout_row_preferred(&[layout, layout_c], &[layout_a, layout_b]);
253 if layout == BLASColMajor {
254 let a_cow = a.to_col_layout()?;
256 let b_cow = b.to_col_layout()?;
257 let obj = SYR2K_ {
258 a: a_cow.view(),
259 b: b_cow.view(),
260 c,
261 alpha,
262 beta,
263 uplo,
264 trans,
265 layout: Some(BLASColMajor),
266 };
267 return obj.driver()?.run_blas();
268 } else if layout == BLASRowMajor {
269 let a_cow = a.to_row_layout()?;
271 let b_cow = b.to_row_layout()?;
272 let obj = SYR2K_ {
273 a: b_cow.t(),
274 b: a_cow.t(),
275 c: c.map(|c| c.reversed_axes()),
276 alpha,
277 beta,
278 uplo: uplo.flip()?,
279 trans: trans.flip(false)?,
280 layout: Some(BLASColMajor),
281 };
282 return Ok(obj.driver()?.run_blas()?.reversed_axes());
283 } else {
284 return blas_raise!(RuntimeError, "This is designed not to execuate this line.");
285 }
286 }
287}
288
289