1use crate::{Instruction, TypeHash};
2use alloc::{format, string::String, vec::Vec};
3use core::fmt::{Display, Write};
4
5use crate::OperationReflect;
6
7use super::Variable;
8
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, Copy, PartialOrd, Ord)]
11pub enum BarrierLevel {
12 Unit,
13 Cube,
14}
15
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
17#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationReflect)]
18#[operation(opcode_name = BarrierOpCode)]
19pub enum BarrierOps {
21 Declare {
23 barrier: Variable,
24 },
25 Init {
27 barrier: Variable,
28 is_elected: Variable,
29 arrival_count: Variable,
30 },
31 InitManual {
33 barrier: Variable,
34 arrival_count: Variable,
35 },
36 MemCopyAsync {
38 barrier: Variable,
39 source: Variable,
40 source_length: Variable,
41 offset_source: Variable,
42 offset_out: Variable,
43 },
44 MemCopyAsyncCooperative {
46 barrier: Variable,
47 source: Variable,
48 source_length: Variable,
49 offset_source: Variable,
50 offset_out: Variable,
51 },
52 MemCopyAsyncTx {
54 barrier: Variable,
55 source: Variable,
56 source_length: Variable,
57 offset_source: Variable,
58 offset_out: Variable,
59 },
60 CopyAsync {
62 source: Variable,
63 source_length: Variable,
64 offset_source: Variable,
65 offset_out: Variable,
66 copy_length: u32,
67 checked: bool,
68 },
69 TmaLoad {
70 barrier: Variable,
71 tensor_map: Variable,
72 indices: Vec<Variable>,
73 offset_out: Variable,
74 },
75 TmaLoadIm2col {
76 barrier: Variable,
77 tensor_map: Variable,
78 indices: Vec<Variable>,
79 offsets: Vec<Variable>,
80 offset_out: Variable,
81 },
82 Arrive {
84 barrier: Variable,
85 },
86 ArriveTx {
87 barrier: Variable,
88 arrive_count_update: Variable,
89 transaction_count_update: Variable,
90 },
91 CommitCopyAsync {
92 barrier: Variable,
93 },
94 ExpectTx {
95 barrier: Variable,
96 transaction_count_update: Variable,
97 },
98 Wait {
99 barrier: Variable,
100 token: Variable,
101 },
102 WaitParity {
103 barrier: Variable,
104 phase: Variable,
105 },
106 ArriveAndWait {
108 barrier: Variable,
109 },
110}
111
112impl Display for BarrierOps {
113 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
114 match self {
115 BarrierOps::Declare { .. } => Ok(()),
116 BarrierOps::Init {
117 barrier,
118 arrival_count,
119 ..
120 } => write!(f, "{barrier}.init_barrier({arrival_count})"),
121 BarrierOps::InitManual {
122 barrier,
123 arrival_count,
124 } => write!(f, "{barrier}.init_barrier({arrival_count})"),
125 BarrierOps::MemCopyAsync {
126 barrier,
127 source,
128 offset_source,
129 offset_out,
130 ..
131 } => {
132 write!(
133 f,
134 "out[{offset_out}] = mem_copy_async({barrier}, source: {source}[{offset_source}])",
135 )
136 }
137 BarrierOps::MemCopyAsyncCooperative {
138 barrier,
139 source,
140 offset_source,
141 offset_out,
142 ..
143 } => {
144 write!(
145 f,
146 "out[{offset_out}] = mem_copy_async_cooperative({barrier}, source: {source}[{offset_source}])",
147 )
148 }
149 BarrierOps::MemCopyAsyncTx {
150 barrier,
151 source,
152 offset_source,
153 offset_out,
154 ..
155 } => {
156 write!(
157 f,
158 "out[{offset_out}] = mem_copy_async_tx({barrier}, source: {source}[{offset_source}])",
159 )
160 }
161 BarrierOps::CopyAsync {
162 source,
163 source_length,
164 offset_source,
165 offset_out,
166 copy_length,
167 checked,
168 } => {
169 let source_slice = if *checked {
170 format!("[{offset_source}..][..{source_length}]")
171 } else {
172 format!("[{offset_source}]")
173 };
174 write!(
175 f,
176 "out[{offset_out}] = copy_async(source: {source}{source_slice}, bytes: {copy_length})",
177 )
178 }
179 BarrierOps::ArriveAndWait { barrier } => write!(f, "arrive_and_wait({barrier})"),
180 BarrierOps::TmaLoad {
181 barrier,
182 tensor_map,
183 offset_out,
184 indices,
185 } => {
186 let rank = indices.len();
187 let indices = indices.iter().fold(String::new(), |mut s, it| {
188 let _ = write!(s, "{it}, ");
189 s
190 });
191 write!(
192 f,
193 "out[{offset_out}] = tma_load::<{rank}>({barrier}, {tensor_map}, {indices})"
194 )
195 }
196 BarrierOps::TmaLoadIm2col {
197 barrier,
198 tensor_map,
199 indices,
200 offsets,
201 offset_out,
202 } => {
203 let rank = indices.len();
204 let indices = indices.iter().fold(String::new(), |mut s, it| {
205 let _ = write!(s, "{it}, ");
206 s
207 });
208 let offsets = offsets.iter().fold(String::new(), |mut s, it| {
209 let _ = write!(s, "{it}, ");
210 s
211 });
212 write!(
213 f,
214 "out[{offset_out}] = tma_load_im2col::<{rank}>({barrier}, {tensor_map}, indices: ({indices}), offsets: ({offsets}))"
215 )
216 }
217 BarrierOps::Arrive { barrier } => write!(f, "arrive({barrier})"),
218 BarrierOps::CommitCopyAsync { barrier } => write!(f, "commit_copy_async({barrier})"),
219 BarrierOps::ArriveTx {
220 barrier,
221 arrive_count_update,
222 transaction_count_update,
223 } => write!(
224 f,
225 "arrive_tx({barrier}, {arrive_count_update}, {transaction_count_update})"
226 ),
227 BarrierOps::ExpectTx {
228 barrier,
229 transaction_count_update,
230 } => write!(f, "expect_tx({barrier}, {transaction_count_update})"),
231 BarrierOps::Wait { barrier, token } => write!(f, "wait({barrier}, {token})"),
232 BarrierOps::WaitParity { barrier, phase } => {
233 write!(f, "wait_parity({barrier}, {phase})")
234 }
235 }
236 }
237}
238
239impl Display for BarrierLevel {
240 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
241 match self {
242 BarrierLevel::Unit => f.write_str("unit"),
243 BarrierLevel::Cube => f.write_str("cube"),
244 }
245 }
246}
247
248impl From<BarrierOps> for Instruction {
249 fn from(value: BarrierOps) -> Self {
250 Instruction::no_out(value)
251 }
252}