1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6pub trait SYRKNum: BLASFloat {
9 unsafe fn syrk(
10 uplo: *const c_char,
11 trans: *const c_char,
12 n: *const blas_int,
13 k: *const blas_int,
14 alpha: *const Self,
15 a: *const Self,
16 lda: *const blas_int,
17 beta: *const Self,
18 c: *mut Self,
19 ldc: *const blas_int,
20 );
21}
22
23macro_rules! impl_syrk {
24 ($type: ty, $func: ident) => {
25 impl SYRKNum for $type {
26 unsafe fn syrk(
27 uplo: *const c_char,
28 trans: *const c_char,
29 n: *const blas_int,
30 k: *const blas_int,
31 alpha: *const Self,
32 a: *const Self,
33 lda: *const blas_int,
34 beta: *const Self,
35 c: *mut Self,
36 ldc: *const blas_int,
37 ) {
38 ffi::$func(uplo, trans, n, k, alpha, a, lda, beta, c, ldc);
39 }
40 }
41 };
42}
43
44impl_syrk!(f32, ssyrk_);
45impl_syrk!(f64, dsyrk_);
46impl_syrk!(c32, csyrk_);
47impl_syrk!(c64, zsyrk_);
48
49pub struct SYRK_Driver<'a, 'c, F>
54where
55 F: BLASFloat,
56{
57 uplo: c_char,
58 trans: c_char,
59 n: blas_int,
60 k: blas_int,
61 alpha: F,
62 a: ArrayView2<'a, F>,
63 lda: blas_int,
64 beta: F,
65 c: ArrayOut2<'c, F>,
66 ldc: blas_int,
67}
68
69impl<'a, 'c, F> BLASDriver<'c, F, Ix2> for SYRK_Driver<'a, 'c, F>
70where
71 F: SYRKNum,
72{
73 fn run_blas(self) -> Result<ArrayOut2<'c, F>, BLASError> {
74 let Self { uplo, trans, n, k, alpha, a, lda, beta, mut c, ldc } = self;
75 let a_ptr = a.as_ptr();
76 let c_ptr = c.get_data_mut_ptr();
77
78 if n == 0 {
81 return Ok(c.clone_to_view_mut());
82 } else if k == 0 {
83 let beta_f = F::from(beta);
84 if uplo == BLASLower.try_into()? {
85 for i in 0..n {
86 c.view_mut().slice_mut(s![i.., i]).mapv_inplace(|v| v * beta_f);
87 }
88 } else if uplo == BLASUpper.try_into()? {
89 for i in 0..n {
90 c.view_mut().slice_mut(s![..=i, i]).mapv_inplace(|v| v * beta_f);
91 }
92 } else {
93 blas_invalid!(uplo)?
94 }
95 return Ok(c.clone_to_view_mut());
96 }
97
98 unsafe {
99 F::syrk(&uplo, &trans, &n, &k, &alpha, a_ptr, &lda, &beta, c_ptr, &ldc);
100 }
101 return Ok(c.clone_to_view_mut());
102 }
103}
104
105#[derive(Builder)]
110#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
111pub struct SYRK_<'a, 'c, F>
112where
113 F: SYRKNum,
114{
115 pub a: ArrayView2<'a, F>,
116
117 #[builder(setter(into, strip_option), default = "None")]
118 pub c: Option<ArrayViewMut2<'c, F>>,
119 #[builder(setter(into), default = "F::one()")]
120 pub alpha: F,
121 #[builder(setter(into), default = "F::zero()")]
122 pub beta: F,
123 #[builder(setter(into), default = "BLASLower")]
124 pub uplo: BLASUpLo,
125 #[builder(setter(into), default = "BLASNoTrans")]
126 pub trans: BLASTranspose,
127 #[builder(setter(into, strip_option), default = "None")]
128 pub layout: Option<BLASLayout>,
129}
130
131impl<'a, 'c, F> BLASBuilder_<'c, F, Ix2> for SYRK_<'a, 'c, F>
132where
133 F: SYRKNum,
134{
135 fn driver(self) -> Result<SYRK_Driver<'a, 'c, F>, BLASError> {
136 let Self { a, c, alpha, beta, uplo, trans, layout } = self;
137
138 assert_eq!(layout, Some(BLASColMajor));
140 assert!(a.is_fpref());
141
142 let (n, k) = match trans {
144 BLASNoTrans => (a.len_of(Axis(0)), a.len_of(Axis(1))),
145 BLASTrans | BLASConjTrans => (a.len_of(Axis(1)), a.len_of(Axis(0))),
146 _ => blas_invalid!(trans)?,
147 };
148 let lda = a.stride_of(Axis(1));
149
150 match F::is_complex() {
152 false => match trans {
153 BLASNoTrans | BLASTrans | BLASConjTrans => (),
155 _ => blas_invalid!(trans)?,
156 },
157 true => match trans {
158 BLASNoTrans | BLASTrans => (),
160 _ => blas_invalid!(trans)?,
161 },
162 };
163
164 let c = match c {
166 Some(c) => {
167 blas_assert_eq!(c.dim(), (n, n), InvalidDim)?;
168 if c.view().is_fpref() {
169 ArrayOut2::ViewMut(c)
170 } else {
171 let c_buffer = c.view().to_col_layout()?.into_owned();
172 ArrayOut2::ToBeCloned(c, c_buffer)
173 }
174 },
175 None => ArrayOut2::Owned(Array2::zeros((n, n).f())),
176 };
177 let ldc = c.view().stride_of(Axis(1));
178
179 let driver = SYRK_Driver {
181 uplo: uplo.try_into()?,
182 trans: trans.try_into()?,
183 n: n.try_into()?,
184 k: k.try_into()?,
185 alpha,
186 a,
187 lda: lda.try_into()?,
188 beta,
189 c,
190 ldc: ldc.try_into()?,
191 };
192 return Ok(driver);
193 }
194}
195
196pub type SYRK<'a, 'c, F> = SYRK_Builder<'a, 'c, F>;
201pub type SSYRK<'a, 'c> = SYRK<'a, 'c, f32>;
202pub type DSYRK<'a, 'c> = SYRK<'a, 'c, f64>;
203pub type CSYRK<'a, 'c> = SYRK<'a, 'c, c32>;
204pub type ZSYRK<'a, 'c> = SYRK<'a, 'c, c64>;
205
206impl<'a, 'c, F> BLASBuilder<'c, F, Ix2> for SYRK_Builder<'a, 'c, F>
207where
208 F: SYRKNum,
209{
210 fn run(self) -> Result<ArrayOut2<'c, F>, BLASError> {
211 let SYRK_ { a, c, alpha, beta, uplo, trans, layout } = self.build()?;
213 let at = a.t();
214
215 match F::is_complex() {
218 false => match trans {
219 BLASNoTrans | BLASTrans | BLASConjTrans => (),
221 _ => blas_invalid!(trans)?,
222 },
223 true => match trans {
224 BLASNoTrans | BLASTrans => (),
226 _ => blas_invalid!(trans)?,
227 },
228 };
229
230 let layout_a = get_layout_array2(&a);
231 let layout_c = c.as_ref().map(|c| get_layout_array2(&c.view()));
232
233 let layout = get_layout_row_preferred(&[layout, layout_c], &[layout_a]);
234 if layout == BLASColMajor {
235 let (trans, a_cow) = flip_trans_fpref(trans, &a, &at, false)?;
237 let obj = SYRK_ { a: a_cow.view(), c, alpha, beta, uplo, trans, layout: Some(BLASColMajor) };
238 return obj.driver()?.run_blas();
239 } else if layout == BLASRowMajor {
240 let (trans, a_cow) = flip_trans_cpref(trans, &a, &at, false)?;
241 let obj = SYRK_ {
242 a: a_cow.t(),
243 c: c.map(|c| c.reversed_axes()),
244 alpha,
245 beta,
246 uplo: uplo.flip()?,
247 trans: trans.flip(false)?,
248 layout: Some(BLASColMajor),
249 };
250 return Ok(obj.driver()?.run_blas()?.reversed_axes());
251 } else {
252 return blas_raise!(RuntimeError, "This is designed not to execuate this line.");
253 }
254 }
255}
256
257