burn_cubecl/
tune_key.rs

1use crate::kernel::{
2    conv::{Conv2dAutotuneKey, ConvTranspose2dAutotuneKey},
3    reduce::SumAutotuneKey,
4};
5use cubecl::tune::AutotuneKey;
6use serde::{Deserialize, Serialize};
7use std::fmt::Display;
8
9#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
10/// Key for all autotune-enabled operations
11pub enum CubeAutotuneKey {
12    /// Key for sum operations
13    Sum(SumAutotuneKey),
14    /// Key for convolution operations
15    Conv2d(Conv2dAutotuneKey),
16    /// Key for transpose convolution operations
17    ConvTranspose2d(ConvTranspose2dAutotuneKey),
18}
19
20impl Display for CubeAutotuneKey {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        match self {
23            CubeAutotuneKey::Sum(reduce_key) => std::fmt::Debug::fmt(&reduce_key, f),
24            CubeAutotuneKey::Conv2d(conv2d_key) => std::fmt::Debug::fmt(&conv2d_key, f),
25            CubeAutotuneKey::ConvTranspose2d(conv2d_key) => std::fmt::Debug::fmt(&conv2d_key, f),
26        }
27    }
28}
29
30impl AutotuneKey for CubeAutotuneKey {}