use std::fmt::Display;
use burn_compute::tune::AutotuneKey;
use crate::kernel::{matmul::MatmulAutotuneKey, reduce::ReduceAutotuneKey};
#[derive(Hash, Eq, PartialEq, Debug, Clone)]
pub enum WgpuAutotuneKey {
Matmul(MatmulAutotuneKey),
SumDim(ReduceAutotuneKey),
MeanDim(ReduceAutotuneKey),
}
impl Display for WgpuAutotuneKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WgpuAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f),
WgpuAutotuneKey::SumDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f),
WgpuAutotuneKey::MeanDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f),
}
}
}
impl AutotuneKey for WgpuAutotuneKey {}