blas_array2/blas2/
gerc.rs

1use crate::blas2::ger::{GERNum, GER_};
2use crate::ffi::{self, blas_int};
3use crate::util::*;
4use derive_builder::Builder;
5use ndarray::prelude::*;
6
7/* #region BLAS func */
8
9pub trait GERCNum: BLASFloat + GERNum {
10    unsafe fn gerc(
11        m: *const blas_int,
12        n: *const blas_int,
13        alpha: *const Self,
14        x: *const Self,
15        incx: *const blas_int,
16        y: *const Self,
17        incy: *const blas_int,
18        a: *mut Self,
19        lda: *const blas_int,
20    );
21}
22
23macro_rules! impl_func {
24    ($type: ty, $func: ident) => {
25        impl GERCNum for $type {
26            unsafe fn gerc(
27                m: *const blas_int,
28                n: *const blas_int,
29                alpha: *const Self,
30                x: *const Self,
31                incx: *const blas_int,
32                y: *const Self,
33                incy: *const blas_int,
34                a: *mut Self,
35                lda: *const blas_int,
36            ) {
37                ffi::$func(m, n, alpha, x, incx, y, incy, a, lda);
38            }
39        }
40    };
41}
42
43impl_func!(c32, cgerc_);
44impl_func!(c64, zgerc_);
45
46/* #endregion */
47
48/* #region BLAS driver */
49
50pub struct GERC_Driver<'x, 'y, 'a, F>
51where
52    F: GERCNum,
53{
54    m: blas_int,
55    n: blas_int,
56    alpha: F,
57    x: ArrayView1<'x, F>,
58    incx: blas_int,
59    y: ArrayView1<'y, F>,
60    incy: blas_int,
61    a: ArrayOut2<'a, F>,
62    lda: blas_int,
63}
64
65impl<'x, 'y, 'a, F> BLASDriver<'a, F, Ix2> for GERC_Driver<'x, 'y, 'a, F>
66where
67    F: GERCNum,
68{
69    fn run_blas(self) -> Result<ArrayOut2<'a, F>, BLASError> {
70        let Self { m, n, alpha, x, incx, y, incy, mut a, lda } = self;
71        let x_ptr = x.as_ptr();
72        let y_ptr = y.as_ptr();
73        let a_ptr = a.get_data_mut_ptr();
74
75        // assuming dimension checks has been performed
76        // unconditionally return Ok if output does not contain anything
77        if m == 0 || n == 0 {
78            return Ok(a.clone_to_view_mut());
79        }
80
81        unsafe {
82            F::gerc(&m, &n, &alpha, x_ptr, &incx, y_ptr, &incy, a_ptr, &lda);
83        }
84        return Ok(a.clone_to_view_mut());
85    }
86}
87
88/* #endregion */
89
90/* #region BLAS builder */
91
92#[derive(Builder)]
93#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
94pub struct GERC_<'x, 'y, 'a, F>
95where
96    F: BLASFloat,
97{
98    pub x: ArrayView1<'x, F>,
99    pub y: ArrayView1<'y, F>,
100
101    #[builder(setter(into, strip_option), default = "None")]
102    pub a: Option<ArrayViewMut2<'a, F>>,
103    #[builder(setter(into), default = "F::one()")]
104    pub alpha: F,
105}
106
107impl<'x, 'y, 'a, F> BLASBuilder_<'a, F, Ix2> for GERC_<'x, 'y, 'a, F>
108where
109    F: GERCNum,
110{
111    fn driver(self) -> Result<GERC_Driver<'x, 'y, 'a, F>, BLASError> {
112        let Self { x, y, a, alpha } = self;
113
114        // initialize intent(hide)
115        let incx = x.stride_of(Axis(0));
116        let incy = y.stride_of(Axis(0));
117        let m = x.len_of(Axis(0));
118        let n = y.len_of(Axis(0));
119
120        // prepare output
121        let a = match a {
122            Some(a) => {
123                blas_assert_eq!(a.dim(), (m, n), InvalidDim)?;
124                if a.view().is_fpref() {
125                    ArrayOut2::ViewMut(a)
126                } else {
127                    let a_buffer = a.view().to_col_layout()?.into_owned();
128                    ArrayOut2::ToBeCloned(a, a_buffer)
129                }
130            },
131            None => ArrayOut2::Owned(Array2::zeros((m, n).f())),
132        };
133        let lda = a.view().stride_of(Axis(1));
134
135        // finalize
136        let driver = GERC_Driver {
137            m: m.try_into()?,
138            n: n.try_into()?,
139            alpha,
140            x,
141            incx: incx.try_into()?,
142            y,
143            incy: incy.try_into()?,
144            a,
145            lda: lda.try_into()?,
146        };
147        return Ok(driver);
148    }
149}
150
151/* #endregion */
152
153/* #region BLAS wrapper */
154
155pub type GERC<'x, 'y, 'a, F> = GERC_Builder<'x, 'y, 'a, F>;
156pub type CGERC<'x, 'y, 'a> = GERC<'x, 'y, 'a, c32>;
157pub type ZGERC<'x, 'y, 'a> = GERC<'x, 'y, 'a, c64>;
158
159impl<'x, 'y, 'a, F> BLASBuilder<'a, F, Ix2> for GERC_Builder<'x, 'y, 'a, F>
160where
161    F: GERCNum,
162{
163    fn run(self) -> Result<ArrayOut2<'a, F>, BLASError> {
164        // initialize
165        let obj = self.build()?;
166
167        if obj.a.as_ref().map(|a| a.view().is_fpref()) == Some(true) {
168            // F-contiguous
169            return obj.driver()?.run_blas();
170        } else {
171            // C-contiguous
172            let a = obj.a.map(|a| a.reversed_axes());
173            let y = obj.y.mapv(F::conj);
174            let obj = GER_ { a, x: y.view(), y: obj.x, alpha: obj.alpha };
175            let a = obj.driver()?.run_blas()?;
176            return Ok(a.reversed_axes());
177        }
178    }
179}
180
181/* #endregion */