burn_jit/
tune_key.rs

1use crate::kernel::{
2    conv::{Conv2dAutotuneKey, ConvTranspose2dAutotuneKey},
3    matmul::MatmulAutotuneKey,
4    reduce::ReduceAutotuneKey,
5};
6use cubecl::tune::AutotuneKey;
7use serde::{Deserialize, Serialize};
8use std::fmt::Display;
9
10#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
11/// Key for all autotune-enabled operations
12pub enum JitAutotuneKey {
13    /// Key for matmul operation
14    Matmul(MatmulAutotuneKey),
15    /// Key for reduce dim operations
16    Reduce(ReduceAutotuneKey),
17    /// Key for convolution operations
18    Conv2d(Conv2dAutotuneKey),
19    /// Key for transpose convolution operations
20    ConvTranspose2d(ConvTranspose2dAutotuneKey),
21}
22
23impl Display for JitAutotuneKey {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            JitAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f),
27            JitAutotuneKey::Reduce(reduce_key) => std::fmt::Display::fmt(&reduce_key, f),
28            JitAutotuneKey::Conv2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f),
29            JitAutotuneKey::ConvTranspose2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f),
30        }
31    }
32}
33
34impl AutotuneKey for JitAutotuneKey {}