1use std::marker::PhantomData;
4
5use cubecl_ir::{ExpandElement, Instruction};
6use paste::paste;
7
8use crate::{
9 ir::{BarrierOps, Item, Scope},
10 unexpanded,
11};
12
13use super::{
14 CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, Init, Line, Slice, SliceMut, TensorMap,
15};
16
17#[derive(Clone, Copy)]
20pub struct Barrier<C: CubePrimitive> {
21 _c: PhantomData<C>,
22}
23
24impl<C: CubePrimitive> CubeType for Barrier<C> {
25 type ExpandType = BarrierExpand<C>;
26}
27
28impl<C: CubePrimitive> Init for BarrierExpand<C> {
29 fn init(self, _scope: &mut Scope) -> Self {
30 self
31 }
32}
33
34impl<C: CubePrimitive> CubeDebug for BarrierExpand<C> {
35 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
36 scope.update_variable_name(*self.elem, name);
37 }
38}
39
40#[derive(Clone)]
41pub struct BarrierExpand<C: CubePrimitive> {
43 elem: ExpandElement,
44 _c: PhantomData<C>,
45}
46
47#[derive(Copy, Clone, PartialEq, Eq)]
48pub struct BarrierLevel(InnerBarrierLevel);
49
50impl CubeType for BarrierLevel {
51 type ExpandType = Self;
52}
53
54impl Init for BarrierLevel {
55 fn init(self, _scope: &mut Scope) -> Self {
56 self
57 }
58}
59
60impl CubeDebug for BarrierLevel {
61 fn set_debug_name(&self, _scope: &mut Scope, _name: &'static str) {}
62}
63
64#[derive(Copy, Clone, Eq, PartialEq)]
65enum InnerBarrierLevel {
68 Unit,
71
72 CubeCoop(u32),
78
79 CubeManual(u32),
84}
85
86impl BarrierLevel {
87 pub fn unit() -> Self {
89 BarrierLevel(InnerBarrierLevel::Unit)
90 }
91
92 pub fn cube_coop(elected_unit: u32) -> Self {
96 BarrierLevel(InnerBarrierLevel::CubeCoop(elected_unit))
97 }
98
99 pub fn cube_manual(elected_unit: u32) -> Self {
103 BarrierLevel(InnerBarrierLevel::CubeManual(elected_unit))
104 }
105
106 pub fn __expand_unit(_scope: &mut Scope) -> BarrierLevel {
107 BarrierLevel(InnerBarrierLevel::Unit)
108 }
109
110 pub fn __expand_cube_coop(_scope: &mut Scope, elected_unit: u32) -> Self {
111 BarrierLevel(InnerBarrierLevel::CubeCoop(elected_unit))
112 }
113
114 pub fn __expand_cube_manual(_scope: &mut Scope, elected_unit: u32) -> Self {
115 BarrierLevel(InnerBarrierLevel::CubeManual(elected_unit))
116 }
117}
118
119impl From<InnerBarrierLevel> for cubecl_ir::BarrierLevel {
120 fn from(val: InnerBarrierLevel) -> Self {
121 match val {
122 InnerBarrierLevel::Unit => cubecl_ir::BarrierLevel::Unit,
123 InnerBarrierLevel::CubeCoop(elected_unit) => {
124 cubecl_ir::BarrierLevel::CubeCoop(elected_unit)
125 }
126 InnerBarrierLevel::CubeManual(elected_unit) => {
127 cubecl_ir::BarrierLevel::CubeManual(elected_unit)
128 }
129 }
130 }
131}
132
133macro_rules! tensor_map_load {
134 ($dim: literal, $($arg: expr),*) => {
135 paste! {
136 impl<C: CubePrimitive> Barrier<C> {
137 #[allow(unused, clippy::too_many_arguments)]
140 pub fn [<tma_load_ $dim d>](
141 &self,
142 source: &TensorMap<C>,
143 destination: &mut SliceMut<Line<C>>,
144 $($arg: i32),*
145 ) {
146 unexpanded!()
147 }
148
149 #[allow(clippy::too_many_arguments)]
150 pub fn [<__expand_tma_load_ $dim d>](
151 scope: &mut Scope,
152 expand: BarrierExpand<C>,
153 source: ExpandElementTyped<TensorMap<C>>,
154 destination: ExpandElementTyped<SliceMut<Line<C>>>,
155 $($arg: ExpandElementTyped<i32>),*
156 ) {
157 expand.[<__expand_tma_load_ $dim d_method>](scope, source, destination, $($arg),*);
158 }
159 }
160
161 impl<C: CubePrimitive> BarrierExpand<C> {
162 #[allow(clippy::too_many_arguments)]
163 pub fn [<__expand_tma_load_ $dim d_method>](
164 &self,
165 scope: &mut Scope,
166 source: ExpandElementTyped<TensorMap<C>>,
167 destination: ExpandElementTyped<SliceMut<Line<C>>>,
168 $($arg: ExpandElementTyped<i32>),*
169 ) {
170 let barrier = *self.elem;
171 let source = *source.expand;
172 let destination = *destination.expand;
173
174 let mem_copy = BarrierOps::TmaLoad {
175 barrier,
176 tensor_map: source,
177 indices: vec![$(*$arg.expand),*],
178 };
179
180 scope.register(Instruction::new(mem_copy, destination));
181 }
182 }
183 }
184 };
185}
186
187macro_rules! tensor_map_load_im2col {
188 ($dim: literal, $($arg: expr),*; $($offset: expr),*) => {
189 paste! {
190 impl<C: CubePrimitive> Barrier<C> {
191 #[allow(unused, clippy::too_many_arguments)]
194 pub fn [<tma_load_im2col_ $dim d>](
195 &self,
196 source: &TensorMap<C>,
197 destination: &mut SliceMut<Line<C>>,
198 $($arg: i32,)*
199 $($offset: u16),*
200 ) {
201 unexpanded!()
202 }
203
204 #[allow(clippy::too_many_arguments)]
205 pub fn [<__expand_tma_load_im2col_ $dim d>](
206 scope: &mut Scope,
207 expand: BarrierExpand<C>,
208 source: ExpandElementTyped<TensorMap<C>>,
209 destination: ExpandElementTyped<SliceMut<Line<C>>>,
210 $($arg: ExpandElementTyped<i32>,)*
211 $($offset: ExpandElementTyped<u16>),*
212 ) {
213 expand.[<__expand_tma_load_im2col_ $dim d_method>](scope, source, destination, $($arg),*, $($offset),*);
214 }
215 }
216
217 impl<C: CubePrimitive> BarrierExpand<C> {
218 #[allow(clippy::too_many_arguments)]
219 pub fn [<__expand_tma_load_im2col_ $dim d_method>](
220 &self,
221 scope: &mut Scope,
222 source: ExpandElementTyped<TensorMap<C>>,
223 destination: ExpandElementTyped<SliceMut<Line<C>>>,
224 $($arg: ExpandElementTyped<i32>,)*
225 $($offset: ExpandElementTyped<u16>),*
226 ) {
227 let barrier = *self.elem;
228 let source = *source.expand;
229 let destination = *destination.expand;
230
231 let mem_copy = BarrierOps::TmaLoadIm2col {
232 barrier,
233 tensor_map: source,
234 indices: vec![$(*$arg.expand),*],
235 offsets: vec![$(*$offset.expand),*],
236 };
237
238 scope.register(Instruction::new(mem_copy, destination));
239 }
240 }
241 }
242 };
243}
244
245tensor_map_load!(2, y, x);
246tensor_map_load!(3, z, y, x);
247tensor_map_load!(4, w, z, y, x);
248tensor_map_load!(5, v, w, z, y, x);
249
250tensor_map_load_im2col!(3, n, w, c; w_offset);
251tensor_map_load_im2col!(4, n, h, w, c; h_offset, w_offset);
252tensor_map_load_im2col!(5, n, d, h, w, c; d_offset, h_offset, w_offset);
253
254impl<C: CubePrimitive> Barrier<C> {
255 pub fn new(_level: BarrierLevel) -> Self {
257 Self { _c: PhantomData }
258 }
259
260 pub fn new_with_tma_proxy(_level: BarrierLevel) -> Self {
263 Self { _c: PhantomData }
264 }
265
266 pub fn memcpy_async(&self, _source: &Slice<Line<C>>, _destination: &mut SliceMut<Line<C>>) {
273 unexpanded!()
274 }
275
276 pub fn arrive(&self) {
278 unexpanded!()
279 }
280
281 pub fn arrive_tx(&self, _arrival_count: u32, _transaction_count: u32) {
283 unexpanded!()
284 }
285
286 pub fn expect_tx(&self, _expected_count: u32) {
288 unexpanded!()
289 }
290
291 pub fn wait(&self) {
293 unexpanded!()
294 }
295
296 pub fn arrive_and_wait(&self) {
298 unexpanded!()
299 }
300
301 pub fn __expand_new(scope: &mut Scope, level: BarrierLevel) -> BarrierExpand<C> {
302 let elem = C::as_elem(scope);
303
304 let variable = scope.create_barrier(Item::new(elem), level.0.into());
305 scope.register(BarrierOps::Init {
306 barrier: *variable,
307 with_cta_fence: false,
308 });
309 BarrierExpand {
310 elem: variable,
311 _c: PhantomData,
312 }
313 }
314
315 pub fn __expand_new_with_tma_proxy(scope: &mut Scope, level: BarrierLevel) -> BarrierExpand<C> {
316 let elem = C::as_elem(scope);
317
318 let variable = scope.create_barrier(Item::new(elem), level.0.into());
319 scope.register(BarrierOps::Init {
320 barrier: *variable,
321 with_cta_fence: true,
322 });
323 BarrierExpand {
324 elem: variable,
325 _c: PhantomData,
326 }
327 }
328
329 pub fn __expand_memcpy_async(
330 scope: &mut Scope,
331 expand: BarrierExpand<C>,
332 source: ExpandElementTyped<Slice<Line<C>>>,
333 destination: ExpandElementTyped<SliceMut<Line<C>>>,
334 ) {
335 expand.__expand_memcpy_async_method(scope, source, destination);
336 }
337
338 pub fn __expand_arrive(scope: &mut Scope, expand: BarrierExpand<C>) {
339 expand.__expand_arrive_method(scope);
340 }
341
342 pub fn __expand_arrive_tx(
343 scope: &mut Scope,
344 expand: BarrierExpand<C>,
345 arrival_count: ExpandElementTyped<u32>,
346 transaction_count: ExpandElementTyped<u32>,
347 ) {
348 expand.__expand_arrive_tx_method(scope, arrival_count, transaction_count);
349 }
350
351 pub fn __expand_expect_tx(
352 scope: &mut Scope,
353 expand: BarrierExpand<C>,
354 expected_count: ExpandElementTyped<u32>,
355 ) {
356 expand.__expand_expect_tx_method(scope, expected_count);
357 }
358
359 pub fn __expand_wait(scope: &mut Scope, expand: BarrierExpand<C>) {
360 expand.__expand_wait_method(scope);
361 }
362
363 pub fn __expand_arrive_and_wait(scope: &mut Scope, expand: BarrierExpand<C>) {
364 expand.__expand_arrive_and_wait_method(scope);
365 }
366}
367
368impl<C: CubePrimitive> BarrierExpand<C> {
369 pub fn __expand_memcpy_async_method(
370 &self,
371 scope: &mut Scope,
372 source: ExpandElementTyped<Slice<Line<C>>>,
373 destination: ExpandElementTyped<SliceMut<Line<C>>>,
374 ) {
375 let barrier = *self.elem;
376 let source = *source.expand;
377 let destination = *destination.expand;
378
379 let mem_copy = BarrierOps::MemCopyAsync { barrier, source };
380
381 scope.register(Instruction::new(mem_copy, destination));
382 }
383
384 pub fn __expand_arrive_method(&self, scope: &mut Scope) {
385 let barrier = *self.elem;
386 scope.register(BarrierOps::Arrive { barrier });
387 }
388
389 pub fn __expand_arrive_tx_method(
390 &self,
391 scope: &mut Scope,
392 arrival_count: ExpandElementTyped<u32>,
393 transaction_count: ExpandElementTyped<u32>,
394 ) {
395 let barrier = *self.elem;
396 let arrival_count: ExpandElement = arrival_count.into();
397 let transaction_count: ExpandElement = transaction_count.into();
398 scope.register(BarrierOps::ArriveTx {
399 barrier,
400 arrive_count_update: arrival_count.consume(),
401 transaction_count_update: transaction_count.consume(),
402 });
403 }
404
405 pub fn __expand_expect_tx_method(
406 &self,
407 scope: &mut Scope,
408 transaction_count: ExpandElementTyped<u32>,
409 ) {
410 let barrier = *self.elem;
411 let transaction_count: ExpandElement = transaction_count.into();
412 scope.register(BarrierOps::ExpectTx {
413 barrier,
414 transaction_count_update: transaction_count.consume(),
415 });
416 }
417
418 pub fn __expand_wait_method(&self, scope: &mut Scope) {
419 let barrier = *self.elem;
420 scope.register(BarrierOps::Wait { barrier });
421 }
422
423 pub fn __expand_arrive_and_wait_method(&self, scope: &mut Scope) {
424 let barrier = *self.elem;
425 scope.register(BarrierOps::ArriveAndWait { barrier });
426 }
427}