1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6pub trait HER2Num: BLASFloat {
9 unsafe fn syr2(
10 uplo: *const c_char,
11 n: *const blas_int,
12 alpha: *const Self,
13 x: *const Self,
14 incx: *const blas_int,
15 y: *const Self,
16 incy: *const blas_int,
17 a: *mut Self,
18 lda: *const blas_int,
19 );
20}
21
22macro_rules! impl_func {
23 ($type: ty, $func: ident) => {
24 impl HER2Num for $type {
25 unsafe fn syr2(
26 uplo: *const c_char,
27 n: *const blas_int,
28 alpha: *const Self,
29 x: *const Self,
30 incx: *const blas_int,
31 y: *const Self,
32 incy: *const blas_int,
33 a: *mut Self,
34 lda: *const blas_int,
35 ) {
36 ffi::$func(uplo, n, alpha, x, incx, y, incy, a, lda);
37 }
38 }
39 };
40}
41
42impl_func!(f32, ssyr2_);
43impl_func!(f64, dsyr2_);
44impl_func!(c32, cher2_);
45impl_func!(c64, zher2_);
46
47pub struct HER2_Driver<'x, 'y, 'a, F>
52where
53 F: HER2Num,
54{
55 uplo: c_char,
56 n: blas_int,
57 alpha: F,
58 x: ArrayView1<'x, F>,
59 incx: blas_int,
60 y: ArrayView1<'y, F>,
61 incy: blas_int,
62 a: ArrayOut2<'a, F>,
63 lda: blas_int,
64}
65
66impl<'x, 'y, 'a, F> BLASDriver<'a, F, Ix2> for HER2_Driver<'x, 'y, 'a, F>
67where
68 F: HER2Num,
69{
70 fn run_blas(self) -> Result<ArrayOut2<'a, F>, BLASError> {
71 let Self { uplo, n, alpha, x, incx, y, incy, mut a, lda } = self;
72 let x_ptr = x.as_ptr();
73 let y_ptr = y.as_ptr();
74 let a_ptr = a.get_data_mut_ptr();
75
76 if n == 0 {
79 return Ok(a.clone_to_view_mut());
80 }
81
82 unsafe {
83 F::syr2(&uplo, &n, &alpha, x_ptr, &incx, y_ptr, &incy, a_ptr, &lda);
84 }
85 return Ok(a.clone_to_view_mut());
86 }
87}
88
89#[derive(Builder)]
94#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
95pub struct HER2_<'x, 'y, 'a, F>
96where
97 F: BLASFloat,
98{
99 pub x: ArrayView1<'x, F>,
100 pub y: ArrayView1<'y, F>,
101
102 #[builder(setter(into, strip_option), default = "None")]
103 pub a: Option<ArrayViewMut2<'a, F>>,
104 #[builder(setter(into), default = "F::one()")]
105 pub alpha: F,
106 #[builder(setter(into), default = "BLASUpper")]
107 pub uplo: BLASUpLo,
108}
109
110impl<'x, 'y, 'a, F> BLASBuilder_<'a, F, Ix2> for HER2_<'x, 'y, 'a, F>
111where
112 F: HER2Num,
113{
114 fn driver(self) -> Result<HER2_Driver<'x, 'y, 'a, F>, BLASError> {
115 let Self { x, y, a, alpha, uplo, .. } = self;
116
117 let incx = x.stride_of(Axis(0));
119 let incy = y.stride_of(Axis(0));
120 let n = x.len_of(Axis(0));
121
122 blas_assert_eq!(y.len_of(Axis(0)), n, InvalidDim)?;
124
125 let a = match a {
127 Some(a) => {
128 blas_assert_eq!(a.dim(), (n, n), InvalidDim)?;
129 if a.view().is_fpref() {
130 ArrayOut2::ViewMut(a)
131 } else {
132 let a_buffer = a.view().to_col_layout()?.into_owned();
133 ArrayOut2::ToBeCloned(a, a_buffer)
134 }
135 },
136 None => ArrayOut2::Owned(Array2::zeros((n, n).f())),
137 };
138 let lda = a.view().stride_of(Axis(1));
139
140 let driver = HER2_Driver {
142 uplo: uplo.try_into()?,
143 n: n.try_into()?,
144 alpha,
145 x,
146 incx: incx.try_into()?,
147 y,
148 incy: incy.try_into()?,
149 a,
150 lda: lda.try_into()?,
151 };
152 return Ok(driver);
153 }
154}
155
156pub type SYR2<'x, 'y, 'a, F> = HER2_Builder<'x, 'y, 'a, F>;
161pub type SSYR2<'x, 'y, 'a> = SYR2<'x, 'y, 'a, f32>;
162pub type DSYR2<'x, 'y, 'a> = SYR2<'x, 'y, 'a, f64>;
163
164pub type HER2<'x, 'y, 'a, F> = HER2_Builder<'x, 'y, 'a, F>;
165pub type CHER2<'x, 'y, 'a> = HER2<'x, 'y, 'a, c32>;
166pub type ZHER2<'x, 'y, 'a> = HER2<'x, 'y, 'a, c64>;
167
168impl<'x, 'y, 'a, F> BLASBuilder<'a, F, Ix2> for HER2_Builder<'x, 'y, 'a, F>
169where
170 F: HER2Num,
171{
172 fn run(self) -> Result<ArrayOut2<'a, F>, BLASError> {
173 let obj = self.build()?;
175
176 if obj.a.as_ref().map(|a| a.view().is_fpref()) == Some(true) {
177 return obj.driver()?.run_blas();
179 } else {
180 let uplo = obj.uplo.flip()?;
182 let a = obj.a.map(|a| a.reversed_axes());
183 if F::is_complex() {
184 let x = obj.x.mapv(F::conj);
185 let y = obj.y.mapv(F::conj);
186 let obj = HER2_ { a, y: x.view(), x: y.view(), uplo, ..obj };
187 let a = obj.driver()?.run_blas()?;
188 return Ok(a.reversed_axes());
189 } else {
190 let obj = HER2_ { a, uplo, x: obj.y, y: obj.x, ..obj };
191 let a = obj.driver()?.run_blas()?;
192 return Ok(a.reversed_axes());
193 };
194 }
195 }
196}
197
198