cubecl_ir/
tma.rs

1use crate::{Instruction, TypeHash};
2use alloc::{string::String, vec::Vec};
3use core::fmt::{Display, Write};
4
5use crate::OperationReflect;
6
7use super::Variable;
8
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationReflect)]
11#[operation(opcode_name = TmaOpCode)]
12/// Operations available on a barrier
13pub enum TmaOps {
14    TmaStore {
15        source: Variable,
16        coordinates: Vec<Variable>,
17        offset_source: Variable,
18    },
19    CommitGroup,
20    WaitGroup {
21        max_pending: u32,
22    },
23    WaitGroupRead {
24        max_pending: u32,
25    },
26}
27
28impl Display for TmaOps {
29    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
30        match self {
31            TmaOps::TmaStore {
32                source,
33                coordinates,
34                offset_source,
35            } => {
36                let rank = coordinates.len();
37                let coords = coordinates.iter().fold(String::new(), |mut s, coord| {
38                    let _ = write!(s, ", {coord}");
39                    s
40                });
41                write!(f, "tma_load::<{rank}>({source} + {offset_source} {coords})")
42            }
43            TmaOps::CommitGroup => write!(f, "memcpy_async_bulk_commit_group()"),
44            TmaOps::WaitGroup { max_pending } => {
45                write!(f, "tma_wait_group::<{max_pending}>()")
46            }
47            TmaOps::WaitGroupRead { max_pending } => {
48                write!(f, "tma_wait_group_read::<{max_pending}>()")
49            }
50        }
51    }
52}
53
54impl From<TmaOps> for Instruction {
55    fn from(value: TmaOps) -> Self {
56        Instruction::no_out(value)
57    }
58}