1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5use num_traits::*;
6
7pub trait HERKNum: BLASFloat {
10 unsafe fn herk(
11 uplo: *const c_char,
12 trans: *const c_char,
13 n: *const blas_int,
14 k: *const blas_int,
15 alpha: *const Self::RealFloat,
16 a: *const Self,
17 lda: *const blas_int,
18 beta: *const Self::RealFloat,
19 c: *mut Self,
20 ldc: *const blas_int,
21 );
22}
23
24macro_rules! impl_herk {
25 ($type: ty, $func: ident) => {
26 impl HERKNum for $type {
27 unsafe fn herk(
28 uplo: *const c_char,
29 trans: *const c_char,
30 n: *const blas_int,
31 k: *const blas_int,
32 alpha: *const Self::RealFloat,
33 a: *const Self,
34 lda: *const blas_int,
35 beta: *const Self::RealFloat,
36 c: *mut Self,
37 ldc: *const blas_int,
38 ) {
39 ffi::$func(uplo, trans, n, k, alpha, a, lda, beta, c, ldc);
40 }
41 }
42 };
43}
44
45impl_herk!(c32, cherk_);
46impl_herk!(c64, zherk_);
47
48pub struct HERK_Driver<'a, 'c, F>
53where
54 F: BLASFloat,
55{
56 uplo: c_char,
57 trans: c_char,
58 n: blas_int,
59 k: blas_int,
60 alpha: F::RealFloat,
61 a: ArrayView2<'a, F>,
62 lda: blas_int,
63 beta: F::RealFloat,
64 c: ArrayOut2<'c, F>,
65 ldc: blas_int,
66}
67
68impl<'a, 'c, F> BLASDriver<'c, F, Ix2> for HERK_Driver<'a, 'c, F>
69where
70 F: HERKNum,
71{
72 fn run_blas(self) -> Result<ArrayOut2<'c, F>, BLASError> {
73 let Self { uplo, trans, n, k, alpha, a, lda, beta, mut c, ldc } = self;
74 let a_ptr = a.as_ptr();
75 let c_ptr = c.get_data_mut_ptr();
76
77 if n == 0 {
80 return Ok(c.clone_to_view_mut());
81 } else if k == 0 {
82 let beta_f = F::from_real(beta);
83 if uplo == BLASLower.try_into()? {
84 for i in 0..n {
85 c.view_mut().slice_mut(s![i.., i]).mapv_inplace(|v| v * beta_f);
86 }
87 } else if uplo == BLASUpper.try_into()? {
88 for i in 0..n {
89 c.view_mut().slice_mut(s![..=i, i]).mapv_inplace(|v| v * beta_f);
90 }
91 } else {
92 blas_invalid!(uplo)?
93 }
94 return Ok(c.clone_to_view_mut());
95 }
96
97 unsafe {
98 F::herk(&uplo, &trans, &n, &k, &alpha, a_ptr, &lda, &beta, c_ptr, &ldc);
99 }
100 return Ok(c.clone_to_view_mut());
101 }
102}
103
104#[derive(Builder)]
109#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
110pub struct HERK_<'a, 'c, F>
111where
112 F: HERKNum,
113{
114 pub a: ArrayView2<'a, F>,
115
116 #[builder(setter(into, strip_option), default = "None")]
117 pub c: Option<ArrayViewMut2<'c, F>>,
118 #[builder(setter(into), default = "F::RealFloat::one()")]
119 pub alpha: F::RealFloat,
120 #[builder(setter(into), default = "F::RealFloat::zero()")]
121 pub beta: F::RealFloat,
122 #[builder(setter(into), default = "BLASLower")]
123 pub uplo: BLASUpLo,
124 #[builder(setter(into), default = "BLASNoTrans")]
125 pub trans: BLASTranspose,
126 #[builder(setter(into, strip_option), default = "None")]
127 pub layout: Option<BLASLayout>,
128}
129
130impl<'a, 'c, F> BLASBuilder_<'c, F, Ix2> for HERK_<'a, 'c, F>
131where
132 F: HERKNum,
133{
134 fn driver(self) -> Result<HERK_Driver<'a, 'c, F>, BLASError> {
135 let Self { a, c, alpha, beta, uplo, trans, layout } = self;
136
137 assert_eq!(layout, Some(BLASColMajor));
139 assert!(a.is_fpref());
140
141 let (n, k) = match trans {
143 BLASNoTrans => (a.len_of(Axis(0)), a.len_of(Axis(1))),
144 BLASConjTrans => (a.len_of(Axis(1)), a.len_of(Axis(0))),
145 _ => blas_invalid!(trans)?,
146 };
147 let lda = a.stride_of(Axis(1));
148
149 let c = match c {
151 Some(c) => {
152 blas_assert_eq!(c.dim(), (n, n), InvalidDim)?;
153 if c.view().is_fpref() {
154 ArrayOut2::ViewMut(c)
155 } else {
156 let c_buffer = c.view().to_col_layout()?.into_owned();
157 ArrayOut2::ToBeCloned(c, c_buffer)
158 }
159 },
160 None => ArrayOut2::Owned(Array2::zeros((n, n).f())),
161 };
162 let ldc = c.view().stride_of(Axis(1));
163
164 let driver = HERK_Driver {
166 uplo: uplo.try_into()?,
167 trans: trans.try_into()?,
168 n: n.try_into()?,
169 k: k.try_into()?,
170 alpha,
171 a,
172 lda: lda.try_into()?,
173 beta,
174 c,
175 ldc: ldc.try_into()?,
176 };
177 return Ok(driver);
178 }
179}
180
181pub type HERK<'a, 'c, F> = HERK_Builder<'a, 'c, F>;
186pub type CHERK<'a, 'c> = HERK<'a, 'c, c32>;
187pub type ZHERK<'a, 'c> = HERK<'a, 'c, c64>;
188
189impl<'a, 'c, F> BLASBuilder<'c, F, Ix2> for HERK_Builder<'a, 'c, F>
190where
191 F: HERKNum,
192{
193 fn run(self) -> Result<ArrayOut2<'c, F>, BLASError> {
194 let HERK_ { a, c, alpha, beta, uplo, trans, layout } = self.build()?;
196 let at = a.t();
197
198 match trans {
201 BLASNoTrans | BLASConjTrans => (),
203 _ => blas_invalid!(trans)?,
204 };
205
206 let layout_a = get_layout_array2(&a);
207 let layout_c = c.as_ref().map(|c| get_layout_array2(&c.view()));
208
209 let layout = get_layout_row_preferred(&[layout, layout_c], &[layout_a]);
210 if layout == BLASColMajor {
211 let (trans, a_cow) = flip_trans_fpref(trans, &a, &at, true)?;
213 let obj = HERK_ { a: a_cow.view(), c, alpha, beta, uplo, trans, layout: Some(BLASColMajor) };
214 return obj.driver()?.run_blas();
215 } else if layout == BLASRowMajor {
216 let (trans, a_cow) = flip_trans_cpref(trans, &a, &at, true)?;
217 let obj = HERK_ {
218 a: a_cow.t(),
219 c: c.map(|c| c.reversed_axes()),
220 alpha,
221 beta,
222 uplo: uplo.flip()?,
223 trans: trans.flip(true)?,
224 layout: Some(BLASColMajor),
225 };
226 return Ok(obj.driver()?.run_blas()?.reversed_axes());
227 } else {
228 return blas_raise!(RuntimeError, "This is designed not to execuate this line.");
229 }
230 }
231}
232
233