arrayfire/blas/mod.rs
1use super::core::{
2 af_array, AfError, Array, CublasMathMode, FloatingPoint, HasAfEnum, MatProp, HANDLE_ERROR,
3};
4
5use libc::{c_int, c_uint, c_void};
6use std::vec::Vec;
7
8extern "C" {
9 fn af_gemm(
10 out: *mut af_array,
11 optlhs: c_uint,
12 optrhs: c_uint,
13 alpha: *const c_void,
14 lhs: af_array,
15 rhs: af_array,
16 beta: *const c_void,
17 ) -> c_int;
18
19 fn af_matmul(
20 out: *mut af_array,
21 lhs: af_array,
22 rhs: af_array,
23 optlhs: c_uint,
24 optrhs: c_uint,
25 ) -> c_int;
26
27 fn af_dot(
28 out: *mut af_array,
29 lhs: af_array,
30 rhs: af_array,
31 optlhs: c_uint,
32 optrhs: c_uint,
33 ) -> c_int;
34
35 fn af_transpose(out: *mut af_array, arr: af_array, conjugate: bool) -> c_int;
36 fn af_transpose_inplace(arr: af_array, conjugate: bool) -> c_int;
37
38 fn afcu_cublasSetMathMode(mode: c_int) -> c_int;
39}
40
41/// BLAS general matrix multiply (GEMM) of two Array objects
42///
43///
44/// This provides a general interface to the BLAS level 3 general matrix multiply (GEMM),
45/// which is generally defined as:
46///
47/// \begin{equation}
48/// C = \alpha * opA(A)opB(B) + \beta * C
49/// \end{equation}
50///
51/// where $\alpha$ (**alpha**) and $\beta$ (**beta**) are both scalars; $A$ and $B$ are the matrix
52/// multiply operands; and $opA$ and $opB$ are noop
53/// (if optLhs is [MatProp::NONE](./enum.MatProp.html)) or transpose
54/// (if optLhs is [MatProp::TRANS](./enum.MatProp.html)) operations on $A$ or $B$ before the
55/// actual GEMM operation. Batched GEMM is supported if at least either $A$ or $B$ have more than
56/// two dimensions (see [af::matmul](http://arrayfire.org/docs/group__blas__func__matmul.htm#ga63306b6ed967bd1055086db862fe885b)
57/// for more details on broadcasting). However, only one **alpha** and one **beta** can be used
58/// for all of the batched matrix operands.
59///
60/// The `output` Array can be used both as an input and output. An allocation will be performed
61/// if you pass an empty Array (i.e. `let c: Array<f32> = (0 as i64).into();`). If a valid Array
62/// is passed as $C$, the operation will be performed on that Array itself. The C Array must be
63/// the correct type and shape; otherwise, an error will be thrown.
64///
65/// Note: Passing an Array that has not been initialized to the C array
66/// will cause undefined behavior.
67///
68/// # Examples
69///
70/// Given below is an example of using gemm API with existing Arrays
71///
72/// ```rust
73/// use arrayfire::{Array, Dim4, print, randu, gemm};
74///
75/// let dims = Dim4::new(&[5, 5, 1, 1]);
76///
77/// let alpha = vec![1.0 as f32];
78/// let beta = vec![2.0 as f32];
79///
80/// let lhs = randu::<f32>(dims);
81/// let rhs = randu::<f32>(dims);
82///
83/// let mut result = Array::new_empty(dims);
84/// gemm(&mut result, arrayfire::MatProp::NONE, arrayfire::MatProp::NONE,
85/// alpha, &lhs, &rhs, beta);
86/// ```
87///
88/// If you don't have an existing Array, you can also use gemm in the following fashion.
89/// However, if there is no existing Array that you need to fill and your use case doesn't
90/// deal with alpha and beta from gemm equation, it is recommended to use
91/// [matmul](./fn.matmul.html) for more terse code.
92///
93/// ```rust
94/// use arrayfire::{Array, Dim4, af_array, print, randu, gemm};
95///
96/// let dims = Dim4::new(&[5, 5, 1, 1]);
97///
98/// let alpha = vec![1.0 as f32];
99/// let beta = vec![2.0 as f32];
100///
101/// let lhs = randu::<f32>(dims);
102/// let rhs = randu::<f32>(dims);
103///
104/// let mut result: Array::<f32> = (std::ptr::null_mut() as af_array).into();
105///
106/// gemm(&mut result, arrayfire::MatProp::NONE, arrayfire::MatProp::NONE,
107/// alpha, &lhs, &rhs, beta);
108/// ```
109///
110/// # Parameters
111///
112/// - `optlhs` - Transpose left hand side before the function is performed, uses one of the values of [MatProp](./enum.MatProp.html)
113/// - `optrhs` - Transpose right hand side before the function is performed, uses one of the values of [MatProp](./enum.MatProp.html)
114/// - `alpha` is alpha value;
115/// - `lhs` is the Array on left hand side
116/// - `rhs` is the Array on right hand side
117/// - `beta` is beta value;
118///
119/// # Return Values
120///
121/// Array, result of gemm operation
122pub fn gemm<T>(
123 output: &mut Array<T>,
124 optlhs: MatProp,
125 optrhs: MatProp,
126 alpha: Vec<T>,
127 lhs: &Array<T>,
128 rhs: &Array<T>,
129 beta: Vec<T>,
130) where
131 T: HasAfEnum + FloatingPoint,
132{
133 unsafe {
134 let mut out = output.get();
135 let err_val = af_gemm(
136 &mut out as *mut af_array,
137 optlhs as c_uint,
138 optrhs as c_uint,
139 alpha.as_ptr() as *const c_void,
140 lhs.get(),
141 rhs.get(),
142 beta.as_ptr() as *const c_void,
143 );
144 HANDLE_ERROR(AfError::from(err_val));
145 output.set(out);
146 }
147}
148
149/// Matrix multiple of two Arrays
150///
151/// # Parameters
152///
153/// - `lhs` is the Array on left hand side
154/// - `rhs` is the Array on right hand side
155/// - `optlhs` - Transpose left hand side before the function is performed, uses one of the values of [MatProp](./enum.MatProp.html)
156/// - `optrhs` - Transpose right hand side before the function is performed, uses one of the values of [MatProp](./enum.MatProp.html)
157///
158/// # Return Values
159///
160/// The result Array of matrix multiplication
161pub fn matmul<T>(lhs: &Array<T>, rhs: &Array<T>, optlhs: MatProp, optrhs: MatProp) -> Array<T>
162where
163 T: HasAfEnum + FloatingPoint,
164{
165 unsafe {
166 let mut temp: af_array = std::ptr::null_mut();
167 let err_val = af_matmul(
168 &mut temp as *mut af_array,
169 lhs.get(),
170 rhs.get(),
171 optlhs as c_uint,
172 optrhs as c_uint,
173 );
174 HANDLE_ERROR(AfError::from(err_val));
175 temp.into()
176 }
177}
178
179/// Calculate the dot product of vectors.
180///
181/// Scalar dot product between two vectors. Also referred to as the inner product. This function returns the scalar product of two equal sized vectors.
182///
183/// # Parameters
184///
185/// - `lhs` - Left hand side of dot operation
186/// - `rhs` - Right hand side of dot operation
187/// - `optlhs` - Options for lhs. Currently only NONE value from [MatProp](./enum.MatProp.html) is supported.
188/// - `optrhs` - Options for rhs. Currently only NONE value from [MatProp](./enum.MatProp.html) is supported.
189///
190/// # Return Values
191///
192/// The result of dot product.
193pub fn dot<T>(lhs: &Array<T>, rhs: &Array<T>, optlhs: MatProp, optrhs: MatProp) -> Array<T>
194where
195 T: HasAfEnum + FloatingPoint,
196{
197 unsafe {
198 let mut temp: af_array = std::ptr::null_mut();
199 let err_val = af_dot(
200 &mut temp as *mut af_array,
201 lhs.get(),
202 rhs.get(),
203 optlhs as c_uint,
204 optrhs as c_uint,
205 );
206 HANDLE_ERROR(AfError::from(err_val));
207 temp.into()
208 }
209}
210
211/// Transpose of a matrix.
212///
213/// # Parameters
214///
215/// - `arr` is the input Array
216/// - `conjugate` is a boolean that indicates if the transpose operation needs to be a conjugate
217/// transpose
218///
219/// # Return Values
220///
221/// Transposed Array.
222pub fn transpose<T: HasAfEnum>(arr: &Array<T>, conjugate: bool) -> Array<T> {
223 unsafe {
224 let mut temp: af_array = std::ptr::null_mut();
225 let err_val = af_transpose(&mut temp as *mut af_array, arr.get(), conjugate);
226 HANDLE_ERROR(AfError::from(err_val));
227 temp.into()
228 }
229}
230
231/// Inplace transpose of a matrix.
232///
233/// # Parameters
234///
235/// - `arr` is the input Array that has to be transposed
236/// - `conjugate` is a boolean that indicates if the transpose operation needs to be a conjugate
237/// transpose
238pub fn transpose_inplace<T: HasAfEnum>(arr: &mut Array<T>, conjugate: bool) {
239 unsafe {
240 let err_val = af_transpose_inplace(arr.get(), conjugate);
241 HANDLE_ERROR(AfError::from(err_val));
242 }
243}
244
245/// Sets the cuBLAS math mode for the internal handle.
246///
247/// See the cuBLAS documentation for additional details
248///
249/// # Parameters
250///
251/// - `mode` takes a value of [CublasMathMode](./enum.CublasMathMode.html) enum
252pub fn set_cublas_mode(mode: CublasMathMode) {
253 unsafe {
254 afcu_cublasSetMathMode(mode as c_int);
255 //let err_val = afcu_cublasSetMathMode(mode as c_int);
256 // FIXME(wonder if this something to throw off,
257 // the program state is not invalid or anything
258 // HANDLE_ERROR(AfError::from(err_val));
259 }
260}