cubecl_core/frontend/
barrier.rs

1//! This module exposes barrier for asynchronous data transfer
2
3use cubecl_ir::{ExpandElement, Instruction, Variable, VariableKind};
4use paste::paste;
5
6use crate::{
7    ir::{BarrierOps, Scope},
8    prelude::{CUBE_DIM, ExpandElementIntoMut},
9    unexpanded,
10};
11
12use super::{
13    CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, Line, ReadOnly, ReadWrite,
14    Slice, SliceExpand, SliceMut, TensorMap,
15};
16
17/// A mechanism for awaiting on asynchronous data transfers
18/// Behaviour is defined by its [BarrierLevel](BarrierLevel).
19#[derive(Clone, Copy)]
20pub struct Barrier;
21
22#[derive(Clone, Copy, PartialEq)]
23pub struct BarrierToken;
24
25impl CubeType for Barrier {
26    type ExpandType = BarrierExpand;
27}
28
29impl IntoMut for BarrierExpand {
30    fn into_mut(self, _scope: &mut Scope) -> Self {
31        self
32    }
33}
34
35impl CubeDebug for BarrierExpand {
36    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
37        scope.update_variable_name(*self.elem, name);
38    }
39}
40
41impl CubeType for BarrierToken {
42    type ExpandType = ExpandElementTyped<BarrierToken>;
43}
44
45impl ExpandElementIntoMut for BarrierToken {
46    fn elem_into_mut(_scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
47        elem
48    }
49}
50
51#[derive(Clone)]
52/// Expand type of [Barrier]
53pub struct BarrierExpand {
54    elem: ExpandElement,
55}
56
57#[derive(Clone)]
58pub struct BarrierLevel(InnerBarrierLevel);
59
60impl CubeType for BarrierLevel {
61    type ExpandType = Self;
62}
63
64impl IntoMut for BarrierLevel {
65    fn into_mut(self, _scope: &mut Scope) -> Self {
66        self
67    }
68}
69
70impl CubeDebug for BarrierLevel {
71    fn set_debug_name(&self, _scope: &mut Scope, _name: &'static str) {}
72}
73
74#[derive(Clone)]
75/// Defines how many units must reach the barrier before execution can continue.
76/// This also determines how `memcpy_async` operations should be handled.
77enum InnerBarrierLevel {
78    /// Waits only for the unit that declared this barrier.
79    /// Useful for synchronizing after async data loading.
80    Unit,
81
82    /// Only the leader unit is required to reach the barrier before continuing.
83    /// The argument is the ID of the unit elected for initialization.
84    ///
85    /// TMA loads are issued from only a single unit, and this leader is the one that should arrive
86    /// on the barrier. Unlike `Unit`, this barrier is *shared*, so all threads can wait on it.
87    CubeUnit(ExpandElement),
88
89    /// All units in the Cube must reach the barrier before continuing.
90    /// The argument is the ID of the unit elected for initialization.
91    CubeFull(ExpandElement),
92
93    /// `arrival_count` units are required before the barrier can continue.
94    /// The arguments are the ID of the unit elected for initialization, and the number of units
95    /// that should call `arrive`.
96    ///
97    /// TMA loads are issued from only a single unit, and this leader is the one that should arrive
98    /// on the barrier. Unlike `Unit`, this barrier is *shared*, so all threads can wait on it.
99    CubeCustom {
100        is_elected: ExpandElement,
101        arrival_count: ExpandElement,
102    },
103
104    /// Fully manual Cube barrier, no automatic initialization
105    CubeManual,
106}
107
108impl BarrierLevel {
109    /// Creates a Unit barrier level
110    pub fn unit() -> Self {
111        BarrierLevel(InnerBarrierLevel::Unit)
112    }
113
114    /// Creates a CubeUnit barrier level
115    ///
116    /// Same as `cube_full` but with an expected arrival count of `1`. Only the leader thread will
117    /// arrive on the barrier. Useful for TMA
118    pub fn cube_unit(_is_elected: bool) -> Self {
119        unexpanded!()
120    }
121
122    /// Creates a CubeCoop barrier level
123    ///
124    /// Will sync all units
125    pub fn cube_full(_is_elected: bool) -> Self {
126        unexpanded!()
127    }
128
129    /// Creates a CubeCustom barrier level
130    ///
131    /// Will sync `arrival_count` units
132    pub fn cube_custom(_arrival_count: u32) -> Self {
133        unexpanded!()
134    }
135
136    /// Creates a CubeManual barrier level
137    /// Not initialized automatically
138    pub fn cube_manual() -> Self {
139        unexpanded!()
140    }
141
142    fn arrival_count(&self, scope: &mut Scope) -> Variable {
143        match &self.0 {
144            InnerBarrierLevel::Unit | InnerBarrierLevel::CubeUnit(_) => 1.into(),
145            InnerBarrierLevel::CubeFull(_) => *CUBE_DIM::expand(scope).expand,
146            InnerBarrierLevel::CubeCustom { arrival_count, .. } => **arrival_count,
147            InnerBarrierLevel::CubeManual => panic!("Can't get arrival count of manual barrier"),
148        }
149    }
150
151    fn is_elected(&self) -> Variable {
152        match &self.0 {
153            InnerBarrierLevel::Unit => true.into(),
154            InnerBarrierLevel::CubeUnit(is_elected)
155            | InnerBarrierLevel::CubeFull(is_elected)
156            | InnerBarrierLevel::CubeCustom { is_elected, .. } => **is_elected,
157            InnerBarrierLevel::CubeManual => panic!("Can't get `is_elected` of manual barrier"),
158        }
159    }
160
161    pub fn __expand_unit(_scope: &mut Scope) -> BarrierLevel {
162        BarrierLevel(InnerBarrierLevel::Unit)
163    }
164
165    pub fn __expand_cube_unit(_scope: &mut Scope, is_elected: ExpandElementTyped<bool>) -> Self {
166        BarrierLevel(InnerBarrierLevel::CubeUnit(is_elected.expand))
167    }
168
169    pub fn __expand_cube_full(_scope: &mut Scope, is_elected: ExpandElementTyped<bool>) -> Self {
170        BarrierLevel(InnerBarrierLevel::CubeFull(is_elected.expand))
171    }
172
173    pub fn __expand_cube_custom(
174        _scope: &mut Scope,
175        is_elected: ExpandElementTyped<bool>,
176        arrival_count: ExpandElementTyped<u32>,
177    ) -> Self {
178        BarrierLevel(InnerBarrierLevel::CubeCustom {
179            is_elected: is_elected.expand,
180            arrival_count: arrival_count.expand,
181        })
182    }
183
184    pub fn __expand_cube_manual(_scope: &mut Scope) -> Self {
185        BarrierLevel(InnerBarrierLevel::CubeManual)
186    }
187}
188
189impl From<InnerBarrierLevel> for cubecl_ir::BarrierLevel {
190    fn from(val: InnerBarrierLevel) -> Self {
191        match val {
192            InnerBarrierLevel::Unit => cubecl_ir::BarrierLevel::Unit,
193            InnerBarrierLevel::CubeUnit(_)
194            | InnerBarrierLevel::CubeFull(_)
195            | InnerBarrierLevel::CubeCustom { .. }
196            | InnerBarrierLevel::CubeManual => cubecl_ir::BarrierLevel::Cube,
197        }
198    }
199}
200
201macro_rules! tensor_map_load {
202    ($dim: literal, $($arg: expr),*) => {
203        paste! {
204            impl Barrier {
205                /// Copy a tile from a global memory `source` to a shared memory `destination`, with
206                /// the provided offsets.
207                #[allow(unused, clippy::too_many_arguments)]
208                pub fn [<tma_load_ $dim d>]<C: CubePrimitive>(
209                    &self,
210                    source: &TensorMap<C>,
211                    destination: &mut SliceMut<Line<C>>,
212                    $($arg: i32),*
213                ) {
214                    unexpanded!()
215                }
216
217                #[allow(clippy::too_many_arguments)]
218                pub fn [<__expand_tma_load_ $dim d>]<C: CubePrimitive>(
219                    scope: &mut Scope,
220                    expand: BarrierExpand,
221                    source: ExpandElementTyped<TensorMap<C>>,
222                    destination: SliceExpand<Line<C>, ReadWrite>,
223                    $($arg: ExpandElementTyped<i32>),*
224                ) {
225                    expand.[<__expand_tma_load_ $dim d_method>](scope, source, destination, $($arg),*);
226                }
227            }
228
229            impl BarrierExpand {
230                #[allow(clippy::too_many_arguments)]
231                pub fn [<__expand_tma_load_ $dim d_method>]<C: CubePrimitive>(
232                    &self,
233                    scope: &mut Scope,
234                    source: ExpandElementTyped<TensorMap<C>>,
235                    destination: SliceExpand<Line<C>, ReadWrite>,
236                    $($arg: ExpandElementTyped<i32>),*
237                ) {
238                    let barrier = *self.elem;
239                    let source = *source.expand;
240                    let (destination, destination_offset) = destination.__to_raw_parts();
241
242                    let mem_copy = BarrierOps::TmaLoad {
243                        barrier,
244                        tensor_map: source,
245                        indices: vec![$(*$arg.expand),*],
246                        offset_out: destination_offset
247                    };
248
249                    scope.register(Instruction::new(mem_copy, destination));
250                }
251            }
252        }
253    };
254}
255
256macro_rules! tensor_map_load_im2col {
257    ($dim: literal, $($arg: expr),*; $($offset: expr),*) => {
258        paste! {
259            impl Barrier {
260                /// Copy a tile from a global memory `source` to a shared memory `destination`, with
261                /// the provided offsets.
262                #[allow(unused, clippy::too_many_arguments)]
263                pub fn [<tma_load_im2col_ $dim d>]<C: CubePrimitive>(
264                    &self,
265                    source: &TensorMap<C>,
266                    destination: &mut SliceMut<Line<C>>,
267                    $($arg: i32,)*
268                    $($offset: u16),*
269                ) {
270                    unexpanded!()
271                }
272
273                #[allow(clippy::too_many_arguments)]
274                pub fn [<__expand_tma_load_im2col_ $dim d>]<C: CubePrimitive>(
275                    scope: &mut Scope,
276                    expand: BarrierExpand,
277                    source: ExpandElementTyped<TensorMap<C>>,
278                    destination: SliceExpand<Line<C>, ReadWrite>,
279                    $($arg: ExpandElementTyped<i32>,)*
280                    $($offset: ExpandElementTyped<u16>),*
281                ) {
282                    expand.[<__expand_tma_load_im2col_ $dim d_method>](scope, source, destination, $($arg),*, $($offset),*);
283                }
284            }
285
286            impl BarrierExpand {
287                #[allow(clippy::too_many_arguments)]
288                pub fn [<__expand_tma_load_im2col_ $dim d_method>]<C: CubePrimitive>(
289                    &self,
290                    scope: &mut Scope,
291                    source: ExpandElementTyped<TensorMap<C>>,
292                    destination: SliceExpand<Line<C>, ReadWrite>,
293                    $($arg: ExpandElementTyped<i32>,)*
294                    $($offset: ExpandElementTyped<u16>),*
295                ) {
296                    let barrier = *self.elem;
297                    let source = *source.expand;
298                    let (destination, destination_offset) = destination.__to_raw_parts();
299
300                    let mem_copy = BarrierOps::TmaLoadIm2col {
301                        barrier,
302                        tensor_map: source,
303                        indices: vec![$(*$arg.expand),*],
304                        offsets: vec![$(*$offset.expand),*],
305                        offset_out: destination_offset,
306                    };
307
308                    scope.register(Instruction::new(mem_copy, destination));
309                }
310            }
311        }
312    };
313}
314
315tensor_map_load!(1, x);
316tensor_map_load!(2, y, x);
317tensor_map_load!(3, z, y, x);
318tensor_map_load!(4, w, z, y, x);
319tensor_map_load!(5, v, w, z, y, x);
320
321tensor_map_load_im2col!(3, n, w, c; w_offset);
322tensor_map_load_im2col!(4, n, h, w, c; h_offset, w_offset);
323tensor_map_load_im2col!(5, n, d, h, w, c; d_offset, h_offset, w_offset);
324
325impl Barrier {
326    /// Creates a barrier using a user defined comptime barrier level
327    pub fn new(_level: BarrierLevel) -> Self {
328        Self
329    }
330
331    /// Creates a new barrier for use with TMA instructions. Adds a shared memory proxy barrier to
332    /// the initialization.
333    pub fn new_with_async_proxy_fence(_level: BarrierLevel) -> Self {
334        Self
335    }
336
337    /// Manually initialize the barrier, without handling synchronization, etc.
338    pub fn init_manual(&self, _arrival_count: u32) -> BarrierToken {
339        unexpanded!()
340    }
341
342    /// Copy the source slice to destination
343    ///
344    /// # Safety
345    ///
346    /// This will try to copy the whole source slice, so
347    /// make sure source length <= destination length
348    pub fn memcpy_async<C: CubePrimitive>(
349        &self,
350        _source: &Slice<Line<C>>,
351        _destination: &mut SliceMut<Line<C>>,
352    ) {
353        unexpanded!()
354    }
355
356    /// Copy the source slice to destination
357    ///
358    /// # Safety
359    ///
360    /// This will try to copy the whole source slice, so
361    /// make sure source length <= destination length
362    pub fn memcpy_async_cooperative<C: CubePrimitive>(
363        &self,
364        _source: &Slice<Line<C>>,
365        _destination: &mut SliceMut<Line<C>>,
366    ) {
367        unexpanded!()
368    }
369
370    /// Copy the source slice to destination. Uses transaction count like TMA, so use with
371    /// `expect_tx` or `arrive_and_expect_tx`.
372    ///
373    /// # Safety
374    ///
375    /// This will try to copy the whole source slice, so
376    /// make sure source length <= destination length
377    pub fn memcpy_async_tx<C: CubePrimitive>(
378        &self,
379        _source: &Slice<Line<C>>,
380        _destination: &mut SliceMut<Line<C>>,
381    ) {
382        unexpanded!()
383    }
384
385    /// Arrive at the barrier, decrementing arrival count
386    pub fn arrive(&self) -> BarrierToken {
387        unexpanded!()
388    }
389
390    /// Arrive at the barrier, decrementing arrival count. Additionally increments expected count.
391    pub fn arrive_and_expect_tx(
392        &self,
393        _arrival_count: u32,
394        _transaction_count: u32,
395    ) -> BarrierToken {
396        unexpanded!()
397    }
398
399    /// Increments the expected count of the barrier.
400    pub fn expect_tx(&self, _expected_count: u32) {
401        unexpanded!()
402    }
403
404    /// Wait at the barrier until all arrivals are done
405    pub fn wait(&self, _token: BarrierToken) {
406        unexpanded!()
407    }
408
409    /// Wait at the barrier until the `phase` is completed. Doesn't require a token, but needs phase
410    /// to be managed manually.
411    pub fn wait_parity(&self, _phase: u32) {
412        unexpanded!()
413    }
414
415    /// Wait until all data is loaded
416    pub fn arrive_and_wait(&self) {
417        unexpanded!()
418    }
419
420    pub fn __expand_new(scope: &mut Scope, level: BarrierLevel) -> BarrierExpand {
421        let variable = scope.create_barrier(level.0.clone().into());
422        match &level.0 {
423            InnerBarrierLevel::CubeManual => {
424                scope.register(BarrierOps::Declare { barrier: *variable });
425            }
426            _ => {
427                let is_elected = level.is_elected();
428                let arrival_count = level.arrival_count(scope);
429                scope.register(BarrierOps::Init {
430                    barrier: *variable,
431                    is_elected,
432                    arrival_count,
433                    with_async_proxy_fence: false,
434                });
435            }
436        }
437
438        BarrierExpand { elem: variable }
439    }
440
441    pub fn __expand_new_with_async_proxy_fence(
442        scope: &mut Scope,
443        level: BarrierLevel,
444    ) -> BarrierExpand {
445        let is_elected = level.is_elected();
446        let arrival_count = level.arrival_count(scope);
447        let variable = scope.create_barrier(level.0.clone().into());
448        scope.register(BarrierOps::Init {
449            barrier: *variable,
450            is_elected,
451            arrival_count,
452            with_async_proxy_fence: true,
453        });
454        BarrierExpand { elem: variable }
455    }
456
457    pub fn __expand_init_manual(
458        scope: &mut Scope,
459        expand: BarrierExpand,
460        arrival_count: ExpandElementTyped<u32>,
461    ) {
462        expand.__expand_init_manual_method(scope, arrival_count);
463    }
464
465    pub fn __expand_memcpy_async<C: CubePrimitive>(
466        scope: &mut Scope,
467        expand: BarrierExpand,
468        source: SliceExpand<Line<C>, ReadOnly>,
469        destination: SliceExpand<Line<C>, ReadWrite>,
470    ) {
471        expand.__expand_memcpy_async_method(scope, source, destination);
472    }
473
474    pub fn __expand_memcpy_async_cooperative<C: CubePrimitive>(
475        scope: &mut Scope,
476        expand: BarrierExpand,
477        source: SliceExpand<Line<C>, ReadOnly>,
478        destination: SliceExpand<Line<C>, ReadWrite>,
479    ) {
480        expand.__expand_memcpy_async_method(scope, source, destination);
481    }
482
483    pub fn __expand_memcpy_async_tx<C: CubePrimitive>(
484        scope: &mut Scope,
485        expand: BarrierExpand,
486        source: SliceExpand<Line<C>, ReadOnly>,
487        destination: SliceExpand<Line<C>, ReadWrite>,
488    ) {
489        expand.__expand_memcpy_async_tx_method(scope, source, destination);
490    }
491
492    pub fn __expand_arrive(
493        scope: &mut Scope,
494        expand: BarrierExpand,
495    ) -> ExpandElementTyped<BarrierToken> {
496        expand.__expand_arrive_method(scope)
497    }
498
499    pub fn __expand_arrive_and_expect_tx(
500        scope: &mut Scope,
501        expand: BarrierExpand,
502        arrival_count: ExpandElementTyped<u32>,
503        transaction_count: ExpandElementTyped<u32>,
504    ) -> ExpandElementTyped<BarrierToken> {
505        expand.__expand_arrive_and_expect_tx_method(scope, arrival_count, transaction_count)
506    }
507
508    pub fn __expand_expect_tx(
509        scope: &mut Scope,
510        expand: BarrierExpand,
511        expected_count: ExpandElementTyped<u32>,
512    ) {
513        expand.__expand_expect_tx_method(scope, expected_count);
514    }
515
516    pub fn __expand_wait(
517        scope: &mut Scope,
518        expand: BarrierExpand,
519        token: ExpandElementTyped<BarrierToken>,
520    ) {
521        expand.__expand_wait_method(scope, token);
522    }
523
524    pub fn __expand_wait_parity(
525        scope: &mut Scope,
526        expand: BarrierExpand,
527        phase: ExpandElementTyped<u32>,
528    ) {
529        expand.__expand_wait_parity_method(scope, phase);
530    }
531
532    pub fn __expand_arrive_and_wait(scope: &mut Scope, expand: BarrierExpand) {
533        expand.__expand_arrive_and_wait_method(scope);
534    }
535}
536
537impl BarrierExpand {
538    pub fn __expand_init_manual_method(
539        &self,
540        scope: &mut Scope,
541        arrival_count: ExpandElementTyped<u32>,
542    ) {
543        let barrier = *self.elem;
544
545        scope.register(BarrierOps::InitManual {
546            barrier,
547            arrival_count: *arrival_count.expand,
548        });
549    }
550
551    pub fn __expand_memcpy_async_method<C: CubePrimitive>(
552        &self,
553        scope: &mut Scope,
554        source: SliceExpand<Line<C>, ReadOnly>,
555        destination: SliceExpand<Line<C>, ReadWrite>,
556    ) {
557        let barrier = *self.elem;
558        let source_length = *source.length.expand;
559        let (source, source_offset) = source.__to_raw_parts();
560        let (destination, destination_offset) = destination.__to_raw_parts();
561
562        let mem_copy = BarrierOps::MemCopyAsync {
563            barrier,
564            source,
565            source_length,
566            offset_source: source_offset,
567            offset_out: destination_offset,
568        };
569
570        scope.register(Instruction::new(mem_copy, destination));
571    }
572
573    pub fn __expand_memcpy_async_cooperative_method<C: CubePrimitive>(
574        &self,
575        scope: &mut Scope,
576        source: SliceExpand<Line<C>, ReadOnly>,
577        destination: SliceExpand<Line<C>, ReadWrite>,
578    ) {
579        let barrier = *self.elem;
580        let source_length = *source.length.expand;
581        let (source, source_offset) = source.__to_raw_parts();
582        let (destination, destination_offset) = destination.__to_raw_parts();
583
584        let mem_copy = BarrierOps::MemCopyAsyncCooperative {
585            barrier,
586            source,
587            source_length,
588            offset_source: source_offset,
589            offset_out: destination_offset,
590        };
591
592        scope.register(Instruction::new(mem_copy, destination));
593    }
594
595    pub fn __expand_memcpy_async_tx_method<C: CubePrimitive>(
596        &self,
597        scope: &mut Scope,
598        source: SliceExpand<Line<C>, ReadOnly>,
599        destination: SliceExpand<Line<C>, ReadWrite>,
600    ) {
601        let barrier = *self.elem;
602        let source_length = *source.length.expand;
603        let (source, source_offset) = source.__to_raw_parts();
604        let (destination, destination_offset) = destination.__to_raw_parts();
605
606        let mem_copy = BarrierOps::MemCopyAsyncTx {
607            barrier,
608            source,
609            source_length,
610            offset_source: source_offset,
611            offset_out: destination_offset,
612        };
613
614        scope.register(Instruction::new(mem_copy, destination));
615    }
616
617    pub fn __expand_arrive_method(&self, scope: &mut Scope) -> ExpandElementTyped<BarrierToken> {
618        let barrier = *self.elem;
619        let VariableKind::Barrier { id, level, .. } = barrier.kind else {
620            unreachable!()
621        };
622        let token = scope.create_barrier_token(id, level);
623        scope.register(Instruction::new(BarrierOps::Arrive { barrier }, *token));
624        token.into()
625    }
626
627    pub fn __expand_arrive_and_expect_tx_method(
628        &self,
629        scope: &mut Scope,
630        arrival_count: ExpandElementTyped<u32>,
631        transaction_count: ExpandElementTyped<u32>,
632    ) -> ExpandElementTyped<BarrierToken> {
633        let barrier = *self.elem;
634        let VariableKind::Barrier { id, level, .. } = barrier.kind else {
635            unreachable!()
636        };
637        let token = scope.create_barrier_token(id, level);
638        let arrival_count: ExpandElement = arrival_count.into();
639        let transaction_count: ExpandElement = transaction_count.into();
640        scope.register(Instruction::new(
641            BarrierOps::ArriveTx {
642                barrier,
643                arrive_count_update: arrival_count.consume(),
644                transaction_count_update: transaction_count.consume(),
645            },
646            *token,
647        ));
648        token.into()
649    }
650
651    pub fn __expand_expect_tx_method(
652        &self,
653        scope: &mut Scope,
654        transaction_count: ExpandElementTyped<u32>,
655    ) {
656        let barrier = *self.elem;
657        let transaction_count: ExpandElement = transaction_count.into();
658        scope.register(BarrierOps::ExpectTx {
659            barrier,
660            transaction_count_update: transaction_count.consume(),
661        });
662    }
663
664    pub fn __expand_wait_method(&self, scope: &mut Scope, token: ExpandElementTyped<BarrierToken>) {
665        let barrier = *self.elem;
666        let token = *token.expand;
667        scope.register(BarrierOps::Wait { barrier, token });
668    }
669
670    pub fn __expand_wait_parity_method(&self, scope: &mut Scope, phase: ExpandElementTyped<u32>) {
671        let barrier = *self.elem;
672        let phase = *phase.expand;
673        scope.register(BarrierOps::WaitParity { barrier, phase });
674    }
675
676    pub fn __expand_arrive_and_wait_method(&self, scope: &mut Scope) {
677        let barrier = *self.elem;
678        scope.register(BarrierOps::ArriveAndWait { barrier });
679    }
680}