blas_array2/blas2/
ger.rs

1use crate::ffi::{self, blas_int};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6/* #region BLAS func */
7
8pub trait GERNum: BLASFloat {
9    unsafe fn ger(
10        m: *const blas_int,
11        n: *const blas_int,
12        alpha: *const Self,
13        x: *const Self,
14        incx: *const blas_int,
15        y: *const Self,
16        incy: *const blas_int,
17        a: *mut Self,
18        lda: *const blas_int,
19    );
20}
21
22macro_rules! impl_func {
23    ($type: ty, $func: ident) => {
24        impl GERNum for $type {
25            unsafe fn ger(
26                m: *const blas_int,
27                n: *const blas_int,
28                alpha: *const Self,
29                x: *const Self,
30                incx: *const blas_int,
31                y: *const Self,
32                incy: *const blas_int,
33                a: *mut Self,
34                lda: *const blas_int,
35            ) {
36                ffi::$func(m, n, alpha, x, incx, y, incy, a, lda);
37            }
38        }
39    };
40}
41
42impl_func!(f32, sger_);
43impl_func!(f64, dger_);
44impl_func!(c32, cgeru_);
45impl_func!(c64, zgeru_);
46
47/* #endregion */
48
49/* #region BLAS driver */
50
51pub struct GER_Driver<'x, 'y, 'a, F>
52where
53    F: BLASFloat,
54{
55    m: blas_int,
56    n: blas_int,
57    alpha: F,
58    x: ArrayView1<'x, F>,
59    incx: blas_int,
60    y: ArrayView1<'y, F>,
61    incy: blas_int,
62    a: ArrayOut2<'a, F>,
63    lda: blas_int,
64}
65
66impl<'x, 'y, 'a, F> BLASDriver<'a, F, Ix2> for GER_Driver<'x, 'y, 'a, F>
67where
68    F: GERNum,
69{
70    fn run_blas(self) -> Result<ArrayOut2<'a, F>, BLASError> {
71        let Self { m, n, alpha, x, incx, y, incy, mut a, lda } = self;
72        let x_ptr = x.as_ptr();
73        let y_ptr = y.as_ptr();
74        let a_ptr = a.get_data_mut_ptr();
75
76        // assuming dimension checks has been performed
77        // unconditionally return Ok if output does not contain anything
78        if m == 0 || n == 0 {
79            return Ok(a.clone_to_view_mut());
80        }
81
82        unsafe {
83            F::ger(&m, &n, &alpha, x_ptr, &incx, y_ptr, &incy, a_ptr, &lda);
84        }
85        return Ok(a.clone_to_view_mut());
86    }
87}
88
89/* #endregion */
90
91/* #region BLAS builder */
92
93#[derive(Builder)]
94#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
95pub struct GER_<'x, 'y, 'a, F>
96where
97    F: GERNum,
98{
99    pub x: ArrayView1<'x, F>,
100    pub y: ArrayView1<'y, F>,
101
102    #[builder(setter(into, strip_option), default = "None")]
103    pub a: Option<ArrayViewMut2<'a, F>>,
104    #[builder(setter(into), default = "F::one()")]
105    pub alpha: F,
106}
107
108impl<'x, 'y, 'a, F> BLASBuilder_<'a, F, Ix2> for GER_<'x, 'y, 'a, F>
109where
110    F: GERNum,
111{
112    fn driver(self) -> Result<GER_Driver<'x, 'y, 'a, F>, BLASError> {
113        let Self { x, y, a, alpha } = self;
114
115        // initialize intent(hide)
116        let incx = x.stride_of(Axis(0));
117        let incy = y.stride_of(Axis(0));
118        let m = x.len_of(Axis(0));
119        let n = y.len_of(Axis(0));
120
121        // prepare output
122        let a = match a {
123            Some(a) => {
124                blas_assert_eq!(a.dim(), (m, n), InvalidDim)?;
125                if a.view().is_fpref() {
126                    ArrayOut2::ViewMut(a)
127                } else {
128                    let a_buffer = a.view().to_col_layout()?.into_owned();
129                    ArrayOut2::ToBeCloned(a, a_buffer)
130                }
131            },
132            None => ArrayOut2::Owned(Array2::zeros((m, n).f())),
133        };
134        let lda = a.view().stride_of(Axis(1));
135
136        // finalize
137        let driver = GER_Driver {
138            m: m.try_into()?,
139            n: n.try_into()?,
140            alpha,
141            x,
142            incx: incx.try_into()?,
143            y,
144            incy: incy.try_into()?,
145            a,
146            lda: lda.try_into()?,
147        };
148        return Ok(driver);
149    }
150}
151
152/* #endregion */
153
154/* #region BLAS wrapper */
155
156pub type GER<'x, 'y, 'a, F> = GER_Builder<'x, 'y, 'a, F>;
157pub type SGER<'x, 'y, 'a> = GER<'x, 'y, 'a, f32>;
158pub type DGER<'x, 'y, 'a> = GER<'x, 'y, 'a, f64>;
159pub type CGERU<'x, 'y, 'a> = GER<'x, 'y, 'a, c32>;
160pub type ZGERU<'x, 'y, 'a> = GER<'x, 'y, 'a, c64>;
161
162impl<'x, 'y, 'a, F> BLASBuilder<'a, F, Ix2> for GER_Builder<'x, 'y, 'a, F>
163where
164    F: GERNum,
165{
166    fn run(self) -> Result<ArrayOut2<'a, F>, BLASError> {
167        // initialize
168        let obj = self.build()?;
169
170        if obj.a.as_ref().map(|a| a.view().is_fpref()) == Some(true) {
171            // F-contiguous
172            return obj.driver()?.run_blas();
173        } else {
174            // C-contiguous
175            let a = obj.a.map(|a| a.reversed_axes());
176            let obj = GER_ { a, x: obj.y, y: obj.x, ..obj };
177            let a = obj.driver()?.run_blas()?;
178            return Ok(a.reversed_axes());
179        }
180    }
181}
182
183/* #endregion */