candle_gemm/
lib.rs

1#![cfg_attr(feature = "nightly", feature(stdsimd), feature(avx512_target_feature))]
2#![cfg_attr(not(feature = "std"), no_std)]
3#![warn(rust_2018_idioms)]
4
5mod gemm;
6
7pub use crate::gemm::*;
8pub use gemm_common::Parallelism;
9
10pub use gemm_f16::f16;
11
12#[cfg(test)]
13mod tests {
14    use super::*;
15    extern crate alloc;
16    use alloc::{vec, vec::Vec};
17    use num_traits::Float;
18
19    #[test]
20    fn test_gemm_f16() {
21        let mut mnks = vec![];
22        mnks.push((16, 2, 1));
23        mnks.push((0, 0, 4));
24        mnks.push((16, 1, 1));
25        mnks.push((16, 3, 1));
26        mnks.push((16, 4, 1));
27        mnks.push((16, 1, 2));
28        mnks.push((16, 2, 2));
29        mnks.push((16, 3, 2));
30        mnks.push((16, 4, 2));
31        mnks.push((16, 16, 1));
32        mnks.push((64, 64, 0));
33        mnks.push((256, 256, 256));
34        mnks.push((4096, 4096, 4));
35        mnks.push((64, 64, 4));
36        mnks.push((0, 64, 4));
37        mnks.push((64, 0, 4));
38        mnks.push((8, 16, 1));
39        mnks.push((16, 8, 1));
40        mnks.push((1, 1, 2));
41        mnks.push((4, 4, 4));
42        mnks.push((1024, 1024, 1));
43        mnks.push((1024, 1024, 4));
44        mnks.push((63, 1, 10));
45        mnks.push((63, 2, 10));
46        mnks.push((63, 3, 10));
47        mnks.push((63, 4, 10));
48        mnks.push((1, 63, 10));
49        mnks.push((2, 63, 10));
50        mnks.push((3, 63, 10));
51        mnks.push((4, 63, 10));
52
53        for (m, n, k) in mnks {
54            dbg!(m, n, k);
55            for parallelism in [Parallelism::None, Parallelism::Rayon(0)] {
56                for alpha in [0.0, 1.0, 2.3] {
57                    for beta in [0.0, 1.0, 2.3] {
58                        dbg!(alpha, beta, parallelism);
59                        let alpha = f16::from_f32(alpha);
60                        let beta = f16::from_f32(beta);
61                        let a_vec: Vec<f16> = (0..(m * k))
62                            .map(|_| f16::from_f32(rand::random()))
63                            .collect();
64                        let b_vec: Vec<f16> = (0..(k * n))
65                            .map(|_| f16::from_f32(rand::random()))
66                            .collect();
67                        let mut c_vec: Vec<f16> = (0..(m * n))
68                            .map(|_| f16::from_f32(rand::random()))
69                            .collect();
70                        let mut d_vec = c_vec.clone();
71
72                        unsafe {
73                            gemm::gemm(
74                                m,
75                                n,
76                                k,
77                                c_vec.as_mut_ptr(),
78                                m as isize,
79                                1,
80                                true,
81                                a_vec.as_ptr(),
82                                m as isize,
83                                1,
84                                b_vec.as_ptr(),
85                                k as isize,
86                                1,
87                                alpha,
88                                beta,
89                                false,
90                                false,
91                                false,
92                                parallelism,
93                            );
94
95                            gemm::gemm_fallback(
96                                m,
97                                n,
98                                k,
99                                d_vec.as_mut_ptr(),
100                                m as isize,
101                                1,
102                                true,
103                                a_vec.as_ptr(),
104                                m as isize,
105                                1,
106                                b_vec.as_ptr(),
107                                k as isize,
108                                1,
109                                alpha,
110                                beta,
111                            );
112                        }
113                        let eps = f16::from_f32(1e-1);
114                        for (c, d) in c_vec.iter().zip(d_vec.iter()) {
115                            let eps_rel = c.abs() * eps;
116                            let eps_abs = eps;
117                            let eps = if eps_rel > eps_abs { eps_rel } else { eps_abs };
118                            assert_approx_eq::assert_approx_eq!(c, d, eps);
119                        }
120                    }
121                }
122            }
123        }
124    }
125
126    #[test]
127    fn test_gemm_real() {
128        let mut mnks = vec![];
129        // large m to trigger parallelized rhs packing with big number of threads and small n
130        mnks.push((2048, 255, 255));
131
132        mnks.push((256, 256, 256));
133        mnks.push((4096, 4096, 4));
134        mnks.push((64, 64, 4));
135        mnks.push((0, 64, 4));
136        mnks.push((64, 0, 4));
137        mnks.push((0, 0, 4));
138        mnks.push((64, 64, 0));
139        mnks.push((16, 1, 1));
140        mnks.push((16, 2, 1));
141        mnks.push((16, 3, 1));
142        mnks.push((16, 4, 1));
143        mnks.push((16, 1, 2));
144        mnks.push((16, 2, 2));
145        mnks.push((16, 3, 2));
146        mnks.push((16, 4, 2));
147        mnks.push((16, 16, 1));
148        mnks.push((8, 16, 1));
149        mnks.push((16, 8, 1));
150        mnks.push((1, 1, 2));
151        mnks.push((4, 4, 4));
152        mnks.push((1024, 1024, 1));
153        mnks.push((1024, 1024, 4));
154        mnks.push((63, 1, 10));
155        mnks.push((63, 2, 10));
156        mnks.push((63, 3, 10));
157        mnks.push((63, 4, 10));
158        mnks.push((1, 63, 10));
159        mnks.push((2, 63, 10));
160        mnks.push((3, 63, 10));
161        mnks.push((4, 63, 10));
162
163        for (m, n, k) in mnks {
164            dbg!(m, n, k);
165            for parallelism in [
166                Parallelism::None,
167                Parallelism::Rayon(0),
168                Parallelism::Rayon(128),
169            ] {
170                for alpha in [0.0, 1.0, 2.3] {
171                    for beta in [0.0, 1.0, 2.3] {
172                        dbg!(alpha, beta, parallelism);
173                        let a_vec: Vec<f64> = (0..(m * k)).map(|_| rand::random()).collect();
174                        let b_vec: Vec<f64> = (0..(k * n)).map(|_| rand::random()).collect();
175                        let mut c_vec: Vec<f64> = (0..(m * n)).map(|_| rand::random()).collect();
176                        let mut d_vec = c_vec.clone();
177
178                        unsafe {
179                            gemm::gemm(
180                                m,
181                                n,
182                                k,
183                                c_vec.as_mut_ptr(),
184                                m as isize,
185                                1,
186                                true,
187                                a_vec.as_ptr(),
188                                m as isize,
189                                1,
190                                b_vec.as_ptr(),
191                                k as isize,
192                                1,
193                                alpha,
194                                beta,
195                                false,
196                                false,
197                                false,
198                                parallelism,
199                            );
200
201                            gemm::gemm_fallback(
202                                m,
203                                n,
204                                k,
205                                d_vec.as_mut_ptr(),
206                                m as isize,
207                                1,
208                                true,
209                                a_vec.as_ptr(),
210                                m as isize,
211                                1,
212                                b_vec.as_ptr(),
213                                k as isize,
214                                1,
215                                alpha,
216                                beta,
217                            );
218                        }
219                        for (c, d) in c_vec.iter().zip(d_vec.iter()) {
220                            assert_approx_eq::assert_approx_eq!(c, d);
221                        }
222                    }
223                }
224            }
225        }
226    }
227
228    #[test]
229    fn test_gemm_cplx() {
230        let mut mnks = vec![];
231        mnks.push((0, 64, 4));
232        mnks.push((64, 0, 4));
233        mnks.push((0, 0, 4));
234        mnks.push((64, 64, 4));
235        mnks.push((64, 64, 0));
236        mnks.push((6, 3, 1));
237        mnks.push((1, 1, 2));
238        mnks.push((128, 128, 128));
239        mnks.push((16, 1, 1));
240        mnks.push((16, 2, 1));
241        mnks.push((16, 3, 1));
242        mnks.push((16, 4, 1));
243        mnks.push((16, 1, 2));
244        mnks.push((16, 2, 2));
245        mnks.push((16, 3, 2));
246        mnks.push((16, 4, 2));
247        mnks.push((16, 16, 1));
248        mnks.push((8, 16, 1));
249        mnks.push((16, 8, 1));
250        mnks.push((4, 4, 4));
251        mnks.push((1024, 1024, 4));
252        mnks.push((1024, 1024, 1));
253        mnks.push((63, 1, 10));
254        mnks.push((63, 2, 10));
255        mnks.push((63, 3, 10));
256        mnks.push((63, 4, 10));
257        mnks.push((1, 63, 10));
258        mnks.push((2, 63, 10));
259        mnks.push((3, 63, 10));
260        mnks.push((4, 63, 10));
261
262        for (m, n, k) in mnks {
263            dbg!(m, n, k);
264
265            let zero = c64::new(0.0, 0.0);
266            let one = c64::new(1.0, 0.0);
267            let arbitrary = c64::new(2.3, 4.1);
268            for alpha in [zero, one, arbitrary] {
269                for beta in [zero, one, arbitrary] {
270                    dbg!(alpha, beta);
271                    for conj_dst in [false, true] {
272                        for conj_lhs in [false, true] {
273                            for conj_rhs in [false, true] {
274                                dbg!(conj_dst);
275                                dbg!(conj_lhs);
276                                dbg!(conj_rhs);
277                                let a_vec: Vec<f64> =
278                                    (0..(2 * m * k)).map(|_| rand::random()).collect();
279                                let b_vec: Vec<f64> =
280                                    (0..(2 * k * n)).map(|_| rand::random()).collect();
281                                let mut c_vec: Vec<f64> =
282                                    (0..(2 * m * n)).map(|_| rand::random()).collect();
283                                let mut d_vec = c_vec.clone();
284
285                                unsafe {
286                                    gemm::gemm(
287                                        m,
288                                        n,
289                                        k,
290                                        c_vec.as_mut_ptr() as *mut c64,
291                                        m as isize,
292                                        1,
293                                        true,
294                                        a_vec.as_ptr() as *const c64,
295                                        m as isize,
296                                        1,
297                                        b_vec.as_ptr() as *const c64,
298                                        k as isize,
299                                        1,
300                                        alpha,
301                                        beta,
302                                        conj_dst,
303                                        conj_lhs,
304                                        conj_rhs,
305                                        Parallelism::Rayon(0),
306                                    );
307
308                                    gemm::gemm_cplx_fallback(
309                                        m,
310                                        n,
311                                        k,
312                                        d_vec.as_mut_ptr() as *mut c64,
313                                        m as isize,
314                                        1,
315                                        true,
316                                        a_vec.as_ptr() as *const c64,
317                                        m as isize,
318                                        1,
319                                        b_vec.as_ptr() as *const c64,
320                                        k as isize,
321                                        1,
322                                        alpha,
323                                        beta,
324                                        conj_dst,
325                                        conj_lhs,
326                                        conj_rhs,
327                                    );
328                                }
329                                for (c, d) in c_vec.iter().zip(d_vec.iter()) {
330                                    assert_approx_eq::assert_approx_eq!(c, d);
331                                }
332                            }
333                        }
334                    }
335                }
336            }
337        }
338    }
339}