use crate::{Instruction, TypeHash};
use alloc::{format, string::String, vec::Vec};
use core::fmt::{Display, Write};
use crate::OperationReflect;
use super::Variable;
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, Copy, PartialOrd, Ord)]
pub enum BarrierLevel {
Unit,
Cube,
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationReflect)]
#[operation(opcode_name = BarrierOpCode)]
pub enum BarrierOps {
Declare {
barrier: Variable,
},
Init {
barrier: Variable,
is_elected: Variable,
arrival_count: Variable,
},
InitManual {
barrier: Variable,
arrival_count: Variable,
},
MemCopyAsync {
barrier: Variable,
source: Variable,
source_length: Variable,
offset_source: Variable,
offset_out: Variable,
},
MemCopyAsyncCooperative {
barrier: Variable,
source: Variable,
source_length: Variable,
offset_source: Variable,
offset_out: Variable,
},
MemCopyAsyncTx {
barrier: Variable,
source: Variable,
source_length: Variable,
offset_source: Variable,
offset_out: Variable,
},
CopyAsync {
source: Variable,
source_length: Variable,
offset_source: Variable,
offset_out: Variable,
copy_length: u32,
checked: bool,
},
TmaLoad {
barrier: Variable,
tensor_map: Variable,
indices: Vec<Variable>,
offset_out: Variable,
},
TmaLoadIm2col {
barrier: Variable,
tensor_map: Variable,
indices: Vec<Variable>,
offsets: Vec<Variable>,
offset_out: Variable,
},
Arrive {
barrier: Variable,
},
ArriveTx {
barrier: Variable,
arrive_count_update: Variable,
transaction_count_update: Variable,
},
CommitCopyAsync {
barrier: Variable,
},
ExpectTx {
barrier: Variable,
transaction_count_update: Variable,
},
Wait {
barrier: Variable,
token: Variable,
},
WaitParity {
barrier: Variable,
phase: Variable,
},
ArriveAndWait {
barrier: Variable,
},
}
impl Display for BarrierOps {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
BarrierOps::Declare { .. } => Ok(()),
BarrierOps::Init {
barrier,
arrival_count,
..
} => write!(f, "{barrier}.init_barrier({arrival_count})"),
BarrierOps::InitManual {
barrier,
arrival_count,
} => write!(f, "{barrier}.init_barrier({arrival_count})"),
BarrierOps::MemCopyAsync {
barrier,
source,
offset_source,
offset_out,
..
} => {
write!(
f,
"out[{offset_out}] = mem_copy_async({barrier}, source: {source}[{offset_source}])",
)
}
BarrierOps::MemCopyAsyncCooperative {
barrier,
source,
offset_source,
offset_out,
..
} => {
write!(
f,
"out[{offset_out}] = mem_copy_async_cooperative({barrier}, source: {source}[{offset_source}])",
)
}
BarrierOps::MemCopyAsyncTx {
barrier,
source,
offset_source,
offset_out,
..
} => {
write!(
f,
"out[{offset_out}] = mem_copy_async_tx({barrier}, source: {source}[{offset_source}])",
)
}
BarrierOps::CopyAsync {
source,
source_length,
offset_source,
offset_out,
copy_length,
checked,
} => {
let source_slice = if *checked {
format!("[{offset_source}..][..{source_length}]")
} else {
format!("[{offset_source}]")
};
write!(
f,
"out[{offset_out}] = copy_async(source: {source}{source_slice}, bytes: {copy_length})",
)
}
BarrierOps::ArriveAndWait { barrier } => write!(f, "arrive_and_wait({barrier})"),
BarrierOps::TmaLoad {
barrier,
tensor_map,
offset_out,
indices,
} => {
let rank = indices.len();
let indices = indices.iter().fold(String::new(), |mut s, it| {
let _ = write!(s, "{it}, ");
s
});
write!(
f,
"out[{offset_out}] = tma_load::<{rank}>({barrier}, {tensor_map}, {indices})"
)
}
BarrierOps::TmaLoadIm2col {
barrier,
tensor_map,
indices,
offsets,
offset_out,
} => {
let rank = indices.len();
let indices = indices.iter().fold(String::new(), |mut s, it| {
let _ = write!(s, "{it}, ");
s
});
let offsets = offsets.iter().fold(String::new(), |mut s, it| {
let _ = write!(s, "{it}, ");
s
});
write!(
f,
"out[{offset_out}] = tma_load_im2col::<{rank}>({barrier}, {tensor_map}, indices: ({indices}), offsets: ({offsets}))"
)
}
BarrierOps::Arrive { barrier } => write!(f, "arrive({barrier})"),
BarrierOps::CommitCopyAsync { barrier } => write!(f, "commit_copy_async({barrier})"),
BarrierOps::ArriveTx {
barrier,
arrive_count_update,
transaction_count_update,
} => write!(
f,
"arrive_tx({barrier}, {arrive_count_update}, {transaction_count_update})"
),
BarrierOps::ExpectTx {
barrier,
transaction_count_update,
} => write!(f, "expect_tx({barrier}, {transaction_count_update})"),
BarrierOps::Wait { barrier, token } => write!(f, "wait({barrier}, {token})"),
BarrierOps::WaitParity { barrier, phase } => {
write!(f, "wait_parity({barrier}, {phase})")
}
}
}
}
impl Display for BarrierLevel {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
BarrierLevel::Unit => f.write_str("unit"),
BarrierLevel::Cube => f.write_str("cube"),
}
}
}
impl From<BarrierOps> for Instruction {
fn from(value: BarrierOps) -> Self {
Instruction::no_out(value)
}
}