burn_jit/kernel/matmul/tune/
key.rs

1use crate::{tensor::JitTensor, FloatElement, JitAutotuneKey, JitRuntime};
2use burn_tensor::{DType, Shape};
3use core::fmt::Debug;
4use cubecl::AutotuneKey;
5use serde::{Deserialize, Serialize};
6use std::{cmp::max, hash::Hash};
7
8#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
9/// Autotune key representative of matmul versions
10pub struct MatmulAutotuneKey {
11    round: bool,     // True when all matmul dims are multiples of 64
12    broadcast: bool, // True when there are differences in batch size
13    #[autotune(anchor)]
14    m: usize,
15    #[autotune(anchor)]
16    k: usize,
17    #[autotune(anchor)]
18    n: usize,
19    #[autotune(anchor(max = 256))]
20    batch: usize,
21    dtype: DType,
22}
23
24impl MatmulAutotuneKey {
25    fn from_shape(lhs_shape: &Shape, rhs_shape: &Shape, dtype: DType) -> Self {
26        let ndims = lhs_shape.num_dims();
27        let m = lhs_shape.dims[ndims - 2];
28        let k = lhs_shape.dims[ndims - 1];
29        let n = rhs_shape.dims[ndims - 1];
30
31        let mut broadcast = false;
32        let mut batch_product_lhs = 1;
33        let mut batch_product_rhs = 1;
34
35        for b in 0..ndims - 2 {
36            batch_product_lhs *= lhs_shape.dims[b];
37            batch_product_rhs *= rhs_shape.dims[b];
38            if lhs_shape.dims[b] != rhs_shape.dims[b] {
39                broadcast = true;
40            }
41        }
42        let batch_product = max(batch_product_lhs, batch_product_rhs);
43
44        let round = m % 64 == 0 && k % 64 == 0 && n % 64 == 0;
45
46        Self::new(round, broadcast, m, k, n, batch_product, dtype)
47    }
48}
49
50pub(crate) fn create_key<R: JitRuntime, E: FloatElement>(
51    lhs: &JitTensor<R>,
52    rhs: &JitTensor<R>,
53    _out: &JitTensor<R>,
54) -> JitAutotuneKey {
55    JitAutotuneKey::Matmul(MatmulAutotuneKey::from_shape(
56        &lhs.shape,
57        &rhs.shape,
58        E::dtype(),
59    ))
60}
61
62#[cfg(test)]
63mod tests {
64    use super::*;
65
66    #[test]
67    fn matmul_autotune_key_all_same_and_round() {
68        let lhs_shape: Shape = [4, 512, 512].into();
69        let rhs_shape: Shape = [4, 512, 512].into();
70        let key = MatmulAutotuneKey::from_shape(&lhs_shape, &rhs_shape, DType::F32);
71
72        assert!(key.round);
73        assert!(!key.broadcast);
74        assert_eq!(key.m, 512);
75        assert_eq!(key.k, 512);
76        assert_eq!(key.n, 512);
77    }
78
79    #[test]
80    fn matmul_autotune_key_all_different() {
81        let lhs_shape: Shape = [2, 3, 511, 512].into();
82        let rhs_shape: Shape = [3, 2, 512, 513].into();
83        let key = MatmulAutotuneKey::from_shape(&lhs_shape, &rhs_shape, DType::F32);
84
85        assert!(!key.round);
86        assert!(key.broadcast);
87        assert_eq!(key.m, 512);
88        assert_eq!(key.k, 512);
89        assert_eq!(key.n, 1024);
90        assert_eq!(key.batch, 8);
91    }
92
93    #[test]
94    fn matmul_autotune_key_large_batch() {
95        let lhs_shape: Shape = [128, 512, 511, 512].into();
96        let rhs_shape: Shape = [200, 400, 512, 513].into();
97        let key = MatmulAutotuneKey::from_shape(&lhs_shape, &rhs_shape, DType::F32);
98
99        assert_eq!(key.batch, 256);
100    }
101}