1use crate::{Instruction, TypeHash};
2use alloc::{string::String, vec::Vec};
3use core::fmt::{Display, Write};
4
5use crate::OperationReflect;
6
7use super::Variable;
8
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, Copy)]
11pub enum BarrierLevel {
12 Unit,
13 Cube,
14}
15
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
17#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationReflect)]
18#[operation(opcode_name = BarrierOpCode)]
19pub enum BarrierOps {
21 Declare {
23 barrier: Variable,
24 },
25 Init {
27 barrier: Variable,
28 is_elected: Variable,
29 arrival_count: Variable,
30 with_async_proxy_fence: bool,
31 },
32 InitManual {
34 barrier: Variable,
35 arrival_count: Variable,
36 },
37 MemCopyAsync {
39 barrier: Variable,
40 source: Variable,
41 source_length: Variable,
42 offset_source: Variable,
43 offset_out: Variable,
44 },
45 MemCopyAsyncCooperative {
47 barrier: Variable,
48 source: Variable,
49 source_length: Variable,
50 offset_source: Variable,
51 offset_out: Variable,
52 },
53 MemCopyAsyncTx {
55 barrier: Variable,
56 source: Variable,
57 source_length: Variable,
58 offset_source: Variable,
59 offset_out: Variable,
60 },
61 TmaLoad {
62 barrier: Variable,
63 tensor_map: Variable,
64 indices: Vec<Variable>,
65 offset_out: Variable,
66 },
67 TmaLoadIm2col {
68 barrier: Variable,
69 tensor_map: Variable,
70 indices: Vec<Variable>,
71 offsets: Vec<Variable>,
72 offset_out: Variable,
73 },
74 Arrive {
76 barrier: Variable,
77 },
78 ArriveTx {
79 barrier: Variable,
80 arrive_count_update: Variable,
81 transaction_count_update: Variable,
82 },
83 ExpectTx {
84 barrier: Variable,
85 transaction_count_update: Variable,
86 },
87 Wait {
88 barrier: Variable,
89 token: Variable,
90 },
91 WaitParity {
92 barrier: Variable,
93 phase: Variable,
94 },
95 ArriveAndWait {
97 barrier: Variable,
98 },
99}
100
101impl Display for BarrierOps {
102 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
103 match self {
104 BarrierOps::Declare { .. } => Ok(()),
105 BarrierOps::Init {
106 barrier,
107 arrival_count,
108 with_async_proxy_fence,
109 ..
110 } => match with_async_proxy_fence {
111 true => write!(f, "init_barrier_tma({barrier}, {arrival_count})"),
112 false => write!(f, "init_barrier({barrier}, {arrival_count})"),
113 },
114 BarrierOps::InitManual {
115 barrier,
116 arrival_count,
117 } => {
118 write!(f, "init_barrier({barrier}, {arrival_count})")
119 }
120 BarrierOps::MemCopyAsync {
121 barrier,
122 source,
123 offset_source,
124 offset_out,
125 ..
126 } => {
127 write!(
128 f,
129 "out[{offset_out}] = mem_copy_async({barrier}, source: {source}[{offset_source}])",
130 )
131 }
132 BarrierOps::MemCopyAsyncCooperative {
133 barrier,
134 source,
135 offset_source,
136 offset_out,
137 ..
138 } => {
139 write!(
140 f,
141 "out[{offset_out}] = mem_copy_async_cooperative({barrier}, source: {source}[{offset_source}])",
142 )
143 }
144 BarrierOps::MemCopyAsyncTx {
145 barrier,
146 source,
147 offset_source,
148 offset_out,
149 ..
150 } => {
151 write!(
152 f,
153 "out[{offset_out}] = mem_copy_async_tx({barrier}, source: {source}[{offset_source}])",
154 )
155 }
156 BarrierOps::ArriveAndWait { barrier } => write!(f, "arrive_and_wait({barrier})"),
157 BarrierOps::TmaLoad {
158 barrier,
159 tensor_map,
160 offset_out,
161 indices,
162 } => {
163 let rank = indices.len();
164 let indices = indices.iter().fold(String::new(), |mut s, it| {
165 let _ = write!(s, "{it}, ");
166 s
167 });
168 write!(
169 f,
170 "out[{offset_out}] = tma_load::<{rank}>({barrier}, {tensor_map}, {indices})"
171 )
172 }
173 BarrierOps::TmaLoadIm2col {
174 barrier,
175 tensor_map,
176 indices,
177 offsets,
178 offset_out,
179 } => {
180 let rank = indices.len();
181 let indices = indices.iter().fold(String::new(), |mut s, it| {
182 let _ = write!(s, "{it}, ");
183 s
184 });
185 let offsets = offsets.iter().fold(String::new(), |mut s, it| {
186 let _ = write!(s, "{it}, ");
187 s
188 });
189 write!(
190 f,
191 "out[{offset_out}] = tma_load_im2col::<{rank}>({barrier}, {tensor_map}, indices: ({indices}), offsets: ({offsets}))"
192 )
193 }
194 BarrierOps::Arrive { barrier } => write!(f, "arrive({barrier})"),
195 BarrierOps::ArriveTx {
196 barrier,
197 arrive_count_update,
198 transaction_count_update,
199 } => write!(
200 f,
201 "arrive_tx({barrier}, {arrive_count_update}, {transaction_count_update})"
202 ),
203 BarrierOps::ExpectTx {
204 barrier,
205 transaction_count_update,
206 } => write!(f, "expect_tx({barrier}, {transaction_count_update})"),
207 BarrierOps::Wait { barrier, token } => write!(f, "wait({barrier}, {token})"),
208 BarrierOps::WaitParity { barrier, phase } => {
209 write!(f, "wait_parity({barrier}, {phase})")
210 }
211 }
212 }
213}
214
215impl From<BarrierOps> for Instruction {
216 fn from(value: BarrierOps) -> Self {
217 Instruction::no_out(value)
218 }
219}