use core::fmt::Display;
use super::{Branch, CoopMma, NonSemantic, Plane, Synchronization, Type, Variable};
use crate::{
Arithmetic, AtomicOp, Bitwise, InstructionModes, Metadata, OperationArgs, OperationReflect,
Operator, TmaOps, VectorSize, comparison::Comparison, marker::Marker,
};
use crate::{BarrierOps, SourceLoc, TypeHash};
use alloc::{
format,
string::{String, ToString},
vec::Vec,
};
use derive_more::derive::From;
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, From, OperationReflect)]
#[operation(opcode_name = OpCode)]
#[allow(dead_code, missing_docs, clippy::large_enum_variant)] pub enum Operation {
#[operation(pure)]
#[from(ignore)]
Copy(Variable),
#[operation(nested)]
Arithmetic(Arithmetic),
#[operation(nested)]
Comparison(Comparison),
#[operation(nested)]
Bitwise(Bitwise),
#[operation(nested)]
Operator(Operator),
#[operation(nested)]
Atomic(AtomicOp),
#[operation(nested)]
Metadata(Metadata),
#[operation(nested)]
Branch(Branch),
#[operation(nested)]
Synchronization(Synchronization),
#[operation(nested)]
Plane(Plane),
#[operation(nested)]
CoopMma(CoopMma),
#[operation(nested)]
Barrier(BarrierOps),
#[operation(nested)]
Tma(TmaOps),
#[operation(nested)]
NonSemantic(NonSemantic),
#[operation(nested)]
Marker(Marker),
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, PartialEq, Eq, Hash, TypeHash)]
pub struct Instruction {
pub out: Option<Variable>,
pub source_loc: Option<SourceLoc>,
pub modes: InstructionModes,
pub operation: Operation,
}
impl Instruction {
pub fn new(operation: impl Into<Operation>, out: Variable) -> Self {
Instruction {
out: Some(out),
operation: operation.into(),
source_loc: None,
modes: Default::default(),
}
}
pub fn no_out(operation: impl Into<Operation>) -> Self {
Instruction {
out: None,
operation: operation.into(),
source_loc: None,
modes: Default::default(),
}
}
pub fn out(&self) -> Variable {
self.out.unwrap()
}
pub fn ty(&self) -> Type {
self.out().ty
}
}
impl Display for Instruction {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match &self.operation {
Operation::Operator(Operator::CopyMemory(op)) => write!(
f,
"copy_mem({}[{}], {}[{}])",
self.out(),
op.out_index,
op.input,
op.in_index
),
Operation::Operator(Operator::CopyMemoryBulk(op)) => write!(
f,
"copy_mem_bulk({}[{}], {}[{}], {})",
self.out(),
op.out_index,
op.input,
op.in_index,
op.len
),
Operation::Operator(Operator::IndexAssign(op)) => {
write!(
f,
"{}[{}] = {} : ({}, {}) -> ({})",
self.out(),
op.index,
op.value,
op.index.ty,
op.value.ty,
self.out().ty,
)
}
Operation::Operator(Operator::UncheckedIndexAssign(op)) => {
write!(
f,
"unchecked {}[{}] = {} : ({}, {}) -> ({})",
self.out(),
op.index,
op.value,
op.index.ty,
op.value.ty,
self.out().ty,
)
}
Operation::Operator(Operator::Cast(op)) => {
write!(
f,
"{} = cast<{}>({}) : ({}) -> ({})",
self.out(),
self.ty(),
op.input,
op.input.ty,
self.out().ty,
)
}
Operation::Operator(Operator::Reinterpret(op)) => {
write!(f, "{} = bitcast<{}>({})", self.out(), self.ty(), op.input)
}
_ => {
if let Some(out) = self.out {
let mut vars_str = String::new();
for (i, var) in self.operation.args().unwrap_or_default().iter().enumerate() {
if i != 0 {
vars_str.push_str(", ");
}
vars_str.push_str(&var.ty.to_string());
}
write!(
f,
"{out} = {} : ({}) -> ({})",
self.operation, vars_str, out.ty
)
} else {
write!(f, "{}", self.operation)
}
}
}
}
}
impl Display for Operation {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Operation::Arithmetic(arithmetic) => write!(f, "{arithmetic}"),
Operation::Comparison(comparison) => write!(f, "{comparison}"),
Operation::Bitwise(bitwise) => write!(f, "{bitwise}"),
Operation::Operator(operator) => write!(f, "{operator}"),
Operation::Atomic(atomic) => write!(f, "{atomic}"),
Operation::Metadata(metadata) => write!(f, "{metadata}"),
Operation::Branch(branch) => write!(f, "{branch}"),
Operation::Synchronization(synchronization) => write!(f, "{synchronization}"),
Operation::Plane(plane) => write!(f, "{plane}"),
Operation::CoopMma(coop_mma) => write!(f, "{coop_mma}"),
Operation::Copy(variable) => write!(f, "{variable}"),
Operation::NonSemantic(non_semantic) => write!(f, "{non_semantic}"),
Operation::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
Operation::Tma(tma_ops) => write!(f, "{tma_ops}"),
Operation::Marker(marker) => write!(f, "{marker}"),
}
}
}
pub fn fmt_vararg(args: &[impl Display]) -> String {
if args.is_empty() {
"".to_string()
} else {
let str = args
.iter()
.map(|it| it.to_string())
.collect::<Vec<_>>()
.join(", ");
format!(", {str}")
}
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
#[allow(missing_docs)]
pub struct IndexOperator {
pub list: Variable,
pub index: Variable,
pub vector_size: VectorSize, pub unroll_factor: usize, }
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
#[allow(missing_docs)]
pub struct IndexAssignOperator {
pub index: Variable,
pub value: Variable,
pub vector_size: VectorSize, pub unroll_factor: usize, }
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
#[allow(missing_docs)]
pub struct BinaryOperator {
pub lhs: Variable,
pub rhs: Variable,
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
#[allow(missing_docs)]
pub struct UnaryOperator {
pub input: Variable,
}
impl From<Branch> for Instruction {
fn from(value: Branch) -> Self {
Instruction::no_out(value)
}
}
impl From<Synchronization> for Instruction {
fn from(value: Synchronization) -> Self {
Instruction::no_out(value)
}
}
impl From<NonSemantic> for Instruction {
fn from(value: NonSemantic) -> Self {
Instruction::no_out(value)
}
}