custos_math/raw_ops/opencl/
gemm.rs

1use custos::{opencl::enqueue_kernel, prelude::CLBuffer, CDatatype, Device, Error, OpenCL};
2use std::fmt::Write;
3
4/// OpenCL matrix multiplication of two buffers / matrices.
5/// # Example
6/// ```
7/// use custos::{OpenCL, Buffer, Read};
8/// use custos_math::cl_gemm;
9///
10/// fn main() -> Result<(), custos::Error> {
11///     let device = OpenCL::new(0)?;
12///     let lhs = Buffer::from((&device, [15i16, 30, 21, 5, 8, 5]));
13///     let rhs = Buffer::from((&device, [3i16, 2, 7, 1, 9, 20]));
14///     
15///     let out = cl_gemm(&device, 2, 3, 2, &rhs, &lhs)?;
16///     assert_eq!(device.read(&out), vec![444, 480, 116, 118]);
17///     Ok(())
18/// }
19/// ```
20pub fn cl_gemm<'a, T: CDatatype>(
21    device: &'a OpenCL,
22    m: usize,
23    k: usize,
24    n: usize,
25    lhs: &CLBuffer<T>,
26    rhs: &CLBuffer<T>,
27) -> Result<CLBuffer<'a, T>, Error> {
28    let mut mw = 1;
29    for x in &[16, 8, 4, 2, 1] {
30        if m % x == 0 {
31            mw = *x;
32            break;
33        }
34    }
35    let mut kw = 1;
36    for x in &[8, 4, 2, 1] {
37        if n % x == 0 && k % x == 0 {
38            kw = *x;
39            break;
40        }
41    }
42    let nw = kw;
43    let mt = (((m / mw) as f32).floor()) as usize;
44    let kt = (((k / kw) as f32).floor()) as usize;
45
46    let f = (((m / mw) as f32).floor()) as usize;
47    let s = (((n / nw) as f32).floor()) as usize;
48    //'testing'/excellent code for gemm - 'currently' stolen from litenn
49
50    let mut float_mw = String::new();
51    if mw == 1 {
52        write!(&mut float_mw, "{}", T::as_c_type_str()).unwrap();
53    } else {
54        write!(&mut float_mw, "{}{}", T::as_c_type_str(), mw).unwrap();
55    }
56
57    let mut float_kw = String::new();
58    if kw == 1 {
59        write!(&mut float_kw, "{}", T::as_c_type_str()).unwrap();
60    } else {
61        write!(&mut float_kw, "{}{}", T::as_c_type_str(), kw).unwrap();
62    }
63
64    let dt = T::as_c_type_str();
65
66    let src = format!("
67        #define K {k}
68        #define N {n}
69        #define MW {mw}     // M tile Width
70        #define NW {nw}     // N tile Width  -- NW & KW should be the same !
71        #define KW {kw}     // K tile Width
72        #define MT {mt}  // MT is max for 'mt' (M tile count)
73        #define KT {kt}  // KT is max for 'kt' (K tile count)
74        #define floatMW {float_mw}
75        #define floatKW {float_kw}
76        __kernel void GeMM(const __global floatMW* restrict A, const __global floatKW* restrict B, __global floatMW* C)
77            {{
78                size_t mt = get_global_id(0);    //global M-tile id
79                size_t nc = get_global_id(1);    //global N-tile id
80
81                {dt} AT[KW][MW]; // sub tiles
82                {dt} BT[NW][KW];
83                {dt} CT[NW][MW];
84
85                #pragma unroll
86                for (uint i=0; i<NW*MW; ++i) // zero CT tile
87                    (({dt }*) CT)[i] = 0.0;
88
89                for (uint kt=0; kt<KT; ++kt)  // iterate over K-dim tiles
90                {{
91                    #pragma unroll
92                    for (uint k=0; k<KW; ++k)  // every k-element inside K-dim tile
93                        *( (floatMW*) AT[k] ) = A[(kt*KW + k)*MT + mt]; // store M-Width floats
94
95                    #pragma unroll
96                    for (uint n=0; n<NW; ++n)  // every n-element inside N-dim tile
97                        *( (floatKW*) BT[n] ) = B[(nc*NW + n)*KT + kt]; // store K-Width floats
98
99                    #pragma unroll
100                    for (uint k=0; k<KW; ++k)
101                    #pragma unroll
102                    for (uint n=0; n<NW; ++n)  // sub tiles multiplication
103                    #pragma unroll
104                    for (uint m=0; m<MW; ++m)
105                        CT[n][m] += AT[k][m] * BT[n][k];
106                }}
107
108                #pragma unroll
109                for (uint n=0; n<NW; ++n)
110                    C[(nc*NW + n)*MT + mt] = *( (floatMW*) CT[n]);
111            }}");
112
113    let gws = [f, s, 0];
114
115    let out: CLBuffer<T> = device.retrieve(n * m, (lhs.node.idx, rhs.node.idx));
116    enqueue_kernel(device, &src, gws, None, &[lhs, rhs, &out])?;
117    Ok(out)
118}