use super::{Branch, CoopMma, Procedure, Subcube, Synchronization, Variable};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(dead_code, missing_docs)] pub enum Operation {
Operator(Operator),
Procedure(Procedure),
Metadata(Metadata),
Branch(Branch),
Synchronization(Synchronization),
Subcube(Subcube),
CoopMma(CoopMma),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(dead_code, missing_docs)] pub enum Operator {
Add(BinaryOperator),
Fma(FmaOperator),
Sub(BinaryOperator),
Mul(BinaryOperator),
Div(BinaryOperator),
Abs(UnaryOperator),
Exp(UnaryOperator),
Log(UnaryOperator),
Log1p(UnaryOperator),
Cos(UnaryOperator),
Sin(UnaryOperator),
Tanh(UnaryOperator),
Powf(BinaryOperator),
Sqrt(UnaryOperator),
Floor(UnaryOperator),
Ceil(UnaryOperator),
Erf(UnaryOperator),
Recip(UnaryOperator),
Equal(BinaryOperator),
NotEqual(BinaryOperator),
Lower(BinaryOperator),
Clamp(ClampOperator),
Greater(BinaryOperator),
LowerEqual(BinaryOperator),
GreaterEqual(BinaryOperator),
Assign(UnaryOperator),
Modulo(BinaryOperator),
Index(BinaryOperator),
Slice(SliceOperator),
UncheckedIndex(BinaryOperator),
IndexAssign(BinaryOperator),
UncheckedIndexAssign(BinaryOperator),
And(BinaryOperator),
Or(BinaryOperator),
Not(UnaryOperator),
Max(BinaryOperator),
Min(BinaryOperator),
BitwiseAnd(BinaryOperator),
BitwiseXor(BinaryOperator),
ShiftLeft(BinaryOperator),
ShiftRight(BinaryOperator),
Remainder(BinaryOperator),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(missing_docs)]
pub enum Metadata {
Stride {
dim: Variable,
var: Variable,
out: Variable,
},
Shape {
dim: Variable,
var: Variable,
out: Variable,
},
Length {
var: Variable,
out: Variable,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(missing_docs)]
pub struct BinaryOperator {
pub lhs: Variable,
pub rhs: Variable,
pub out: Variable,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(missing_docs)]
pub struct UnaryOperator {
pub input: Variable,
pub out: Variable,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(missing_docs)]
pub struct InitOperator {
pub out: Variable,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(missing_docs)]
pub struct ClampOperator {
pub input: Variable,
pub min_value: Variable,
pub max_value: Variable,
pub out: Variable,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(missing_docs)]
pub struct SliceOperator {
pub input: Variable,
pub start: Variable,
pub end: Variable,
pub out: Variable,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(missing_docs)]
pub struct ReadGlobalOperator {
pub variable: Variable,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(missing_docs)]
pub struct ReadGlobalWithLayoutOperator {
pub variable: Variable,
pub tensor_read_pos: usize,
pub tensor_layout_pos: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(missing_docs)]
pub struct FmaOperator {
pub a: Variable,
pub b: Variable,
pub c: Variable,
pub out: Variable,
}
impl From<Operator> for Operation {
fn from(val: Operator) -> Self {
Operation::Operator(val)
}
}
impl From<Branch> for Operation {
fn from(value: Branch) -> Self {
Self::Branch(value)
}
}
impl From<Synchronization> for Operation {
fn from(value: Synchronization) -> Self {
Self::Synchronization(value)
}
}
impl From<Metadata> for Operation {
fn from(val: Metadata) -> Self {
Operation::Metadata(val)
}
}
impl From<Procedure> for Operation {
fn from(val: Procedure) -> Self {
Operation::Procedure(val)
}
}