1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5use num_traits::*;
6
7pub trait HERNum: BLASFloat {
10 unsafe fn her(
11 uplo: *const c_char,
12 n: *const blas_int,
13 alpha: *const Self::RealFloat,
14 x: *const Self,
15 incx: *const blas_int,
16 a: *mut Self,
17 lda: *const blas_int,
18 );
19}
20
21macro_rules! impl_func {
22 ($type: ty, $func: ident) => {
23 impl HERNum for $type {
24 unsafe fn her(
25 uplo: *const c_char,
26 n: *const blas_int,
27 alpha: *const Self::RealFloat,
28 x: *const Self,
29 incx: *const blas_int,
30 a: *mut Self,
31 lda: *const blas_int,
32 ) {
33 ffi::$func(uplo, n, alpha, x, incx, a, lda);
34 }
35 }
36 };
37}
38
39impl_func!(f32, ssyr_);
40impl_func!(f64, dsyr_);
41impl_func!(c32, cher_);
42impl_func!(c64, zher_);
43
44pub struct HER_Driver<'x, 'a, F>
49where
50 F: HERNum,
51{
52 uplo: c_char,
53 n: blas_int,
54 alpha: F::RealFloat,
55 x: ArrayView1<'x, F>,
56 incx: blas_int,
57 a: ArrayOut2<'a, F>,
58 lda: blas_int,
59}
60
61impl<'x, 'a, F> BLASDriver<'a, F, Ix2> for HER_Driver<'x, 'a, F>
62where
63 F: HERNum,
64{
65 fn run_blas(self) -> Result<ArrayOut2<'a, F>, BLASError> {
66 let Self { uplo, n, alpha, x, incx, mut a, lda, .. } = self;
67 let x_ptr = x.as_ptr();
68 let a_ptr = a.get_data_mut_ptr();
69
70 if n == 0 {
73 return Ok(a.clone_to_view_mut());
74 }
75
76 unsafe {
77 F::her(&uplo, &n, &alpha, x_ptr, &incx, a_ptr, &lda);
78 }
79 return Ok(a.clone_to_view_mut());
80 }
81}
82
83#[derive(Builder)]
88#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
89pub struct HER_<'x, 'a, F>
90where
91 F: HERNum,
92{
93 pub x: ArrayView1<'x, F>,
94
95 #[builder(setter(into, strip_option), default = "None")]
96 pub a: Option<ArrayViewMut2<'a, F>>,
97 #[builder(setter(into), default = "F::RealFloat::one()")]
98 pub alpha: F::RealFloat,
99 #[builder(setter(into), default = "BLASUpper")]
100 pub uplo: BLASUpLo,
101}
102
103impl<'x, 'a, F> BLASBuilder_<'a, F, Ix2> for HER_<'x, 'a, F>
104where
105 F: HERNum,
106{
107 fn driver(self) -> Result<HER_Driver<'x, 'a, F>, BLASError> {
108 let Self { x, a, alpha, uplo, .. } = self;
109
110 let incx = x.stride_of(Axis(0));
112 let n = x.len_of(Axis(0));
113
114 let a = match a {
116 Some(a) => {
117 blas_assert_eq!(a.dim(), (n, n), InvalidDim)?;
118 if a.view().is_fpref() {
119 ArrayOut2::ViewMut(a)
120 } else {
121 let a_buffer = a.view().to_col_layout()?.into_owned();
122 ArrayOut2::ToBeCloned(a, a_buffer)
123 }
124 },
125 None => ArrayOut2::Owned(Array2::zeros((n, n).f())),
126 };
127 let lda = a.view().stride_of(Axis(1));
128
129 let driver = HER_Driver {
131 uplo: uplo.try_into()?,
132 n: n.try_into()?,
133 alpha,
134 x,
135 incx: incx.try_into()?,
136 a,
137 lda: lda.try_into()?,
138 };
139 return Ok(driver);
140 }
141}
142
143pub type HER<'x, 'a, F> = HER_Builder<'x, 'a, F>;
148pub type SSYR<'x, 'a> = HER<'x, 'a, f32>;
149pub type DSYR<'x, 'a> = HER<'x, 'a, f64>;
150pub type CHER<'x, 'a> = HER<'x, 'a, c32>;
151pub type ZHER<'x, 'a> = HER<'x, 'a, c64>;
152
153impl<'x, 'a, F> BLASBuilder<'a, F, Ix2> for HER_Builder<'x, 'a, F>
154where
155 F: HERNum,
156{
157 fn run(self) -> Result<ArrayOut2<'a, F>, BLASError> {
158 let obj = self.build()?;
160
161 if obj.a.as_ref().map(|a| a.view().is_fpref()) == Some(true) {
162 return obj.driver()?.run_blas();
164 } else {
165 let uplo = obj.uplo.flip()?;
167 let a = obj.a.map(|a| a.reversed_axes());
168 if F::is_complex() {
169 let x = obj.x.mapv(F::conj);
170 let obj = HER_ { a, x: x.view(), uplo, ..obj };
171 let a = obj.driver()?.run_blas()?;
172 return Ok(a.reversed_axes());
173 } else {
174 let obj = HER_ { a, uplo, ..obj };
175 let a = obj.driver()?.run_blas()?;
176 return Ok(a.reversed_axes());
177 };
178 }
179 }
180}
181
182