blas_array2/blas2/
hpr2.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6/* #region BLAS func */
7
8pub trait HPR2Num: BLASFloat {
9    unsafe fn hpr2(
10        uplo: *const c_char,
11        n: *const blas_int,
12        alpha: *const Self,
13        x: *const Self,
14        incx: *const blas_int,
15        y: *const Self,
16        incy: *const blas_int,
17        ap: *mut Self,
18    );
19}
20
21macro_rules! impl_func {
22    ($type: ty, $func: ident) => {
23        impl HPR2Num for $type {
24            unsafe fn hpr2(
25                uplo: *const c_char,
26                n: *const blas_int,
27                alpha: *const Self,
28                x: *const Self,
29                incx: *const blas_int,
30                y: *const Self,
31                incy: *const blas_int,
32                ap: *mut Self,
33            ) {
34                ffi::$func(uplo, n, alpha, x, incx, y, incy, ap);
35            }
36        }
37    };
38}
39
40impl_func!(f32, sspr2_);
41impl_func!(f64, dspr2_);
42impl_func!(c32, chpr2_);
43impl_func!(c64, zhpr2_);
44
45/* #endregion */
46
47/* #region BLAS driver */
48
49pub struct HPR2_Driver<'x, 'y, 'a, F>
50where
51    F: HPR2Num,
52{
53    uplo: c_char,
54    n: blas_int,
55    alpha: F,
56    x: ArrayView1<'x, F>,
57    incx: blas_int,
58    y: ArrayView1<'y, F>,
59    incy: blas_int,
60    ap: ArrayOut1<'a, F>,
61}
62
63impl<'x, 'y, 'a, F> BLASDriver<'a, F, Ix1> for HPR2_Driver<'x, 'y, 'a, F>
64where
65    F: HPR2Num,
66{
67    fn run_blas(self) -> Result<ArrayOut1<'a, F>, BLASError> {
68        let Self { uplo, n, alpha, x, incx, y, incy, mut ap } = self;
69        let x_ptr = x.as_ptr();
70        let y_ptr = y.as_ptr();
71        let ap_ptr = ap.get_data_mut_ptr();
72
73        // assuming dimension checks has been performed
74        // unconditionally return Ok if output does not contain anything
75        if n == 0 {
76            return Ok(ap.clone_to_view_mut());
77        }
78
79        unsafe {
80            F::hpr2(&uplo, &n, &alpha, x_ptr, &incx, y_ptr, &incy, ap_ptr);
81        }
82        return Ok(ap.clone_to_view_mut());
83    }
84}
85
86/* #endregion */
87
88/* #region BLAS builder */
89
90#[derive(Builder)]
91#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
92pub struct HPR2_<'x, 'y, 'a, F>
93where
94    F: HPR2Num,
95{
96    pub x: ArrayView1<'x, F>,
97    pub y: ArrayView1<'y, F>,
98
99    #[builder(setter(into, strip_option), default = "None")]
100    pub ap: Option<ArrayViewMut1<'a, F>>,
101    #[builder(setter(into), default = "F::one()")]
102    pub alpha: F,
103    #[builder(setter(into), default = "BLASUpper")]
104    pub uplo: BLASUpLo,
105    #[builder(setter(into, strip_option), default = "None")]
106    pub layout: Option<BLASLayout>,
107}
108
109impl<'x, 'y, 'a, F> BLASBuilder_<'a, F, Ix1> for HPR2_<'x, 'y, 'a, F>
110where
111    F: HPR2Num,
112{
113    fn driver(self) -> Result<HPR2_Driver<'x, 'y, 'a, F>, BLASError> {
114        let Self { x, y, ap, alpha, uplo, layout, .. } = self;
115
116        // initialize intent(hide)
117        let incx = x.stride_of(Axis(0));
118        let incy = y.stride_of(Axis(0));
119        let n = x.len_of(Axis(0));
120
121        // only fortran-preferred (col-major) is accepted in inner wrapper
122        assert_eq!(layout, Some(BLASColMajor));
123
124        // check optional
125        blas_assert_eq!(y.len_of(Axis(0)), n, InvalidDim)?;
126
127        // prepare output
128        let ap = match ap {
129            Some(ap) => {
130                blas_assert_eq!(ap.len_of(Axis(0)), n * (n + 1) / 2, InvalidDim)?;
131                if ap.is_standard_layout() {
132                    ArrayOut1::ViewMut(ap)
133                } else {
134                    let ap_buffer = ap.view().to_seq_layout()?.into_owned();
135                    ArrayOut1::ToBeCloned(ap, ap_buffer)
136                }
137            },
138            None => ArrayOut1::Owned(Array1::zeros(n * (n + 1) / 2)),
139        };
140
141        // finalize
142        let driver = HPR2_Driver {
143            uplo: uplo.try_into()?,
144            n: n.try_into()?,
145            alpha,
146            x,
147            incx: incx.try_into()?,
148            y,
149            incy: incy.try_into()?,
150            ap,
151        };
152        return Ok(driver);
153    }
154}
155
156/* #endregion */
157
158/* #region BLAS wrapper */
159
160pub type HPR2<'x, 'y, 'a, F> = HPR2_Builder<'x, 'y, 'a, F>;
161pub type SSPR2<'x, 'y, 'a> = HPR2<'x, 'y, 'a, f32>;
162pub type DSPR2<'x, 'y, 'a> = HPR2<'x, 'y, 'a, f64>;
163pub type CHPR2<'x, 'y, 'a> = HPR2<'x, 'y, 'a, c32>;
164pub type ZHPR2<'x, 'y, 'a> = HPR2<'x, 'y, 'a, c64>;
165
166impl<'x, 'y, 'a, F> BLASBuilder<'a, F, Ix1> for HPR2_Builder<'x, 'y, 'a, F>
167where
168    F: HPR2Num,
169{
170    fn run(self) -> Result<ArrayOut1<'a, F>, BLASError> {
171        // initialize
172        let obj = self.build()?;
173
174        if obj.layout == Some(BLASColMajor) {
175            // F-contiguous
176            return obj.driver()?.run_blas();
177        } else {
178            // C-contiguous
179            let uplo = obj.uplo.flip()?;
180            if F::is_complex() {
181                let x = obj.x.mapv(F::conj);
182                let y = obj.y.mapv(F::conj);
183                let obj = HPR2_ { y: x.view(), x: y.view(), uplo, layout: Some(BLASColMajor), ..obj };
184                return obj.driver()?.run_blas();
185            } else {
186                let obj = HPR2_ { uplo, x: obj.y, y: obj.x, layout: Some(BLASColMajor), ..obj };
187                return obj.driver()?.run_blas();
188            };
189        }
190    }
191}
192
193/* #endregion */