1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5use num_traits::*;
6
7pub trait HER2KNum: BLASFloat {
10 unsafe fn her2k(
11 uplo: *const c_char,
12 trans: *const c_char,
13 n: *const blas_int,
14 k: *const blas_int,
15 alpha: *const Self,
16 a: *const Self,
17 lda: *const blas_int,
18 b: *const Self,
19 ldb: *const blas_int,
20 beta: *const Self::RealFloat,
21 c: *mut Self,
22 ldc: *const blas_int,
23 );
24}
25
26macro_rules! impl_her2k {
27 ($type: ty, $func: ident) => {
28 impl HER2KNum for $type {
29 unsafe fn her2k(
30 uplo: *const c_char,
31 trans: *const c_char,
32 n: *const blas_int,
33 k: *const blas_int,
34 alpha: *const Self,
35 a: *const Self,
36 lda: *const blas_int,
37 b: *const Self,
38 ldb: *const blas_int,
39 beta: *const Self::RealFloat,
40 c: *mut Self,
41 ldc: *const blas_int,
42 ) {
43 ffi::$func(uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
44 }
45 }
46 };
47}
48
49impl_her2k!(c32, cher2k_);
50impl_her2k!(c64, zher2k_);
51
52pub struct HER2K_Driver<'a, 'b, 'c, F>
57where
58 F: HER2KNum,
59{
60 uplo: c_char,
61 trans: c_char,
62 n: blas_int,
63 k: blas_int,
64 alpha: F,
65 a: ArrayView2<'a, F>,
66 lda: blas_int,
67 b: ArrayView2<'b, F>,
68 ldb: blas_int,
69 beta: F::RealFloat,
70 c: ArrayOut2<'c, F>,
71 ldc: blas_int,
72}
73
74impl<'a, 'b, 'c, F> BLASDriver<'c, F, Ix2> for HER2K_Driver<'a, 'b, 'c, F>
75where
76 F: HER2KNum,
77{
78 fn run_blas(self) -> Result<ArrayOut2<'c, F>, BLASError> {
79 let Self { uplo, trans, n, k, alpha, a, lda, b, ldb, beta, mut c, ldc } = self;
80 let a_ptr = a.as_ptr();
81 let b_ptr = b.as_ptr();
82 let c_ptr = c.get_data_mut_ptr();
83
84 if n == 0 {
87 return Ok(c.clone_to_view_mut());
88 } else if k == 0 {
89 let beta_f = F::RealFloat::from(beta);
90 if uplo == BLASLower.try_into()? {
91 for i in 0..n {
92 c.view_mut().slice_mut(s![i.., i]).mapv_inplace(|v| v * F::from_real(beta_f));
93 }
94 } else if uplo == BLASUpper.try_into()? {
95 for i in 0..n {
96 c.view_mut().slice_mut(s![..=i, i]).mapv_inplace(|v| v * F::from_real(beta_f));
97 }
98 } else {
99 blas_invalid!(uplo)?
100 }
101 return Ok(c.clone_to_view_mut());
102 }
103
104 unsafe {
105 F::her2k(&uplo, &trans, &n, &k, &alpha, a_ptr, &lda, b_ptr, &ldb, &beta, c_ptr, &ldc);
106 }
107 return Ok(c.clone_to_view_mut());
108 }
109}
110
111#[derive(Builder)]
116#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
117pub struct HER2K_<'a, 'b, 'c, F>
118where
119 F: HER2KNum,
120{
121 pub a: ArrayView2<'a, F>,
122 pub b: ArrayView2<'b, F>,
123
124 #[builder(setter(into, strip_option), default = "None")]
125 pub c: Option<ArrayViewMut2<'c, F>>,
126 #[builder(setter(into), default = "F::one()")]
127 pub alpha: F,
128 #[builder(setter(into), default = "F::RealFloat::zero()")]
129 pub beta: F::RealFloat,
130 #[builder(setter(into), default = "BLASLower")]
131 pub uplo: BLASUpLo,
132 #[builder(setter(into), default = "BLASNoTrans")]
133 pub trans: BLASTranspose,
134 #[builder(setter(into, strip_option), default = "None")]
135 pub layout: Option<BLASLayout>,
136}
137
138impl<'a, 'b, 'c, F> BLASBuilder_<'c, F, Ix2> for HER2K_<'a, 'b, 'c, F>
139where
140 F: HER2KNum,
141{
142 fn driver(self) -> Result<HER2K_Driver<'a, 'b, 'c, F>, BLASError> {
143 let Self { a, b, c, alpha, beta, uplo, trans, layout } = self;
144
145 assert_eq!(layout, Some(BLASColMajor));
147 assert!(a.is_fpref() && a.is_fpref());
148
149 let (n, k) = match trans {
151 BLASNoTrans => (a.len_of(Axis(0)), a.len_of(Axis(1))),
152 BLASConjTrans => (a.len_of(Axis(1)), a.len_of(Axis(0))),
153 _ => blas_invalid!(trans)?,
154 };
155 let lda = a.stride_of(Axis(1));
156 let ldb = b.stride_of(Axis(1));
157
158 match trans {
161 BLASNoTrans => blas_assert_eq!(b.dim(), (n, k), InvalidDim)?,
162 BLASConjTrans => blas_assert_eq!(b.dim(), (k, n), InvalidDim)?,
163 _ => blas_invalid!(trans)?,
164 };
165
166 let c = match c {
168 Some(c) => {
169 blas_assert_eq!(c.dim(), (n, n), InvalidDim)?;
170 if c.view().is_fpref() {
171 ArrayOut2::ViewMut(c)
172 } else {
173 let c_buffer = c.view().to_col_layout()?.into_owned();
174 ArrayOut2::ToBeCloned(c, c_buffer)
175 }
176 },
177 None => ArrayOut2::Owned(Array2::zeros((n, n).f())),
178 };
179 let ldc = c.view().stride_of(Axis(1));
180
181 let driver = HER2K_Driver {
183 uplo: uplo.try_into()?,
184 trans: trans.try_into()?,
185 n: n.try_into()?,
186 k: k.try_into()?,
187 alpha,
188 a,
189 lda: lda.try_into()?,
190 b,
191 ldb: ldb.try_into()?,
192 beta,
193 c,
194 ldc: ldc.try_into()?,
195 };
196 return Ok(driver);
197 }
198}
199
200pub type HER2K<'a, 'b, 'c, F> = HER2K_Builder<'a, 'b, 'c, F>;
205pub type CHER2K<'a, 'b, 'c> = HER2K<'a, 'b, 'c, c32>;
206pub type ZHER2K<'a, 'b, 'c> = HER2K<'a, 'b, 'c, c64>;
207
208impl<'a, 'b, 'c, F> BLASBuilder<'c, F, Ix2> for HER2K_Builder<'a, 'b, 'c, F>
209where
210 F: HER2KNum,
211{
212 fn run(self) -> Result<ArrayOut2<'c, F>, BLASError> {
213 let HER2K_ { a, b, c, alpha, beta, uplo, trans, layout } = self.build()?;
215
216 match trans {
219 BLASNoTrans | BLASConjTrans => (),
221 _ => blas_invalid!(trans)?,
222 };
223
224 let layout_a = get_layout_array2(&a);
225 let layout_b = get_layout_array2(&b);
226 let layout_c = c.as_ref().map(|c| get_layout_array2(&c.view()));
227
228 let layout = get_layout_row_preferred(&[layout, layout_c], &[layout_a, layout_b]);
230 if layout == BLASColMajor {
231 let a_cow = a.to_col_layout()?;
233 let b_cow = b.to_col_layout()?;
234 let obj = HER2K_ {
235 a: a_cow.view(),
236 b: b_cow.view(),
237 c,
238 alpha,
239 beta,
240 uplo,
241 trans,
242 layout: Some(BLASColMajor),
243 };
244 return obj.driver()?.run_blas();
245 } else if layout == BLASRowMajor {
246 let a_cow = a.to_row_layout()?;
248 let b_cow = b.to_row_layout()?;
249 let obj = HER2K_ {
250 a: b_cow.t(),
251 b: a_cow.t(),
252 c: c.map(|c| c.reversed_axes()),
253 alpha,
254 beta,
255 uplo: uplo.flip()?,
256 trans: trans.flip(true)?,
257 layout: Some(BLASColMajor),
258 };
259 return Ok(obj.driver()?.run_blas()?.reversed_axes());
260 } else {
261 return blas_raise!(RuntimeError, "This is designed not to execuate this line.");
262 }
263 }
264}
265
266