use std::fmt::Display;
use serde::{Deserialize, Serialize};
use crate::definition::ConvBlueprint;
#[derive(Clone, Debug)]
pub enum Strategy {
Inferred {
algorithm: ConvAlgorithm,
tile_kind: AcceleratedTileKind,
},
Forced {
algorithm: ConvAlgorithm,
blueprint: ConvBlueprint,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ConvAlgorithm {
SimpleSyncCyclic,
SimpleSyncStrided,
SimpleSyncTilewise,
SimpleAsyncCyclic,
SimpleAsyncStrided,
SimpleAsyncTma,
SpecializedAsyncCyclic,
SpecializedAsyncStrided,
SpecializedTma,
}
impl Display for ConvAlgorithm {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
ConvAlgorithm::SimpleSyncCyclic => "simple_sync_cyclic",
ConvAlgorithm::SimpleSyncStrided => "simple_sync_strided",
ConvAlgorithm::SimpleSyncTilewise => "simple_sync_tilewise",
ConvAlgorithm::SimpleAsyncCyclic => "simple_async_cyclic",
ConvAlgorithm::SimpleAsyncStrided => "simple_async_strided",
ConvAlgorithm::SimpleAsyncTma => "simple_async_tma",
ConvAlgorithm::SpecializedAsyncCyclic => "specialized_async_cyclic",
ConvAlgorithm::SpecializedAsyncStrided => "specialized_async_strided",
ConvAlgorithm::SpecializedTma => "specialized_tma",
};
f.write_str(s)
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum AcceleratedTileKind {
#[default]
Cmma,
Mma,
}
impl Display for AcceleratedTileKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AcceleratedTileKind::Cmma => f.write_str("cmma"),
AcceleratedTileKind::Mma => f.write_str("mma"),
}
}
}
impl Display for Strategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Strategy::Inferred {
algorithm,
tile_kind,
} => write!(f, "{algorithm}_{tile_kind}"),
Strategy::Forced {
algorithm,
blueprint: _,
} => write!(f, "{algorithm}_forced"),
}
}
}