blas_array2/blas2/
her.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5use num_traits::*;
6
7/* #region BLAS func */
8
9pub trait HERNum: BLASFloat {
10    unsafe fn her(
11        uplo: *const c_char,
12        n: *const blas_int,
13        alpha: *const Self::RealFloat,
14        x: *const Self,
15        incx: *const blas_int,
16        a: *mut Self,
17        lda: *const blas_int,
18    );
19}
20
21macro_rules! impl_func {
22    ($type: ty, $func: ident) => {
23        impl HERNum for $type {
24            unsafe fn her(
25                uplo: *const c_char,
26                n: *const blas_int,
27                alpha: *const Self::RealFloat,
28                x: *const Self,
29                incx: *const blas_int,
30                a: *mut Self,
31                lda: *const blas_int,
32            ) {
33                ffi::$func(uplo, n, alpha, x, incx, a, lda);
34            }
35        }
36    };
37}
38
39impl_func!(f32, ssyr_);
40impl_func!(f64, dsyr_);
41impl_func!(c32, cher_);
42impl_func!(c64, zher_);
43
44/* #endregion */
45
46/* #region BLAS driver */
47
48pub struct HER_Driver<'x, 'a, F>
49where
50    F: HERNum,
51{
52    uplo: c_char,
53    n: blas_int,
54    alpha: F::RealFloat,
55    x: ArrayView1<'x, F>,
56    incx: blas_int,
57    a: ArrayOut2<'a, F>,
58    lda: blas_int,
59}
60
61impl<'x, 'a, F> BLASDriver<'a, F, Ix2> for HER_Driver<'x, 'a, F>
62where
63    F: HERNum,
64{
65    fn run_blas(self) -> Result<ArrayOut2<'a, F>, BLASError> {
66        let Self { uplo, n, alpha, x, incx, mut a, lda, .. } = self;
67        let x_ptr = x.as_ptr();
68        let a_ptr = a.get_data_mut_ptr();
69
70        // assuming dimension checks has been performed
71        // unconditionally return Ok if output does not contain anything
72        if n == 0 {
73            return Ok(a.clone_to_view_mut());
74        }
75
76        unsafe {
77            F::her(&uplo, &n, &alpha, x_ptr, &incx, a_ptr, &lda);
78        }
79        return Ok(a.clone_to_view_mut());
80    }
81}
82
83/* #endregion */
84
85/* #region BLAS builder */
86
87#[derive(Builder)]
88#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
89pub struct HER_<'x, 'a, F>
90where
91    F: HERNum,
92{
93    pub x: ArrayView1<'x, F>,
94
95    #[builder(setter(into, strip_option), default = "None")]
96    pub a: Option<ArrayViewMut2<'a, F>>,
97    #[builder(setter(into), default = "F::RealFloat::one()")]
98    pub alpha: F::RealFloat,
99    #[builder(setter(into), default = "BLASUpper")]
100    pub uplo: BLASUpLo,
101}
102
103impl<'x, 'a, F> BLASBuilder_<'a, F, Ix2> for HER_<'x, 'a, F>
104where
105    F: HERNum,
106{
107    fn driver(self) -> Result<HER_Driver<'x, 'a, F>, BLASError> {
108        let Self { x, a, alpha, uplo, .. } = self;
109
110        // initialize intent(hide)
111        let incx = x.stride_of(Axis(0));
112        let n = x.len_of(Axis(0));
113
114        // prepare output
115        let a = match a {
116            Some(a) => {
117                blas_assert_eq!(a.dim(), (n, n), InvalidDim)?;
118                if a.view().is_fpref() {
119                    ArrayOut2::ViewMut(a)
120                } else {
121                    let a_buffer = a.view().to_col_layout()?.into_owned();
122                    ArrayOut2::ToBeCloned(a, a_buffer)
123                }
124            },
125            None => ArrayOut2::Owned(Array2::zeros((n, n).f())),
126        };
127        let lda = a.view().stride_of(Axis(1));
128
129        // finalize
130        let driver = HER_Driver {
131            uplo: uplo.try_into()?,
132            n: n.try_into()?,
133            alpha,
134            x,
135            incx: incx.try_into()?,
136            a,
137            lda: lda.try_into()?,
138        };
139        return Ok(driver);
140    }
141}
142
143/* #endregion */
144
145/* #region BLAS wrapper */
146
147pub type HER<'x, 'a, F> = HER_Builder<'x, 'a, F>;
148pub type SSYR<'x, 'a> = HER<'x, 'a, f32>;
149pub type DSYR<'x, 'a> = HER<'x, 'a, f64>;
150pub type CHER<'x, 'a> = HER<'x, 'a, c32>;
151pub type ZHER<'x, 'a> = HER<'x, 'a, c64>;
152
153impl<'x, 'a, F> BLASBuilder<'a, F, Ix2> for HER_Builder<'x, 'a, F>
154where
155    F: HERNum,
156{
157    fn run(self) -> Result<ArrayOut2<'a, F>, BLASError> {
158        // initialize
159        let obj = self.build()?;
160
161        if obj.a.as_ref().map(|a| a.view().is_fpref()) == Some(true) {
162            // F-contiguous
163            return obj.driver()?.run_blas();
164        } else {
165            // C-contiguous
166            let uplo = obj.uplo.flip()?;
167            let a = obj.a.map(|a| a.reversed_axes());
168            if F::is_complex() {
169                let x = obj.x.mapv(F::conj);
170                let obj = HER_ { a, x: x.view(), uplo, ..obj };
171                let a = obj.driver()?.run_blas()?;
172                return Ok(a.reversed_axes());
173            } else {
174                let obj = HER_ { a, uplo, ..obj };
175                let a = obj.driver()?.run_blas()?;
176                return Ok(a.reversed_axes());
177            };
178        }
179    }
180}
181
182/* #endregion */