1use cubecl_ir::{ExpandElement, Instruction, Variable, VariableKind};
4use paste::paste;
5
6use crate::{
7 ir::{BarrierOps, Scope},
8 prelude::{CUBE_DIM, ExpandElementIntoMut},
9 unexpanded,
10};
11
12use super::{
13 CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, Line, ReadOnly, ReadWrite,
14 Slice, SliceExpand, SliceMut, TensorMap,
15};
16
17#[derive(Clone, Copy)]
20pub struct Barrier;
21
22#[derive(Clone, Copy, PartialEq)]
23pub struct BarrierToken;
24
25impl CubeType for Barrier {
26 type ExpandType = BarrierExpand;
27}
28
29impl IntoMut for BarrierExpand {
30 fn into_mut(self, _scope: &mut Scope) -> Self {
31 self
32 }
33}
34
35impl CubeDebug for BarrierExpand {
36 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
37 scope.update_variable_name(*self.elem, name);
38 }
39}
40
41impl CubeType for BarrierToken {
42 type ExpandType = ExpandElementTyped<BarrierToken>;
43}
44
45impl ExpandElementIntoMut for BarrierToken {
46 fn elem_into_mut(_scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
47 elem
48 }
49}
50
51#[derive(Clone)]
52pub struct BarrierExpand {
54 elem: ExpandElement,
55}
56
57#[derive(Clone)]
58pub struct BarrierLevel(InnerBarrierLevel);
59
60impl CubeType for BarrierLevel {
61 type ExpandType = Self;
62}
63
64impl IntoMut for BarrierLevel {
65 fn into_mut(self, _scope: &mut Scope) -> Self {
66 self
67 }
68}
69
70impl CubeDebug for BarrierLevel {
71 fn set_debug_name(&self, _scope: &mut Scope, _name: &'static str) {}
72}
73
74#[derive(Clone)]
75enum InnerBarrierLevel {
78 Unit,
81
82 CubeUnit(ExpandElement),
88
89 CubeFull(ExpandElement),
92
93 CubeCustom {
100 is_elected: ExpandElement,
101 arrival_count: ExpandElement,
102 },
103
104 CubeManual,
106}
107
108impl BarrierLevel {
109 pub fn unit() -> Self {
111 BarrierLevel(InnerBarrierLevel::Unit)
112 }
113
114 pub fn cube_unit(_is_elected: bool) -> Self {
119 unexpanded!()
120 }
121
122 pub fn cube_full(_is_elected: bool) -> Self {
126 unexpanded!()
127 }
128
129 pub fn cube_custom(_arrival_count: u32) -> Self {
133 unexpanded!()
134 }
135
136 pub fn cube_manual() -> Self {
139 unexpanded!()
140 }
141
142 fn arrival_count(&self, scope: &mut Scope) -> Variable {
143 match &self.0 {
144 InnerBarrierLevel::Unit | InnerBarrierLevel::CubeUnit(_) => 1.into(),
145 InnerBarrierLevel::CubeFull(_) => *CUBE_DIM::expand(scope).expand,
146 InnerBarrierLevel::CubeCustom { arrival_count, .. } => **arrival_count,
147 InnerBarrierLevel::CubeManual => panic!("Can't get arrival count of manual barrier"),
148 }
149 }
150
151 fn is_elected(&self) -> Variable {
152 match &self.0 {
153 InnerBarrierLevel::Unit => true.into(),
154 InnerBarrierLevel::CubeUnit(is_elected)
155 | InnerBarrierLevel::CubeFull(is_elected)
156 | InnerBarrierLevel::CubeCustom { is_elected, .. } => **is_elected,
157 InnerBarrierLevel::CubeManual => panic!("Can't get `is_elected` of manual barrier"),
158 }
159 }
160
161 pub fn __expand_unit(_scope: &mut Scope) -> BarrierLevel {
162 BarrierLevel(InnerBarrierLevel::Unit)
163 }
164
165 pub fn __expand_cube_unit(_scope: &mut Scope, is_elected: ExpandElementTyped<bool>) -> Self {
166 BarrierLevel(InnerBarrierLevel::CubeUnit(is_elected.expand))
167 }
168
169 pub fn __expand_cube_full(_scope: &mut Scope, is_elected: ExpandElementTyped<bool>) -> Self {
170 BarrierLevel(InnerBarrierLevel::CubeFull(is_elected.expand))
171 }
172
173 pub fn __expand_cube_custom(
174 _scope: &mut Scope,
175 is_elected: ExpandElementTyped<bool>,
176 arrival_count: ExpandElementTyped<u32>,
177 ) -> Self {
178 BarrierLevel(InnerBarrierLevel::CubeCustom {
179 is_elected: is_elected.expand,
180 arrival_count: arrival_count.expand,
181 })
182 }
183
184 pub fn __expand_cube_manual(_scope: &mut Scope) -> Self {
185 BarrierLevel(InnerBarrierLevel::CubeManual)
186 }
187}
188
189impl From<InnerBarrierLevel> for cubecl_ir::BarrierLevel {
190 fn from(val: InnerBarrierLevel) -> Self {
191 match val {
192 InnerBarrierLevel::Unit => cubecl_ir::BarrierLevel::Unit,
193 InnerBarrierLevel::CubeUnit(_)
194 | InnerBarrierLevel::CubeFull(_)
195 | InnerBarrierLevel::CubeCustom { .. }
196 | InnerBarrierLevel::CubeManual => cubecl_ir::BarrierLevel::Cube,
197 }
198 }
199}
200
201macro_rules! tensor_map_load {
202 ($dim: literal, $($arg: expr),*) => {
203 paste! {
204 impl Barrier {
205 #[allow(unused, clippy::too_many_arguments)]
208 pub fn [<tma_load_ $dim d>]<C: CubePrimitive>(
209 &self,
210 source: &TensorMap<C>,
211 destination: &mut SliceMut<Line<C>>,
212 $($arg: i32),*
213 ) {
214 unexpanded!()
215 }
216
217 #[allow(clippy::too_many_arguments)]
218 pub fn [<__expand_tma_load_ $dim d>]<C: CubePrimitive>(
219 scope: &mut Scope,
220 expand: BarrierExpand,
221 source: ExpandElementTyped<TensorMap<C>>,
222 destination: SliceExpand<Line<C>, ReadWrite>,
223 $($arg: ExpandElementTyped<i32>),*
224 ) {
225 expand.[<__expand_tma_load_ $dim d_method>](scope, source, destination, $($arg),*);
226 }
227 }
228
229 impl BarrierExpand {
230 #[allow(clippy::too_many_arguments)]
231 pub fn [<__expand_tma_load_ $dim d_method>]<C: CubePrimitive>(
232 &self,
233 scope: &mut Scope,
234 source: ExpandElementTyped<TensorMap<C>>,
235 destination: SliceExpand<Line<C>, ReadWrite>,
236 $($arg: ExpandElementTyped<i32>),*
237 ) {
238 let barrier = *self.elem;
239 let source = *source.expand;
240 let (destination, destination_offset) = destination.__to_raw_parts();
241
242 let mem_copy = BarrierOps::TmaLoad {
243 barrier,
244 tensor_map: source,
245 indices: vec![$(*$arg.expand),*],
246 offset_out: destination_offset
247 };
248
249 scope.register(Instruction::new(mem_copy, destination));
250 }
251 }
252 }
253 };
254}
255
256macro_rules! tensor_map_load_im2col {
257 ($dim: literal, $($arg: expr),*; $($offset: expr),*) => {
258 paste! {
259 impl Barrier {
260 #[allow(unused, clippy::too_many_arguments)]
263 pub fn [<tma_load_im2col_ $dim d>]<C: CubePrimitive>(
264 &self,
265 source: &TensorMap<C>,
266 destination: &mut SliceMut<Line<C>>,
267 $($arg: i32,)*
268 $($offset: u16),*
269 ) {
270 unexpanded!()
271 }
272
273 #[allow(clippy::too_many_arguments)]
274 pub fn [<__expand_tma_load_im2col_ $dim d>]<C: CubePrimitive>(
275 scope: &mut Scope,
276 expand: BarrierExpand,
277 source: ExpandElementTyped<TensorMap<C>>,
278 destination: SliceExpand<Line<C>, ReadWrite>,
279 $($arg: ExpandElementTyped<i32>,)*
280 $($offset: ExpandElementTyped<u16>),*
281 ) {
282 expand.[<__expand_tma_load_im2col_ $dim d_method>](scope, source, destination, $($arg),*, $($offset),*);
283 }
284 }
285
286 impl BarrierExpand {
287 #[allow(clippy::too_many_arguments)]
288 pub fn [<__expand_tma_load_im2col_ $dim d_method>]<C: CubePrimitive>(
289 &self,
290 scope: &mut Scope,
291 source: ExpandElementTyped<TensorMap<C>>,
292 destination: SliceExpand<Line<C>, ReadWrite>,
293 $($arg: ExpandElementTyped<i32>,)*
294 $($offset: ExpandElementTyped<u16>),*
295 ) {
296 let barrier = *self.elem;
297 let source = *source.expand;
298 let (destination, destination_offset) = destination.__to_raw_parts();
299
300 let mem_copy = BarrierOps::TmaLoadIm2col {
301 barrier,
302 tensor_map: source,
303 indices: vec![$(*$arg.expand),*],
304 offsets: vec![$(*$offset.expand),*],
305 offset_out: destination_offset,
306 };
307
308 scope.register(Instruction::new(mem_copy, destination));
309 }
310 }
311 }
312 };
313}
314
315tensor_map_load!(1, x);
316tensor_map_load!(2, y, x);
317tensor_map_load!(3, z, y, x);
318tensor_map_load!(4, w, z, y, x);
319tensor_map_load!(5, v, w, z, y, x);
320
321tensor_map_load_im2col!(3, n, w, c; w_offset);
322tensor_map_load_im2col!(4, n, h, w, c; h_offset, w_offset);
323tensor_map_load_im2col!(5, n, d, h, w, c; d_offset, h_offset, w_offset);
324
325impl Barrier {
326 pub fn new(_level: BarrierLevel) -> Self {
328 Self
329 }
330
331 pub fn new_with_async_proxy_fence(_level: BarrierLevel) -> Self {
334 Self
335 }
336
337 pub fn init_manual(&self, _arrival_count: u32) -> BarrierToken {
339 unexpanded!()
340 }
341
342 pub fn memcpy_async<C: CubePrimitive>(
349 &self,
350 _source: &Slice<Line<C>>,
351 _destination: &mut SliceMut<Line<C>>,
352 ) {
353 unexpanded!()
354 }
355
356 pub fn memcpy_async_cooperative<C: CubePrimitive>(
363 &self,
364 _source: &Slice<Line<C>>,
365 _destination: &mut SliceMut<Line<C>>,
366 ) {
367 unexpanded!()
368 }
369
370 pub fn memcpy_async_tx<C: CubePrimitive>(
378 &self,
379 _source: &Slice<Line<C>>,
380 _destination: &mut SliceMut<Line<C>>,
381 ) {
382 unexpanded!()
383 }
384
385 pub fn arrive(&self) -> BarrierToken {
387 unexpanded!()
388 }
389
390 pub fn arrive_and_expect_tx(
392 &self,
393 _arrival_count: u32,
394 _transaction_count: u32,
395 ) -> BarrierToken {
396 unexpanded!()
397 }
398
399 pub fn expect_tx(&self, _expected_count: u32) {
401 unexpanded!()
402 }
403
404 pub fn wait(&self, _token: BarrierToken) {
406 unexpanded!()
407 }
408
409 pub fn wait_parity(&self, _phase: u32) {
412 unexpanded!()
413 }
414
415 pub fn arrive_and_wait(&self) {
417 unexpanded!()
418 }
419
420 pub fn __expand_new(scope: &mut Scope, level: BarrierLevel) -> BarrierExpand {
421 let variable = scope.create_barrier(level.0.clone().into());
422 match &level.0 {
423 InnerBarrierLevel::CubeManual => {
424 scope.register(BarrierOps::Declare { barrier: *variable });
425 }
426 _ => {
427 let is_elected = level.is_elected();
428 let arrival_count = level.arrival_count(scope);
429 scope.register(BarrierOps::Init {
430 barrier: *variable,
431 is_elected,
432 arrival_count,
433 with_async_proxy_fence: false,
434 });
435 }
436 }
437
438 BarrierExpand { elem: variable }
439 }
440
441 pub fn __expand_new_with_async_proxy_fence(
442 scope: &mut Scope,
443 level: BarrierLevel,
444 ) -> BarrierExpand {
445 let is_elected = level.is_elected();
446 let arrival_count = level.arrival_count(scope);
447 let variable = scope.create_barrier(level.0.clone().into());
448 scope.register(BarrierOps::Init {
449 barrier: *variable,
450 is_elected,
451 arrival_count,
452 with_async_proxy_fence: true,
453 });
454 BarrierExpand { elem: variable }
455 }
456
457 pub fn __expand_init_manual(
458 scope: &mut Scope,
459 expand: BarrierExpand,
460 arrival_count: ExpandElementTyped<u32>,
461 ) {
462 expand.__expand_init_manual_method(scope, arrival_count);
463 }
464
465 pub fn __expand_memcpy_async<C: CubePrimitive>(
466 scope: &mut Scope,
467 expand: BarrierExpand,
468 source: SliceExpand<Line<C>, ReadOnly>,
469 destination: SliceExpand<Line<C>, ReadWrite>,
470 ) {
471 expand.__expand_memcpy_async_method(scope, source, destination);
472 }
473
474 pub fn __expand_memcpy_async_cooperative<C: CubePrimitive>(
475 scope: &mut Scope,
476 expand: BarrierExpand,
477 source: SliceExpand<Line<C>, ReadOnly>,
478 destination: SliceExpand<Line<C>, ReadWrite>,
479 ) {
480 expand.__expand_memcpy_async_method(scope, source, destination);
481 }
482
483 pub fn __expand_memcpy_async_tx<C: CubePrimitive>(
484 scope: &mut Scope,
485 expand: BarrierExpand,
486 source: SliceExpand<Line<C>, ReadOnly>,
487 destination: SliceExpand<Line<C>, ReadWrite>,
488 ) {
489 expand.__expand_memcpy_async_tx_method(scope, source, destination);
490 }
491
492 pub fn __expand_arrive(
493 scope: &mut Scope,
494 expand: BarrierExpand,
495 ) -> ExpandElementTyped<BarrierToken> {
496 expand.__expand_arrive_method(scope)
497 }
498
499 pub fn __expand_arrive_and_expect_tx(
500 scope: &mut Scope,
501 expand: BarrierExpand,
502 arrival_count: ExpandElementTyped<u32>,
503 transaction_count: ExpandElementTyped<u32>,
504 ) -> ExpandElementTyped<BarrierToken> {
505 expand.__expand_arrive_and_expect_tx_method(scope, arrival_count, transaction_count)
506 }
507
508 pub fn __expand_expect_tx(
509 scope: &mut Scope,
510 expand: BarrierExpand,
511 expected_count: ExpandElementTyped<u32>,
512 ) {
513 expand.__expand_expect_tx_method(scope, expected_count);
514 }
515
516 pub fn __expand_wait(
517 scope: &mut Scope,
518 expand: BarrierExpand,
519 token: ExpandElementTyped<BarrierToken>,
520 ) {
521 expand.__expand_wait_method(scope, token);
522 }
523
524 pub fn __expand_wait_parity(
525 scope: &mut Scope,
526 expand: BarrierExpand,
527 phase: ExpandElementTyped<u32>,
528 ) {
529 expand.__expand_wait_parity_method(scope, phase);
530 }
531
532 pub fn __expand_arrive_and_wait(scope: &mut Scope, expand: BarrierExpand) {
533 expand.__expand_arrive_and_wait_method(scope);
534 }
535}
536
537impl BarrierExpand {
538 pub fn __expand_init_manual_method(
539 &self,
540 scope: &mut Scope,
541 arrival_count: ExpandElementTyped<u32>,
542 ) {
543 let barrier = *self.elem;
544
545 scope.register(BarrierOps::InitManual {
546 barrier,
547 arrival_count: *arrival_count.expand,
548 });
549 }
550
551 pub fn __expand_memcpy_async_method<C: CubePrimitive>(
552 &self,
553 scope: &mut Scope,
554 source: SliceExpand<Line<C>, ReadOnly>,
555 destination: SliceExpand<Line<C>, ReadWrite>,
556 ) {
557 let barrier = *self.elem;
558 let source_length = *source.length.expand;
559 let (source, source_offset) = source.__to_raw_parts();
560 let (destination, destination_offset) = destination.__to_raw_parts();
561
562 let mem_copy = BarrierOps::MemCopyAsync {
563 barrier,
564 source,
565 source_length,
566 offset_source: source_offset,
567 offset_out: destination_offset,
568 };
569
570 scope.register(Instruction::new(mem_copy, destination));
571 }
572
573 pub fn __expand_memcpy_async_cooperative_method<C: CubePrimitive>(
574 &self,
575 scope: &mut Scope,
576 source: SliceExpand<Line<C>, ReadOnly>,
577 destination: SliceExpand<Line<C>, ReadWrite>,
578 ) {
579 let barrier = *self.elem;
580 let source_length = *source.length.expand;
581 let (source, source_offset) = source.__to_raw_parts();
582 let (destination, destination_offset) = destination.__to_raw_parts();
583
584 let mem_copy = BarrierOps::MemCopyAsyncCooperative {
585 barrier,
586 source,
587 source_length,
588 offset_source: source_offset,
589 offset_out: destination_offset,
590 };
591
592 scope.register(Instruction::new(mem_copy, destination));
593 }
594
595 pub fn __expand_memcpy_async_tx_method<C: CubePrimitive>(
596 &self,
597 scope: &mut Scope,
598 source: SliceExpand<Line<C>, ReadOnly>,
599 destination: SliceExpand<Line<C>, ReadWrite>,
600 ) {
601 let barrier = *self.elem;
602 let source_length = *source.length.expand;
603 let (source, source_offset) = source.__to_raw_parts();
604 let (destination, destination_offset) = destination.__to_raw_parts();
605
606 let mem_copy = BarrierOps::MemCopyAsyncTx {
607 barrier,
608 source,
609 source_length,
610 offset_source: source_offset,
611 offset_out: destination_offset,
612 };
613
614 scope.register(Instruction::new(mem_copy, destination));
615 }
616
617 pub fn __expand_arrive_method(&self, scope: &mut Scope) -> ExpandElementTyped<BarrierToken> {
618 let barrier = *self.elem;
619 let VariableKind::Barrier { id, level, .. } = barrier.kind else {
620 unreachable!()
621 };
622 let token = scope.create_barrier_token(id, level);
623 scope.register(Instruction::new(BarrierOps::Arrive { barrier }, *token));
624 token.into()
625 }
626
627 pub fn __expand_arrive_and_expect_tx_method(
628 &self,
629 scope: &mut Scope,
630 arrival_count: ExpandElementTyped<u32>,
631 transaction_count: ExpandElementTyped<u32>,
632 ) -> ExpandElementTyped<BarrierToken> {
633 let barrier = *self.elem;
634 let VariableKind::Barrier { id, level, .. } = barrier.kind else {
635 unreachable!()
636 };
637 let token = scope.create_barrier_token(id, level);
638 let arrival_count: ExpandElement = arrival_count.into();
639 let transaction_count: ExpandElement = transaction_count.into();
640 scope.register(Instruction::new(
641 BarrierOps::ArriveTx {
642 barrier,
643 arrive_count_update: arrival_count.consume(),
644 transaction_count_update: transaction_count.consume(),
645 },
646 *token,
647 ));
648 token.into()
649 }
650
651 pub fn __expand_expect_tx_method(
652 &self,
653 scope: &mut Scope,
654 transaction_count: ExpandElementTyped<u32>,
655 ) {
656 let barrier = *self.elem;
657 let transaction_count: ExpandElement = transaction_count.into();
658 scope.register(BarrierOps::ExpectTx {
659 barrier,
660 transaction_count_update: transaction_count.consume(),
661 });
662 }
663
664 pub fn __expand_wait_method(&self, scope: &mut Scope, token: ExpandElementTyped<BarrierToken>) {
665 let barrier = *self.elem;
666 let token = *token.expand;
667 scope.register(BarrierOps::Wait { barrier, token });
668 }
669
670 pub fn __expand_wait_parity_method(&self, scope: &mut Scope, phase: ExpandElementTyped<u32>) {
671 let barrier = *self.elem;
672 let phase = *phase.expand;
673 scope.register(BarrierOps::WaitParity { barrier, phase });
674 }
675
676 pub fn __expand_arrive_and_wait_method(&self, scope: &mut Scope) {
677 let barrier = *self.elem;
678 scope.register(BarrierOps::ArriveAndWait { barrier });
679 }
680}