autograd/ops/
mkl_ffi.rs

1#[cfg(feature = "mkl")]
2pub(crate) type MklInt = i64;
3
4#[cfg(feature = "mkl")]
5#[allow(dead_code)]
6pub(crate) enum MemoryOrder {
7    C,
8    F,
9}
10
11#[cfg(feature = "mkl")]
12#[repr(C)]
13#[derive(Clone, Copy, Debug)]
14#[allow(dead_code)]
15pub(crate) enum CblasTranspose {
16    CblasNoTrans = 111,
17    CblasTrans = 112,
18    CblasConjTrans = 113,
19}
20
21#[cfg(feature = "mkl")]
22pub(crate) type CblasLayout = usize;
23
24#[cfg(feature = "mkl")]
25pub(crate) const CBLAS_ROW_MAJOR: usize = 101;
26
27#[cfg(feature = "mkl")]
28extern "C" {
29    pub(crate) fn cblas_sgemm(
30        layout: CblasLayout,
31        transa: CblasTranspose,
32        transb: CblasTranspose,
33        m: MklInt,
34        n: MklInt,
35        k: MklInt,
36        alpha: libc::c_float,
37        a: *const libc::c_float,
38        lda: MklInt,
39        b: *const libc::c_float,
40        ldb: MklInt,
41        beta: libc::c_float,
42        c: *mut libc::c_float,
43        ldc: MklInt,
44    );
45
46    pub(crate) fn cblas_dgemm(
47        layout: CblasLayout,
48        transa: CblasTranspose,
49        transb: CblasTranspose,
50        m: MklInt,
51        n: MklInt,
52        k: MklInt,
53        alpha: libc::c_double,
54        a: *const libc::c_double,
55        lda: MklInt,
56        b: *const libc::c_double,
57        ldb: MklInt,
58        beta: libc::c_double,
59        c: *mut libc::c_double,
60        ldc: MklInt,
61    );
62
63    pub(crate) fn cblas_sgemm_batch(
64        layout: CblasLayout,
65        transa_array: *const CblasTranspose, // batch of CblasTranspose
66        transb_array: *const CblasTranspose, // batch of CblasTranspose
67        m_array: *const MklInt,              // batch of m
68        n_array: *const MklInt,              // batch of n
69        k_array: *const MklInt,              // batch of k
70        alpha_array: *const libc::c_float,   // batch of alpha
71        a_array: *const *const libc::c_float, // a
72        lda_array: *const MklInt,            // batch of lda
73        b_array: *const *const libc::c_float, // b
74        ldb_array: *const MklInt,            // batch of ldb
75        beta_array: *const libc::c_float,    // batch of beta
76        c_array: *mut *mut libc::c_float,    // c
77        ldc_array: *const MklInt,            // batch of odc
78        group_count: MklInt,                 // batch size
79        group_size: *const MklInt,
80    );
81
82    pub(crate) fn cblas_dgemm_batch(
83        layout: CblasLayout,
84        transa_array: *const CblasTranspose, // batch of CblasTranspose
85        transb_array: *const CblasTranspose, // batch of CblasTranspose
86        m_array: *const MklInt,              // batch of m
87        n_array: *const MklInt,              // batch of n
88        k_array: *const MklInt,              // batch of k
89        alpha_array: *const libc::c_double,  // batch of alpha
90        a_array: *const *const libc::c_double, // a
91        lda_array: *const MklInt,            // batch of lda
92        b_array: *const *const libc::c_double, // b
93        ldb_array: *const MklInt,            // batch of ldb
94        beta_array: *const libc::c_double,   // batch of beta
95        c_array: *mut *mut libc::c_double,   // c
96        ldc_array: *const MklInt,            // batch of odc
97        group_count: MklInt,                 // batch size
98        group_size: *const MklInt,
99    );
100
101    pub(crate) fn vsSin(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
102    pub(crate) fn vdSin(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
103
104    pub(crate) fn vsAsin(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
105    pub(crate) fn vdAsin(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
106
107    pub(crate) fn vsSinh(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
108    pub(crate) fn vdSinh(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
109
110    pub(crate) fn vsAsinh(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
111    pub(crate) fn vdAsinh(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
112
113    pub(crate) fn vsCos(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
114    pub(crate) fn vdCos(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
115
116    pub(crate) fn vsAcos(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
117    pub(crate) fn vdAcos(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
118
119    pub(crate) fn vsCosh(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
120    pub(crate) fn vdCosh(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
121
122    pub(crate) fn vsAcosh(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
123    pub(crate) fn vdAcosh(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
124
125    pub(crate) fn vsTan(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
126    pub(crate) fn vdTan(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
127
128    pub(crate) fn vsAtan(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
129    pub(crate) fn vdAtan(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
130
131    pub(crate) fn vsTanh(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
132    pub(crate) fn vdTanh(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
133
134    pub(crate) fn vsAtanh(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
135    pub(crate) fn vdAtanh(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
136
137    pub(crate) fn vsExp(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
138    pub(crate) fn vdExp(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
139
140    pub(crate) fn vsExp2(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
141    pub(crate) fn vdExp2(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
142
143    pub(crate) fn vsExp10(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
144    pub(crate) fn vdExp10(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
145
146    pub(crate) fn vsLn(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
147    pub(crate) fn vdLn(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
148
149    pub(crate) fn vsLog2(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
150    pub(crate) fn vdLog2(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
151
152    pub(crate) fn vsLog10(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
153    pub(crate) fn vdLog10(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
154
155    pub(crate) fn vsInv(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
156    pub(crate) fn vdInv(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
157
158    pub(crate) fn vsDiv(
159        n: MklInt,
160        a: *const libc::c_float,
161        b: *const libc::c_float,
162        y: *mut libc::c_float,
163    );
164    pub(crate) fn vdDiv(
165        n: MklInt,
166        a: *const libc::c_double,
167        b: *const libc::c_double,
168        y: *mut libc::c_double,
169    );
170
171    pub(crate) fn vsSqrt(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
172    pub(crate) fn vdSqrt(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
173
174    pub(crate) fn vsPowx(
175        n: MklInt,
176        a: *const libc::c_float,
177        b: libc::c_float,
178        y: *mut libc::c_float,
179    );
180    pub(crate) fn vdPowx(
181        n: MklInt,
182        a: *const libc::c_double,
183        b: libc::c_double,
184        y: *mut libc::c_double,
185    );
186
187    pub(crate) fn vsInvSqrt(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
188    pub(crate) fn vdInvSqrt(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
189
190    pub(crate) fn vsAdd(
191        n: MklInt,
192        a: *const libc::c_float,
193        b: *const libc::c_float,
194        y: *mut libc::c_float,
195    );
196    pub(crate) fn vdAdd(
197        n: MklInt,
198        a: *const libc::c_double,
199        b: *const libc::c_double,
200        y: *mut libc::c_double,
201    );
202
203    pub(crate) fn vsSub(
204        n: MklInt,
205        a: *const libc::c_float,
206        b: *const libc::c_float,
207        y: *mut libc::c_float,
208    );
209    pub(crate) fn vdSub(
210        n: MklInt,
211        a: *const libc::c_double,
212        b: *const libc::c_double,
213        y: *mut libc::c_double,
214    );
215
216    pub(crate) fn vsSqr(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
217    pub(crate) fn vdSqr(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
218
219    pub(crate) fn vsMul(
220        n: MklInt,
221        a: *const libc::c_float,
222        b: *const libc::c_float,
223        y: *mut libc::c_float,
224    );
225    pub(crate) fn vdMul(
226        n: MklInt,
227        a: *const libc::c_double,
228        b: *const libc::c_double,
229        y: *mut libc::c_double,
230    );
231
232    pub(crate) fn vsAbs(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
233    pub(crate) fn vdAbs(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
234
235    pub(crate) fn vsFloor(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
236    pub(crate) fn vdFloor(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
237
238    pub(crate) fn vsCeil(n: MklInt, a: *const libc::c_float, y: *mut libc::c_float);
239    pub(crate) fn vdCeil(n: MklInt, a: *const libc::c_double, y: *mut libc::c_double);
240}