1use crate::{Instruction, TypeHash};
2use alloc::{string::String, vec::Vec};
3use core::fmt::{Display, Write};
4
5use crate::OperationReflect;
6
7use super::Variable;
8
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationReflect)]
11#[operation(opcode_name = TmaOpCode)]
12pub enum TmaOps {
14 TmaStore {
15 source: Variable,
16 coordinates: Vec<Variable>,
17 offset_source: Variable,
18 },
19 CommitGroup,
20 WaitGroup {
21 max_pending: u32,
22 },
23 WaitGroupRead {
24 max_pending: u32,
25 },
26}
27
28impl Display for TmaOps {
29 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
30 match self {
31 TmaOps::TmaStore {
32 source,
33 coordinates,
34 offset_source,
35 } => {
36 let rank = coordinates.len();
37 let coords = coordinates.iter().fold(String::new(), |mut s, coord| {
38 let _ = write!(s, ", {coord}");
39 s
40 });
41 write!(f, "tma_load::<{rank}>({source} + {offset_source} {coords})")
42 }
43 TmaOps::CommitGroup => write!(f, "memcpy_async_bulk_commit_group()"),
44 TmaOps::WaitGroup { max_pending } => {
45 write!(f, "tma_wait_group::<{max_pending}>()")
46 }
47 TmaOps::WaitGroupRead { max_pending } => {
48 write!(f, "tma_wait_group_read::<{max_pending}>()")
49 }
50 }
51 }
52}
53
54impl From<TmaOps> for Instruction {
55 fn from(value: TmaOps) -> Self {
56 Instruction::no_out(value)
57 }
58}