use std::fmt;
pub use morok_ir::AxisType;
use super::error::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum OptOps {
TC,
UPCAST,
UNROLL,
LOCAL,
THREAD,
GROUP,
GROUPTOP,
NOLOCALS,
PADTO,
SWAP,
}
impl fmt::Display for OptOps {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TC => write!(f, "TC"),
Self::UPCAST => write!(f, "UPCAST"),
Self::UNROLL => write!(f, "UNROLL"),
Self::LOCAL => write!(f, "LOCAL"),
Self::THREAD => write!(f, "THREAD"),
Self::GROUP => write!(f, "GROUP"),
Self::GROUPTOP => write!(f, "GROUPTOP"),
Self::NOLOCALS => write!(f, "NOLOCALS"),
Self::PADTO => write!(f, "PADTO"),
Self::SWAP => write!(f, "SWAP"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum OptArg {
Int(usize),
TensorCore { tc_select: i32, opt_level: usize, use_tc: usize },
Swap { other_axis: usize },
}
impl OptArg {
pub fn type_name(&self) -> &'static str {
match self {
Self::Int(_) => "Int",
Self::TensorCore { .. } => "TensorCore",
Self::Swap { .. } => "Swap",
}
}
pub fn int(&self) -> Result<usize, OptError> {
match self {
Self::Int(v) => Ok(*v),
_ => InvalidArgTypeSnafu { expected: "Int", found: self.type_name() }.fail(),
}
}
pub fn tc(&self) -> Result<(i32, usize, usize), OptError> {
match self {
Self::TensorCore { tc_select, opt_level, use_tc } => Ok((*tc_select, *opt_level, *use_tc)),
_ => InvalidArgTypeSnafu { expected: "TensorCore", found: self.type_name() }.fail(),
}
}
pub fn swap(&self) -> Result<usize, OptError> {
match self {
Self::Swap { other_axis } => Ok(*other_axis),
_ => InvalidArgTypeSnafu { expected: "Swap", found: self.type_name() }.fail(),
}
}
}
impl From<usize> for OptArg {
fn from(v: usize) -> Self {
Self::Int(v)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct Opt {
pub op: OptOps,
pub axis: Option<usize>,
pub arg: OptArg,
}
impl Opt {
pub fn new(op: OptOps, axis: Option<usize>, arg: OptArg) -> Self {
Self { op, axis, arg }
}
pub fn upcast(axis: usize, amount: usize) -> Self {
Self::new(OptOps::UPCAST, Some(axis), OptArg::Int(amount))
}
pub fn local(axis: usize, amount: usize) -> Self {
Self::new(OptOps::LOCAL, Some(axis), OptArg::Int(amount))
}
pub fn unroll(axis: usize, amount: usize) -> Self {
Self::new(OptOps::UNROLL, Some(axis), OptArg::Int(amount))
}
pub fn group(axis: usize, amount: usize) -> Self {
Self::new(OptOps::GROUP, Some(axis), OptArg::Int(amount))
}
pub fn grouptop(axis: usize, amount: usize) -> Self {
Self::new(OptOps::GROUPTOP, Some(axis), OptArg::Int(amount))
}
pub fn thread(axis: usize, amount: usize) -> Self {
Self::new(OptOps::THREAD, Some(axis), OptArg::Int(amount))
}
pub fn padto(axis: usize, size: usize) -> Self {
Self::new(OptOps::PADTO, Some(axis), OptArg::Int(size))
}
pub fn swap(axis: usize, other_axis: usize) -> Self {
Self::new(OptOps::SWAP, Some(axis), OptArg::Swap { other_axis })
}
pub fn tc(axis: Option<usize>, tc_select: i32, opt_level: usize, use_tc: usize) -> Self {
Self::new(OptOps::TC, axis, OptArg::TensorCore { tc_select, opt_level, use_tc })
}
pub fn nolocals() -> Self {
Self::new(OptOps::NOLOCALS, None, OptArg::Int(0))
}
}
impl fmt::Display for Opt {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}(", self.op)?;
if let Some(axis) = self.axis {
write!(f, "{}, ", axis)?;
}
match &self.arg {
OptArg::Int(v) => write!(f, "{}", v),
OptArg::TensorCore { tc_select, opt_level, use_tc } => {
write!(f, "tc_sel={}, opt={}, use={}", tc_select, opt_level, use_tc)
}
OptArg::Swap { other_axis } => write!(f, "swap={}", other_axis),
}?;
write!(f, ")")
}
}