1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
extern crate libc;
use array::Array;
use defines::AfError;
use defines::MatProp;
use error::HANDLE_ERROR;
use self::libc::{c_uint, c_int};
use util::{AfArray, MutAfArray, to_u32};
use util::{FloatingPoint, HasAfEnum};
#[allow(dead_code)]
extern {
fn af_matmul(out: MutAfArray, lhs: AfArray, rhs: AfArray,
optlhs: c_uint, optrhs: c_uint) -> c_int;
fn af_dot(out: MutAfArray, lhs: AfArray, rhs: AfArray,
optlhs: c_uint, optrhs: c_uint) -> c_int;
fn af_transpose(out: MutAfArray, arr: AfArray, conjugate: c_int) -> c_int;
fn af_transpose_inplace(arr: AfArray, conjugate: c_int) -> c_int;
}
#[allow(unused_mut)]
pub fn matmul<T>(lhs: &Array<T>, rhs: &Array<T>,
optlhs: MatProp, optrhs: MatProp) -> Array<T>
where T: HasAfEnum + FloatingPoint
{
let mut temp: i64 = 0;
unsafe {
let err_val = af_matmul(&mut temp as MutAfArray,
lhs.get() as AfArray, rhs.get() as AfArray,
to_u32(optlhs) as c_uint, to_u32(optrhs) as c_uint);
HANDLE_ERROR(AfError::from(err_val));
}
temp.into()
}
#[allow(unused_mut)]
pub fn dot<T>(lhs: &Array<T>, rhs: &Array<T>,
optlhs: MatProp, optrhs: MatProp) -> Array<T>
where T: HasAfEnum + FloatingPoint
{
let mut temp: i64 = 0;
unsafe {
let err_val = af_dot(&mut temp as MutAfArray,
lhs.get() as AfArray, rhs.get() as AfArray,
to_u32(optlhs) as c_uint, to_u32(optrhs) as c_uint);
HANDLE_ERROR(AfError::from(err_val));
}
temp.into()
}
#[allow(unused_mut)]
pub fn transpose<T: HasAfEnum>(arr: &Array<T>, conjugate: bool) -> Array<T> {
let mut temp: i64 = 0;
unsafe {
let err_val = af_transpose(&mut temp as MutAfArray,
arr.get() as AfArray, conjugate as c_int);
HANDLE_ERROR(AfError::from(err_val));
}
temp.into()
}
#[allow(unused_mut)]
pub fn transpose_inplace<T: HasAfEnum>(arr: &mut Array<T>, conjugate: bool) {
unsafe {
let err_val = af_transpose_inplace(arr.get() as AfArray, conjugate as c_int);
HANDLE_ERROR(AfError::from(err_val));
}
}