custos_math/raw_ops/opencl/
gemm.rs1use custos::{opencl::enqueue_kernel, prelude::CLBuffer, CDatatype, Device, Error, OpenCL};
2use std::fmt::Write;
3
4pub fn cl_gemm<'a, T: CDatatype>(
21 device: &'a OpenCL,
22 m: usize,
23 k: usize,
24 n: usize,
25 lhs: &CLBuffer<T>,
26 rhs: &CLBuffer<T>,
27) -> Result<CLBuffer<'a, T>, Error> {
28 let mut mw = 1;
29 for x in &[16, 8, 4, 2, 1] {
30 if m % x == 0 {
31 mw = *x;
32 break;
33 }
34 }
35 let mut kw = 1;
36 for x in &[8, 4, 2, 1] {
37 if n % x == 0 && k % x == 0 {
38 kw = *x;
39 break;
40 }
41 }
42 let nw = kw;
43 let mt = (((m / mw) as f32).floor()) as usize;
44 let kt = (((k / kw) as f32).floor()) as usize;
45
46 let f = (((m / mw) as f32).floor()) as usize;
47 let s = (((n / nw) as f32).floor()) as usize;
48 let mut float_mw = String::new();
51 if mw == 1 {
52 write!(&mut float_mw, "{}", T::as_c_type_str()).unwrap();
53 } else {
54 write!(&mut float_mw, "{}{}", T::as_c_type_str(), mw).unwrap();
55 }
56
57 let mut float_kw = String::new();
58 if kw == 1 {
59 write!(&mut float_kw, "{}", T::as_c_type_str()).unwrap();
60 } else {
61 write!(&mut float_kw, "{}{}", T::as_c_type_str(), kw).unwrap();
62 }
63
64 let dt = T::as_c_type_str();
65
66 let src = format!("
67 #define K {k}
68 #define N {n}
69 #define MW {mw} // M tile Width
70 #define NW {nw} // N tile Width -- NW & KW should be the same !
71 #define KW {kw} // K tile Width
72 #define MT {mt} // MT is max for 'mt' (M tile count)
73 #define KT {kt} // KT is max for 'kt' (K tile count)
74 #define floatMW {float_mw}
75 #define floatKW {float_kw}
76 __kernel void GeMM(const __global floatMW* restrict A, const __global floatKW* restrict B, __global floatMW* C)
77 {{
78 size_t mt = get_global_id(0); //global M-tile id
79 size_t nc = get_global_id(1); //global N-tile id
80
81 {dt} AT[KW][MW]; // sub tiles
82 {dt} BT[NW][KW];
83 {dt} CT[NW][MW];
84
85 #pragma unroll
86 for (uint i=0; i<NW*MW; ++i) // zero CT tile
87 (({dt }*) CT)[i] = 0.0;
88
89 for (uint kt=0; kt<KT; ++kt) // iterate over K-dim tiles
90 {{
91 #pragma unroll
92 for (uint k=0; k<KW; ++k) // every k-element inside K-dim tile
93 *( (floatMW*) AT[k] ) = A[(kt*KW + k)*MT + mt]; // store M-Width floats
94
95 #pragma unroll
96 for (uint n=0; n<NW; ++n) // every n-element inside N-dim tile
97 *( (floatKW*) BT[n] ) = B[(nc*NW + n)*KT + kt]; // store K-Width floats
98
99 #pragma unroll
100 for (uint k=0; k<KW; ++k)
101 #pragma unroll
102 for (uint n=0; n<NW; ++n) // sub tiles multiplication
103 #pragma unroll
104 for (uint m=0; m<MW; ++m)
105 CT[n][m] += AT[k][m] * BT[n][k];
106 }}
107
108 #pragma unroll
109 for (uint n=0; n<NW; ++n)
110 C[(nc*NW + n)*MT + mt] = *( (floatMW*) CT[n]);
111 }}");
112
113 let gws = [f, s, 0];
114
115 let out: CLBuffer<T> = device.retrieve(n * m, (lhs.node.idx, rhs.node.idx));
116 enqueue_kernel(device, &src, gws, None, &[lhs, rhs, &out])?;
117 Ok(out)
118}