use crate::kernel::{matmul::MatmulAutotuneKey, reduce::ReduceAutotuneKey};
use burn_compute::tune::AutotuneKey;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
#[cfg(any(feature = "fusion", test))]
use crate::fusion::FusionElemWiseAutotuneKey;
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
pub enum WgpuAutotuneKey {
Matmul(MatmulAutotuneKey),
SumDim(ReduceAutotuneKey),
MeanDim(ReduceAutotuneKey),
#[cfg(any(feature = "fusion", test))]
FusionElemWise(FusionElemWiseAutotuneKey),
}
impl Display for WgpuAutotuneKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WgpuAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f),
WgpuAutotuneKey::SumDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f),
WgpuAutotuneKey::MeanDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f),
#[cfg(any(feature = "fusion", test))]
WgpuAutotuneKey::FusionElemWise(reduce_key) => std::fmt::Display::fmt(&reduce_key, f),
}
}
}
impl AutotuneKey for WgpuAutotuneKey {}