burn_jit/kernel/matmul/tune/
key.rs1use crate::{tensor::JitTensor, FloatElement, JitAutotuneKey, JitRuntime};
2use burn_tensor::{DType, Shape};
3use core::fmt::Debug;
4use cubecl::AutotuneKey;
5use serde::{Deserialize, Serialize};
6use std::{cmp::max, hash::Hash};
7
8#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
9pub struct MatmulAutotuneKey {
11 round: bool, broadcast: bool, #[autotune(anchor)]
14 m: usize,
15 #[autotune(anchor)]
16 k: usize,
17 #[autotune(anchor)]
18 n: usize,
19 #[autotune(anchor(max = 256))]
20 batch: usize,
21 dtype: DType,
22}
23
24impl MatmulAutotuneKey {
25 fn from_shape(lhs_shape: &Shape, rhs_shape: &Shape, dtype: DType) -> Self {
26 let ndims = lhs_shape.num_dims();
27 let m = lhs_shape.dims[ndims - 2];
28 let k = lhs_shape.dims[ndims - 1];
29 let n = rhs_shape.dims[ndims - 1];
30
31 let mut broadcast = false;
32 let mut batch_product_lhs = 1;
33 let mut batch_product_rhs = 1;
34
35 for b in 0..ndims - 2 {
36 batch_product_lhs *= lhs_shape.dims[b];
37 batch_product_rhs *= rhs_shape.dims[b];
38 if lhs_shape.dims[b] != rhs_shape.dims[b] {
39 broadcast = true;
40 }
41 }
42 let batch_product = max(batch_product_lhs, batch_product_rhs);
43
44 let round = m % 64 == 0 && k % 64 == 0 && n % 64 == 0;
45
46 Self::new(round, broadcast, m, k, n, batch_product, dtype)
47 }
48}
49
50pub(crate) fn create_key<R: JitRuntime, E: FloatElement>(
51 lhs: &JitTensor<R>,
52 rhs: &JitTensor<R>,
53 _out: &JitTensor<R>,
54) -> JitAutotuneKey {
55 JitAutotuneKey::Matmul(MatmulAutotuneKey::from_shape(
56 &lhs.shape,
57 &rhs.shape,
58 E::dtype(),
59 ))
60}
61
62#[cfg(test)]
63mod tests {
64 use super::*;
65
66 #[test]
67 fn matmul_autotune_key_all_same_and_round() {
68 let lhs_shape: Shape = [4, 512, 512].into();
69 let rhs_shape: Shape = [4, 512, 512].into();
70 let key = MatmulAutotuneKey::from_shape(&lhs_shape, &rhs_shape, DType::F32);
71
72 assert!(key.round);
73 assert!(!key.broadcast);
74 assert_eq!(key.m, 512);
75 assert_eq!(key.k, 512);
76 assert_eq!(key.n, 512);
77 }
78
79 #[test]
80 fn matmul_autotune_key_all_different() {
81 let lhs_shape: Shape = [2, 3, 511, 512].into();
82 let rhs_shape: Shape = [3, 2, 512, 513].into();
83 let key = MatmulAutotuneKey::from_shape(&lhs_shape, &rhs_shape, DType::F32);
84
85 assert!(!key.round);
86 assert!(key.broadcast);
87 assert_eq!(key.m, 512);
88 assert_eq!(key.k, 512);
89 assert_eq!(key.n, 1024);
90 assert_eq!(key.batch, 8);
91 }
92
93 #[test]
94 fn matmul_autotune_key_large_batch() {
95 let lhs_shape: Shape = [128, 512, 511, 512].into();
96 let rhs_shape: Shape = [200, 400, 512, 513].into();
97 let key = MatmulAutotuneKey::from_shape(&lhs_shape, &rhs_shape, DType::F32);
98
99 assert_eq!(key.batch, 256);
100 }
101}