blas_array2/blas2/
her2.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6/* #region BLAS func */
7
8pub trait HER2Num: BLASFloat {
9    unsafe fn syr2(
10        uplo: *const c_char,
11        n: *const blas_int,
12        alpha: *const Self,
13        x: *const Self,
14        incx: *const blas_int,
15        y: *const Self,
16        incy: *const blas_int,
17        a: *mut Self,
18        lda: *const blas_int,
19    );
20}
21
22macro_rules! impl_func {
23    ($type: ty, $func: ident) => {
24        impl HER2Num for $type {
25            unsafe fn syr2(
26                uplo: *const c_char,
27                n: *const blas_int,
28                alpha: *const Self,
29                x: *const Self,
30                incx: *const blas_int,
31                y: *const Self,
32                incy: *const blas_int,
33                a: *mut Self,
34                lda: *const blas_int,
35            ) {
36                ffi::$func(uplo, n, alpha, x, incx, y, incy, a, lda);
37            }
38        }
39    };
40}
41
42impl_func!(f32, ssyr2_);
43impl_func!(f64, dsyr2_);
44impl_func!(c32, cher2_);
45impl_func!(c64, zher2_);
46
47/* #endregion */
48
49/* #region BLAS driver */
50
51pub struct HER2_Driver<'x, 'y, 'a, F>
52where
53    F: HER2Num,
54{
55    uplo: c_char,
56    n: blas_int,
57    alpha: F,
58    x: ArrayView1<'x, F>,
59    incx: blas_int,
60    y: ArrayView1<'y, F>,
61    incy: blas_int,
62    a: ArrayOut2<'a, F>,
63    lda: blas_int,
64}
65
66impl<'x, 'y, 'a, F> BLASDriver<'a, F, Ix2> for HER2_Driver<'x, 'y, 'a, F>
67where
68    F: HER2Num,
69{
70    fn run_blas(self) -> Result<ArrayOut2<'a, F>, BLASError> {
71        let Self { uplo, n, alpha, x, incx, y, incy, mut a, lda } = self;
72        let x_ptr = x.as_ptr();
73        let y_ptr = y.as_ptr();
74        let a_ptr = a.get_data_mut_ptr();
75
76        // assuming dimension checks has been performed
77        // unconditionally return Ok if output does not contain anything
78        if n == 0 {
79            return Ok(a.clone_to_view_mut());
80        }
81
82        unsafe {
83            F::syr2(&uplo, &n, &alpha, x_ptr, &incx, y_ptr, &incy, a_ptr, &lda);
84        }
85        return Ok(a.clone_to_view_mut());
86    }
87}
88
89/* #endregion */
90
91/* #region BLAS builder */
92
93#[derive(Builder)]
94#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
95pub struct HER2_<'x, 'y, 'a, F>
96where
97    F: BLASFloat,
98{
99    pub x: ArrayView1<'x, F>,
100    pub y: ArrayView1<'y, F>,
101
102    #[builder(setter(into, strip_option), default = "None")]
103    pub a: Option<ArrayViewMut2<'a, F>>,
104    #[builder(setter(into), default = "F::one()")]
105    pub alpha: F,
106    #[builder(setter(into), default = "BLASUpper")]
107    pub uplo: BLASUpLo,
108}
109
110impl<'x, 'y, 'a, F> BLASBuilder_<'a, F, Ix2> for HER2_<'x, 'y, 'a, F>
111where
112    F: HER2Num,
113{
114    fn driver(self) -> Result<HER2_Driver<'x, 'y, 'a, F>, BLASError> {
115        let Self { x, y, a, alpha, uplo, .. } = self;
116
117        // initialize intent(hide)
118        let incx = x.stride_of(Axis(0));
119        let incy = y.stride_of(Axis(0));
120        let n = x.len_of(Axis(0));
121
122        // check optional
123        blas_assert_eq!(y.len_of(Axis(0)), n, InvalidDim)?;
124
125        // prepare output
126        let a = match a {
127            Some(a) => {
128                blas_assert_eq!(a.dim(), (n, n), InvalidDim)?;
129                if a.view().is_fpref() {
130                    ArrayOut2::ViewMut(a)
131                } else {
132                    let a_buffer = a.view().to_col_layout()?.into_owned();
133                    ArrayOut2::ToBeCloned(a, a_buffer)
134                }
135            },
136            None => ArrayOut2::Owned(Array2::zeros((n, n).f())),
137        };
138        let lda = a.view().stride_of(Axis(1));
139
140        // finalize
141        let driver = HER2_Driver {
142            uplo: uplo.try_into()?,
143            n: n.try_into()?,
144            alpha,
145            x,
146            incx: incx.try_into()?,
147            y,
148            incy: incy.try_into()?,
149            a,
150            lda: lda.try_into()?,
151        };
152        return Ok(driver);
153    }
154}
155
156/* #endregion */
157
158/* #region BLAS wrapper */
159
160pub type SYR2<'x, 'y, 'a, F> = HER2_Builder<'x, 'y, 'a, F>;
161pub type SSYR2<'x, 'y, 'a> = SYR2<'x, 'y, 'a, f32>;
162pub type DSYR2<'x, 'y, 'a> = SYR2<'x, 'y, 'a, f64>;
163
164pub type HER2<'x, 'y, 'a, F> = HER2_Builder<'x, 'y, 'a, F>;
165pub type CHER2<'x, 'y, 'a> = HER2<'x, 'y, 'a, c32>;
166pub type ZHER2<'x, 'y, 'a> = HER2<'x, 'y, 'a, c64>;
167
168impl<'x, 'y, 'a, F> BLASBuilder<'a, F, Ix2> for HER2_Builder<'x, 'y, 'a, F>
169where
170    F: HER2Num,
171{
172    fn run(self) -> Result<ArrayOut2<'a, F>, BLASError> {
173        // initialize
174        let obj = self.build()?;
175
176        if obj.a.as_ref().map(|a| a.view().is_fpref()) == Some(true) {
177            // F-contiguous
178            return obj.driver()?.run_blas();
179        } else {
180            // C-contiguous
181            let uplo = obj.uplo.flip()?;
182            let a = obj.a.map(|a| a.reversed_axes());
183            if F::is_complex() {
184                let x = obj.x.mapv(F::conj);
185                let y = obj.y.mapv(F::conj);
186                let obj = HER2_ { a, y: x.view(), x: y.view(), uplo, ..obj };
187                let a = obj.driver()?.run_blas()?;
188                return Ok(a.reversed_axes());
189            } else {
190                let obj = HER2_ { a, uplo, x: obj.y, y: obj.x, ..obj };
191                let a = obj.driver()?.run_blas()?;
192                return Ok(a.reversed_axes());
193            };
194        }
195    }
196}
197
198/* #endregion */