1use crate::{Instruction, TypeHash};
2use alloc::{string::String, vec::Vec};
3use core::fmt::{Display, Write};
4
5use crate::OperationReflect;
6
7use super::Variable;
8
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, Copy)]
11pub enum BarrierLevel {
12 Unit,
13 CubeCoop(u32),
14 CubeManual(u32),
15}
16
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationReflect)]
19#[operation(opcode_name = BarrierOpCode)]
20pub enum BarrierOps {
22 Init {
24 barrier: Variable,
25 with_cta_fence: bool,
26 },
27 MemCopyAsync {
29 barrier: Variable,
30 source: Variable,
31 source_length: Variable,
32 offset_source: Variable,
33 offset_out: Variable,
34 },
35 TmaLoad {
36 barrier: Variable,
37 tensor_map: Variable,
38 indices: Vec<Variable>,
39 offset_out: Variable,
40 },
41 TmaLoadIm2col {
42 barrier: Variable,
43 tensor_map: Variable,
44 indices: Vec<Variable>,
45 offsets: Vec<Variable>,
46 offset_out: Variable,
47 },
48 Arrive {
50 barrier: Variable,
51 },
52 ArriveTx {
53 barrier: Variable,
54 arrive_count_update: Variable,
55 transaction_count_update: Variable,
56 },
57 ExpectTx {
58 barrier: Variable,
59 transaction_count_update: Variable,
60 },
61 Wait {
62 barrier: Variable,
63 },
64 ArriveAndWait {
66 barrier: Variable,
67 },
68}
69
70impl Display for BarrierOps {
71 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
72 match self {
73 BarrierOps::Init {
74 barrier,
75 with_cta_fence,
76 } => match with_cta_fence {
77 true => write!(f, "init_barrier_tma({barrier})"),
78 false => write!(f, "init_barrier({barrier})"),
79 },
80 BarrierOps::MemCopyAsync {
81 barrier,
82 source,
83 offset_source,
84 offset_out,
85 ..
86 } => {
87 write!(
88 f,
89 "out[{offset_out}] = mem_copy_async({barrier}, source: {source}[{offset_source}])",
90 )
91 }
92 BarrierOps::ArriveAndWait { barrier } => write!(f, "arrive_and_wait({barrier})"),
93 BarrierOps::TmaLoad {
94 barrier,
95 tensor_map,
96 offset_out,
97 indices,
98 } => {
99 let rank = indices.len();
100 let indices = indices.iter().fold(String::new(), |mut s, it| {
101 let _ = write!(s, "{it}, ");
102 s
103 });
104 write!(
105 f,
106 "out[{offset_out}] = tma_load::<{rank}>({barrier}, {tensor_map}, {indices})"
107 )
108 }
109 BarrierOps::TmaLoadIm2col {
110 barrier,
111 tensor_map,
112 indices,
113 offsets,
114 offset_out,
115 } => {
116 let rank = indices.len();
117 let indices = indices.iter().fold(String::new(), |mut s, it| {
118 let _ = write!(s, "{it}, ");
119 s
120 });
121 let offsets = offsets.iter().fold(String::new(), |mut s, it| {
122 let _ = write!(s, "{it}, ");
123 s
124 });
125 write!(
126 f,
127 "out[{offset_out}] = tma_load_im2col::<{rank}>({barrier}, {tensor_map}, indices: ({indices}), offsets: ({offsets}))"
128 )
129 }
130 BarrierOps::Arrive { barrier } => write!(f, "arrive({barrier})"),
131 BarrierOps::ArriveTx {
132 barrier,
133 arrive_count_update,
134 transaction_count_update,
135 } => write!(
136 f,
137 "arrive_tx({barrier}, {arrive_count_update}, {transaction_count_update})"
138 ),
139 BarrierOps::ExpectTx {
140 barrier,
141 transaction_count_update,
142 } => write!(f, "expect_tx({barrier}, {transaction_count_update})"),
143 BarrierOps::Wait { barrier } => write!(f, "wait({barrier})"),
144 }
145 }
146}
147
148impl From<BarrierOps> for Instruction {
149 fn from(value: BarrierOps) -> Self {
150 Instruction::no_out(value)
151 }
152}