1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5use num_traits::*;
6
7pub trait HPRNum: BLASFloat {
10 unsafe fn hpr(
11 uplo: *const c_char,
12 n: *const blas_int,
13 alpha: *const Self::RealFloat,
14 x: *const Self,
15 incx: *const blas_int,
16 ap: *mut Self,
17 );
18}
19
20macro_rules! impl_func {
21 ($type: ty, $func: ident) => {
22 impl HPRNum for $type {
23 unsafe fn hpr(
24 uplo: *const c_char,
25 n: *const blas_int,
26 alpha: *const Self::RealFloat,
27 x: *const Self,
28 incx: *const blas_int,
29 ap: *mut Self,
30 ) {
31 ffi::$func(uplo, n, alpha, x, incx, ap);
32 }
33 }
34 };
35}
36
37impl_func!(f32, sspr_);
38impl_func!(f64, dspr_);
39impl_func!(c32, chpr_);
40impl_func!(c64, zhpr_);
41
42pub struct HPR_Driver<'x, 'a, F>
47where
48 F: HPRNum,
49{
50 uplo: c_char,
51 n: blas_int,
52 alpha: F::RealFloat,
53 x: ArrayView1<'x, F>,
54 incx: blas_int,
55 ap: ArrayOut1<'a, F>,
56}
57
58impl<'x, 'a, F> BLASDriver<'a, F, Ix1> for HPR_Driver<'x, 'a, F>
59where
60 F: HPRNum,
61{
62 fn run_blas(self) -> Result<ArrayOut1<'a, F>, BLASError> {
63 let Self { uplo, n, alpha, x, incx, mut ap, .. } = self;
64 let x_ptr = x.as_ptr();
65 let ap_ptr = ap.get_data_mut_ptr();
66
67 if n == 0 {
70 return Ok(ap.clone_to_view_mut());
71 }
72
73 unsafe {
74 F::hpr(&uplo, &n, &alpha, x_ptr, &incx, ap_ptr);
75 }
76 return Ok(ap.clone_to_view_mut());
77 }
78}
79
80#[derive(Builder)]
85#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
86pub struct HPR_<'x, 'a, F>
87where
88 F: HPRNum,
89{
90 pub x: ArrayView1<'x, F>,
91
92 #[builder(setter(into, strip_option), default = "None")]
93 pub ap: Option<ArrayViewMut1<'a, F>>,
94 #[builder(setter(into), default = "F::RealFloat::one()")]
95 pub alpha: F::RealFloat,
96 #[builder(setter(into), default = "BLASUpper")]
97 pub uplo: BLASUpLo,
98 #[builder(setter(into, strip_option), default = "None")]
99 pub layout: Option<BLASLayout>,
100}
101
102impl<'x, 'a, F> BLASBuilder_<'a, F, Ix1> for HPR_<'x, 'a, F>
103where
104 F: HPRNum,
105{
106 fn driver(self) -> Result<HPR_Driver<'x, 'a, F>, BLASError> {
107 let Self { x, ap, alpha, uplo, layout, .. } = self;
108
109 let incx = x.stride_of(Axis(0));
111 let n = x.len_of(Axis(0));
112
113 assert_eq!(layout, Some(BLASColMajor));
115
116 let ap = match ap {
118 Some(ap) => {
119 blas_assert_eq!(ap.len_of(Axis(0)), n * (n + 1) / 2, InvalidDim)?;
120 if ap.is_standard_layout() {
121 ArrayOut1::ViewMut(ap)
122 } else {
123 let ap_buffer = ap.view().to_seq_layout()?.into_owned();
124 ArrayOut1::ToBeCloned(ap, ap_buffer)
125 }
126 },
127 None => ArrayOut1::Owned(Array1::zeros(n * (n + 1) / 2)),
128 };
129
130 let driver =
132 HPR_Driver { uplo: uplo.try_into()?, n: n.try_into()?, alpha, x, incx: incx.try_into()?, ap };
133 return Ok(driver);
134 }
135}
136
137pub type HPR<'x, 'a, F> = HPR_Builder<'x, 'a, F>;
142pub type SSPR<'x, 'a> = HPR<'x, 'a, f32>;
143pub type DSPR<'x, 'a> = HPR<'x, 'a, f64>;
144pub type CHPR<'x, 'a> = HPR<'x, 'a, c32>;
145pub type ZHPR<'x, 'a> = HPR<'x, 'a, c64>;
146
147impl<'x, 'a, F> BLASBuilder<'a, F, Ix1> for HPR_Builder<'x, 'a, F>
148where
149 F: HPRNum,
150{
151 fn run(self) -> Result<ArrayOut1<'a, F>, BLASError> {
152 let obj = self.build()?;
154
155 if obj.layout == Some(BLASColMajor) {
156 return obj.driver()?.run_blas();
158 } else {
159 let uplo = obj.uplo.flip()?;
161 if F::is_complex() {
162 let x = obj.x.mapv(F::conj);
163 let obj = HPR_ { x: x.view(), uplo, layout: Some(BLASColMajor), ..obj };
164 return obj.driver()?.run_blas();
165 } else {
166 let obj = HPR_ { uplo, layout: Some(BLASColMajor), ..obj };
167 return obj.driver()?.run_blas();
168 };
169 }
170 }
171}
172
173