1use crate::ffi::{self, blas_int};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6pub trait GERNum: BLASFloat {
9 unsafe fn ger(
10 m: *const blas_int,
11 n: *const blas_int,
12 alpha: *const Self,
13 x: *const Self,
14 incx: *const blas_int,
15 y: *const Self,
16 incy: *const blas_int,
17 a: *mut Self,
18 lda: *const blas_int,
19 );
20}
21
22macro_rules! impl_func {
23 ($type: ty, $func: ident) => {
24 impl GERNum for $type {
25 unsafe fn ger(
26 m: *const blas_int,
27 n: *const blas_int,
28 alpha: *const Self,
29 x: *const Self,
30 incx: *const blas_int,
31 y: *const Self,
32 incy: *const blas_int,
33 a: *mut Self,
34 lda: *const blas_int,
35 ) {
36 ffi::$func(m, n, alpha, x, incx, y, incy, a, lda);
37 }
38 }
39 };
40}
41
42impl_func!(f32, sger_);
43impl_func!(f64, dger_);
44impl_func!(c32, cgeru_);
45impl_func!(c64, zgeru_);
46
47pub struct GER_Driver<'x, 'y, 'a, F>
52where
53 F: BLASFloat,
54{
55 m: blas_int,
56 n: blas_int,
57 alpha: F,
58 x: ArrayView1<'x, F>,
59 incx: blas_int,
60 y: ArrayView1<'y, F>,
61 incy: blas_int,
62 a: ArrayOut2<'a, F>,
63 lda: blas_int,
64}
65
66impl<'x, 'y, 'a, F> BLASDriver<'a, F, Ix2> for GER_Driver<'x, 'y, 'a, F>
67where
68 F: GERNum,
69{
70 fn run_blas(self) -> Result<ArrayOut2<'a, F>, BLASError> {
71 let Self { m, n, alpha, x, incx, y, incy, mut a, lda } = self;
72 let x_ptr = x.as_ptr();
73 let y_ptr = y.as_ptr();
74 let a_ptr = a.get_data_mut_ptr();
75
76 if m == 0 || n == 0 {
79 return Ok(a.clone_to_view_mut());
80 }
81
82 unsafe {
83 F::ger(&m, &n, &alpha, x_ptr, &incx, y_ptr, &incy, a_ptr, &lda);
84 }
85 return Ok(a.clone_to_view_mut());
86 }
87}
88
89#[derive(Builder)]
94#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
95pub struct GER_<'x, 'y, 'a, F>
96where
97 F: GERNum,
98{
99 pub x: ArrayView1<'x, F>,
100 pub y: ArrayView1<'y, F>,
101
102 #[builder(setter(into, strip_option), default = "None")]
103 pub a: Option<ArrayViewMut2<'a, F>>,
104 #[builder(setter(into), default = "F::one()")]
105 pub alpha: F,
106}
107
108impl<'x, 'y, 'a, F> BLASBuilder_<'a, F, Ix2> for GER_<'x, 'y, 'a, F>
109where
110 F: GERNum,
111{
112 fn driver(self) -> Result<GER_Driver<'x, 'y, 'a, F>, BLASError> {
113 let Self { x, y, a, alpha } = self;
114
115 let incx = x.stride_of(Axis(0));
117 let incy = y.stride_of(Axis(0));
118 let m = x.len_of(Axis(0));
119 let n = y.len_of(Axis(0));
120
121 let a = match a {
123 Some(a) => {
124 blas_assert_eq!(a.dim(), (m, n), InvalidDim)?;
125 if a.view().is_fpref() {
126 ArrayOut2::ViewMut(a)
127 } else {
128 let a_buffer = a.view().to_col_layout()?.into_owned();
129 ArrayOut2::ToBeCloned(a, a_buffer)
130 }
131 },
132 None => ArrayOut2::Owned(Array2::zeros((m, n).f())),
133 };
134 let lda = a.view().stride_of(Axis(1));
135
136 let driver = GER_Driver {
138 m: m.try_into()?,
139 n: n.try_into()?,
140 alpha,
141 x,
142 incx: incx.try_into()?,
143 y,
144 incy: incy.try_into()?,
145 a,
146 lda: lda.try_into()?,
147 };
148 return Ok(driver);
149 }
150}
151
152pub type GER<'x, 'y, 'a, F> = GER_Builder<'x, 'y, 'a, F>;
157pub type SGER<'x, 'y, 'a> = GER<'x, 'y, 'a, f32>;
158pub type DGER<'x, 'y, 'a> = GER<'x, 'y, 'a, f64>;
159pub type CGERU<'x, 'y, 'a> = GER<'x, 'y, 'a, c32>;
160pub type ZGERU<'x, 'y, 'a> = GER<'x, 'y, 'a, c64>;
161
162impl<'x, 'y, 'a, F> BLASBuilder<'a, F, Ix2> for GER_Builder<'x, 'y, 'a, F>
163where
164 F: GERNum,
165{
166 fn run(self) -> Result<ArrayOut2<'a, F>, BLASError> {
167 let obj = self.build()?;
169
170 if obj.a.as_ref().map(|a| a.view().is_fpref()) == Some(true) {
171 return obj.driver()?.run_blas();
173 } else {
174 let a = obj.a.map(|a| a.reversed_axes());
176 let obj = GER_ { a, x: obj.y, y: obj.x, ..obj };
177 let a = obj.driver()?.run_blas()?;
178 return Ok(a.reversed_axes());
179 }
180 }
181}
182
183