blas_array2/blas2/
hpr.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5use num_traits::*;
6
7/* #region BLAS func */
8
9pub trait HPRNum: BLASFloat {
10    unsafe fn hpr(
11        uplo: *const c_char,
12        n: *const blas_int,
13        alpha: *const Self::RealFloat,
14        x: *const Self,
15        incx: *const blas_int,
16        ap: *mut Self,
17    );
18}
19
20macro_rules! impl_func {
21    ($type: ty, $func: ident) => {
22        impl HPRNum for $type {
23            unsafe fn hpr(
24                uplo: *const c_char,
25                n: *const blas_int,
26                alpha: *const Self::RealFloat,
27                x: *const Self,
28                incx: *const blas_int,
29                ap: *mut Self,
30            ) {
31                ffi::$func(uplo, n, alpha, x, incx, ap);
32            }
33        }
34    };
35}
36
37impl_func!(f32, sspr_);
38impl_func!(f64, dspr_);
39impl_func!(c32, chpr_);
40impl_func!(c64, zhpr_);
41
42/* #endregion */
43
44/* #region BLAS driver */
45
46pub struct HPR_Driver<'x, 'a, F>
47where
48    F: HPRNum,
49{
50    uplo: c_char,
51    n: blas_int,
52    alpha: F::RealFloat,
53    x: ArrayView1<'x, F>,
54    incx: blas_int,
55    ap: ArrayOut1<'a, F>,
56}
57
58impl<'x, 'a, F> BLASDriver<'a, F, Ix1> for HPR_Driver<'x, 'a, F>
59where
60    F: HPRNum,
61{
62    fn run_blas(self) -> Result<ArrayOut1<'a, F>, BLASError> {
63        let Self { uplo, n, alpha, x, incx, mut ap, .. } = self;
64        let x_ptr = x.as_ptr();
65        let ap_ptr = ap.get_data_mut_ptr();
66
67        // assuming dimension checks has been performed
68        // unconditionally return Ok if output does not contain anything
69        if n == 0 {
70            return Ok(ap.clone_to_view_mut());
71        }
72
73        unsafe {
74            F::hpr(&uplo, &n, &alpha, x_ptr, &incx, ap_ptr);
75        }
76        return Ok(ap.clone_to_view_mut());
77    }
78}
79
80/* #endregion */
81
82/* #region BLAS builder */
83
84#[derive(Builder)]
85#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
86pub struct HPR_<'x, 'a, F>
87where
88    F: HPRNum,
89{
90    pub x: ArrayView1<'x, F>,
91
92    #[builder(setter(into, strip_option), default = "None")]
93    pub ap: Option<ArrayViewMut1<'a, F>>,
94    #[builder(setter(into), default = "F::RealFloat::one()")]
95    pub alpha: F::RealFloat,
96    #[builder(setter(into), default = "BLASUpper")]
97    pub uplo: BLASUpLo,
98    #[builder(setter(into, strip_option), default = "None")]
99    pub layout: Option<BLASLayout>,
100}
101
102impl<'x, 'a, F> BLASBuilder_<'a, F, Ix1> for HPR_<'x, 'a, F>
103where
104    F: HPRNum,
105{
106    fn driver(self) -> Result<HPR_Driver<'x, 'a, F>, BLASError> {
107        let Self { x, ap, alpha, uplo, layout, .. } = self;
108
109        // initialize intent(hide)
110        let incx = x.stride_of(Axis(0));
111        let n = x.len_of(Axis(0));
112
113        // only fortran-preferred (col-major) is accepted in inner wrapper
114        assert_eq!(layout, Some(BLASColMajor));
115
116        // prepare output
117        let ap = match ap {
118            Some(ap) => {
119                blas_assert_eq!(ap.len_of(Axis(0)), n * (n + 1) / 2, InvalidDim)?;
120                if ap.is_standard_layout() {
121                    ArrayOut1::ViewMut(ap)
122                } else {
123                    let ap_buffer = ap.view().to_seq_layout()?.into_owned();
124                    ArrayOut1::ToBeCloned(ap, ap_buffer)
125                }
126            },
127            None => ArrayOut1::Owned(Array1::zeros(n * (n + 1) / 2)),
128        };
129
130        // finalize
131        let driver =
132            HPR_Driver { uplo: uplo.try_into()?, n: n.try_into()?, alpha, x, incx: incx.try_into()?, ap };
133        return Ok(driver);
134    }
135}
136
137/* #endregion */
138
139/* #region BLAS wrapper */
140
141pub type HPR<'x, 'a, F> = HPR_Builder<'x, 'a, F>;
142pub type SSPR<'x, 'a> = HPR<'x, 'a, f32>;
143pub type DSPR<'x, 'a> = HPR<'x, 'a, f64>;
144pub type CHPR<'x, 'a> = HPR<'x, 'a, c32>;
145pub type ZHPR<'x, 'a> = HPR<'x, 'a, c64>;
146
147impl<'x, 'a, F> BLASBuilder<'a, F, Ix1> for HPR_Builder<'x, 'a, F>
148where
149    F: HPRNum,
150{
151    fn run(self) -> Result<ArrayOut1<'a, F>, BLASError> {
152        // initialize
153        let obj = self.build()?;
154
155        if obj.layout == Some(BLASColMajor) {
156            // F-contiguous
157            return obj.driver()?.run_blas();
158        } else {
159            // C-contiguous
160            let uplo = obj.uplo.flip()?;
161            if F::is_complex() {
162                let x = obj.x.mapv(F::conj);
163                let obj = HPR_ { x: x.view(), uplo, layout: Some(BLASColMajor), ..obj };
164                return obj.driver()?.run_blas();
165            } else {
166                let obj = HPR_ { uplo, layout: Some(BLASColMajor), ..obj };
167                return obj.driver()?.run_blas();
168            };
169        }
170    }
171}
172
173/* #endregion */