arrayfire/blas/
mod.rs

1use super::core::{
2    af_array, AfError, Array, CublasMathMode, FloatingPoint, HasAfEnum, MatProp, HANDLE_ERROR,
3};
4
5use libc::{c_int, c_uint, c_void};
6use std::vec::Vec;
7
8extern "C" {
9    fn af_gemm(
10        out: *mut af_array,
11        optlhs: c_uint,
12        optrhs: c_uint,
13        alpha: *const c_void,
14        lhs: af_array,
15        rhs: af_array,
16        beta: *const c_void,
17    ) -> c_int;
18
19    fn af_matmul(
20        out: *mut af_array,
21        lhs: af_array,
22        rhs: af_array,
23        optlhs: c_uint,
24        optrhs: c_uint,
25    ) -> c_int;
26
27    fn af_dot(
28        out: *mut af_array,
29        lhs: af_array,
30        rhs: af_array,
31        optlhs: c_uint,
32        optrhs: c_uint,
33    ) -> c_int;
34
35    fn af_transpose(out: *mut af_array, arr: af_array, conjugate: bool) -> c_int;
36    fn af_transpose_inplace(arr: af_array, conjugate: bool) -> c_int;
37
38    fn afcu_cublasSetMathMode(mode: c_int) -> c_int;
39}
40
41/// BLAS general matrix multiply (GEMM) of two Array objects
42///
43///
44/// This provides a general interface to the BLAS level 3 general matrix multiply (GEMM),
45/// which is generally defined as:
46///
47/// \begin{equation}
48///     C = \alpha * opA(A)opB(B) + \beta * C
49/// \end{equation}
50///
51///   where $\alpha$ (**alpha**) and $\beta$ (**beta**) are both scalars; $A$ and $B$ are the matrix
52///   multiply operands; and $opA$ and $opB$ are noop
53///   (if optLhs is [MatProp::NONE](./enum.MatProp.html)) or transpose
54///   (if optLhs is [MatProp::TRANS](./enum.MatProp.html)) operations on $A$ or $B$ before the
55///   actual GEMM operation. Batched GEMM is supported if at least either $A$ or $B$ have more than
56///   two dimensions (see [af::matmul](http://arrayfire.org/docs/group__blas__func__matmul.htm#ga63306b6ed967bd1055086db862fe885b)
57///   for more details on broadcasting). However, only one **alpha** and one **beta** can be used
58///   for all of the batched matrix operands.
59///
60///   The `output` Array can be used both as an input and output. An allocation will be performed
61///   if you pass an empty Array (i.e. `let c: Array<f32> = (0 as i64).into();`). If a valid Array
62///   is passed as $C$, the operation will be performed on that Array itself. The C Array must be
63///   the correct type and shape; otherwise, an error will be thrown.
64///
65///   Note: Passing an Array that has not been initialized to the C array
66///   will cause undefined behavior.
67///
68/// # Examples
69///
70/// Given below is an example of using gemm API with existing Arrays
71///
72/// ```rust
73/// use arrayfire::{Array, Dim4, print, randu, gemm};
74///
75/// let dims = Dim4::new(&[5, 5, 1, 1]);
76///
77/// let alpha = vec![1.0 as f32];
78/// let  beta = vec![2.0 as f32];
79///
80/// let lhs = randu::<f32>(dims);
81/// let rhs = randu::<f32>(dims);
82///
83/// let mut result = Array::new_empty(dims);
84/// gemm(&mut result, arrayfire::MatProp::NONE, arrayfire::MatProp::NONE,
85///      alpha, &lhs, &rhs, beta);
86/// ```
87///
88/// If you don't have an existing Array, you can also use gemm in the following fashion.
89/// However, if there is no existing Array that you need to fill and your use case doesn't
90/// deal with alpha and beta from gemm equation, it is recommended to use
91/// [matmul](./fn.matmul.html) for more terse code.
92///
93/// ```rust
94/// use arrayfire::{Array, Dim4, af_array, print, randu, gemm};
95///
96/// let dims = Dim4::new(&[5, 5, 1, 1]);
97///
98/// let alpha = vec![1.0 as f32];
99/// let  beta = vec![2.0 as f32];
100///
101/// let lhs = randu::<f32>(dims);
102/// let rhs = randu::<f32>(dims);
103///
104/// let mut result: Array::<f32> = (std::ptr::null_mut() as af_array).into();
105///
106/// gemm(&mut result, arrayfire::MatProp::NONE, arrayfire::MatProp::NONE,
107///      alpha, &lhs, &rhs, beta);
108/// ```
109///
110/// # Parameters
111///
112/// - `optlhs` - Transpose left hand side before the function is performed, uses one of the values of [MatProp](./enum.MatProp.html)
113/// - `optrhs` - Transpose right hand side before the function is performed, uses one of the values of [MatProp](./enum.MatProp.html)
114/// - `alpha` is alpha value;
115/// - `lhs` is the Array on left hand side
116/// - `rhs` is the Array on right hand side
117/// - `beta` is beta value;
118///
119/// # Return Values
120///
121/// Array, result of gemm operation
122pub fn gemm<T>(
123    output: &mut Array<T>,
124    optlhs: MatProp,
125    optrhs: MatProp,
126    alpha: Vec<T>,
127    lhs: &Array<T>,
128    rhs: &Array<T>,
129    beta: Vec<T>,
130) where
131    T: HasAfEnum + FloatingPoint,
132{
133    unsafe {
134        let mut out = output.get();
135        let err_val = af_gemm(
136            &mut out as *mut af_array,
137            optlhs as c_uint,
138            optrhs as c_uint,
139            alpha.as_ptr() as *const c_void,
140            lhs.get(),
141            rhs.get(),
142            beta.as_ptr() as *const c_void,
143        );
144        HANDLE_ERROR(AfError::from(err_val));
145        output.set(out);
146    }
147}
148
149/// Matrix multiple of two Arrays
150///
151/// # Parameters
152///
153/// - `lhs` is the Array on left hand side
154/// - `rhs` is the Array on right hand side
155/// - `optlhs` - Transpose left hand side before the function is performed, uses one of the values of [MatProp](./enum.MatProp.html)
156/// - `optrhs` - Transpose right hand side before the function is performed, uses one of the values of [MatProp](./enum.MatProp.html)
157///
158/// # Return Values
159///
160/// The result Array of matrix multiplication
161pub fn matmul<T>(lhs: &Array<T>, rhs: &Array<T>, optlhs: MatProp, optrhs: MatProp) -> Array<T>
162where
163    T: HasAfEnum + FloatingPoint,
164{
165    unsafe {
166        let mut temp: af_array = std::ptr::null_mut();
167        let err_val = af_matmul(
168            &mut temp as *mut af_array,
169            lhs.get(),
170            rhs.get(),
171            optlhs as c_uint,
172            optrhs as c_uint,
173        );
174        HANDLE_ERROR(AfError::from(err_val));
175        temp.into()
176    }
177}
178
179/// Calculate the dot product of vectors.
180///
181/// Scalar dot product between two vectors. Also referred to as the inner product. This function returns the scalar product of two equal sized vectors.
182///
183/// # Parameters
184///
185/// - `lhs` - Left hand side of dot operation
186/// - `rhs` - Right hand side of dot operation
187/// - `optlhs` - Options for lhs. Currently only NONE value from [MatProp](./enum.MatProp.html) is supported.
188/// - `optrhs` - Options for rhs. Currently only NONE value from [MatProp](./enum.MatProp.html) is supported.
189///
190/// # Return Values
191///
192/// The result of dot product.
193pub fn dot<T>(lhs: &Array<T>, rhs: &Array<T>, optlhs: MatProp, optrhs: MatProp) -> Array<T>
194where
195    T: HasAfEnum + FloatingPoint,
196{
197    unsafe {
198        let mut temp: af_array = std::ptr::null_mut();
199        let err_val = af_dot(
200            &mut temp as *mut af_array,
201            lhs.get(),
202            rhs.get(),
203            optlhs as c_uint,
204            optrhs as c_uint,
205        );
206        HANDLE_ERROR(AfError::from(err_val));
207        temp.into()
208    }
209}
210
211/// Transpose of a matrix.
212///
213/// # Parameters
214///
215/// - `arr` is the input Array
216/// - `conjugate` is a boolean that indicates if the transpose operation needs to be a conjugate
217/// transpose
218///
219/// # Return Values
220///
221/// Transposed Array.
222pub fn transpose<T: HasAfEnum>(arr: &Array<T>, conjugate: bool) -> Array<T> {
223    unsafe {
224        let mut temp: af_array = std::ptr::null_mut();
225        let err_val = af_transpose(&mut temp as *mut af_array, arr.get(), conjugate);
226        HANDLE_ERROR(AfError::from(err_val));
227        temp.into()
228    }
229}
230
231/// Inplace transpose of a matrix.
232///
233/// # Parameters
234///
235/// - `arr` is the input Array that has to be transposed
236/// - `conjugate` is a boolean that indicates if the transpose operation needs to be a conjugate
237/// transpose
238pub fn transpose_inplace<T: HasAfEnum>(arr: &mut Array<T>, conjugate: bool) {
239    unsafe {
240        let err_val = af_transpose_inplace(arr.get(), conjugate);
241        HANDLE_ERROR(AfError::from(err_val));
242    }
243}
244
245/// Sets the cuBLAS math mode for the internal handle.
246///
247/// See the cuBLAS documentation for additional details
248///
249/// # Parameters
250///
251/// - `mode` takes a value of [CublasMathMode](./enum.CublasMathMode.html) enum
252pub fn set_cublas_mode(mode: CublasMathMode) {
253    unsafe {
254        afcu_cublasSetMathMode(mode as c_int);
255        //let err_val = afcu_cublasSetMathMode(mode as c_int);
256        // FIXME(wonder if this something to throw off,
257        // the program state is not invalid or anything
258        // HANDLE_ERROR(AfError::from(err_val));
259    }
260}