cubecl_core/frontend/
barrier.rs

1//! This module exposes barrier for asynchronous data transfer
2
3use std::marker::PhantomData;
4
5use cubecl_ir::{ExpandElement, Instruction};
6use paste::paste;
7
8use crate::{
9    ir::{BarrierOps, Item, Scope},
10    unexpanded,
11};
12
13use super::{
14    CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, Line, ReadOnly, ReadWrite,
15    Slice, SliceExpand, SliceMut, TensorMap,
16};
17
18/// A mechanism for awaiting on asynchronous data transfers
19/// Behaviour is defined by its [BarrierLevel](BarrierLevel).
20#[derive(Clone, Copy)]
21pub struct Barrier<C: CubePrimitive> {
22    _c: PhantomData<C>,
23}
24
25impl<C: CubePrimitive> CubeType for Barrier<C> {
26    type ExpandType = BarrierExpand<C>;
27}
28
29impl<C: CubePrimitive> IntoMut for BarrierExpand<C> {
30    fn into_mut(self, _scope: &mut Scope) -> Self {
31        self
32    }
33}
34
35impl<C: CubePrimitive> CubeDebug for BarrierExpand<C> {
36    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
37        scope.update_variable_name(*self.elem, name);
38    }
39}
40
41#[derive(Clone)]
42/// Expand type of [Barrier]
43pub struct BarrierExpand<C: CubePrimitive> {
44    elem: ExpandElement,
45    _c: PhantomData<C>,
46}
47
48#[derive(Copy, Clone, PartialEq, Eq)]
49pub struct BarrierLevel(InnerBarrierLevel);
50
51impl CubeType for BarrierLevel {
52    type ExpandType = Self;
53}
54
55impl IntoMut for BarrierLevel {
56    fn into_mut(self, _scope: &mut Scope) -> Self {
57        self
58    }
59}
60
61impl CubeDebug for BarrierLevel {
62    fn set_debug_name(&self, _scope: &mut Scope, _name: &'static str) {}
63}
64
65#[derive(Copy, Clone, Eq, PartialEq)]
66/// Defines how many units must reach the barrier before execution can continue.
67/// This also determines how `memcpy_async` operations should be handled.
68enum InnerBarrierLevel {
69    /// Waits only for the unit that declared this barrier.
70    /// Useful for synchronizing after async data loading.
71    Unit,
72
73    /// All units in the Cube must reach the barrier before continuing.
74    /// The argument is the ID of the unit elected for initialization.
75    ///
76    /// `memcpy_async` is **cooperative**, so all units in the Cube must call `memcpy_async` with the same arguments.
77    /// The called is not elected by default, so it must be done manually if wanted
78    CubeCoop(u32),
79
80    /// All units in the Cube must reach the barrier before continuing.
81    /// The argument is the ID of the unit elected for initialization.
82    ///
83    /// `memcpy_async` is **not cooperative**, so each unit must manually handle its own data slice.
84    CubeManual(u32),
85}
86
87impl BarrierLevel {
88    /// Creates a Unit barrier level
89    pub fn unit() -> Self {
90        BarrierLevel(InnerBarrierLevel::Unit)
91    }
92
93    /// Creates a CubeCoop barrier level
94    ///
95    /// Will sync all units
96    pub fn cube_coop(elected_unit: u32) -> Self {
97        BarrierLevel(InnerBarrierLevel::CubeCoop(elected_unit))
98    }
99
100    /// Creates a CubeManual barrier level
101    ///
102    /// Will sync all units
103    pub fn cube_manual(elected_unit: u32) -> Self {
104        BarrierLevel(InnerBarrierLevel::CubeManual(elected_unit))
105    }
106
107    pub fn __expand_unit(_scope: &mut Scope) -> BarrierLevel {
108        BarrierLevel(InnerBarrierLevel::Unit)
109    }
110
111    pub fn __expand_cube_coop(_scope: &mut Scope, elected_unit: u32) -> Self {
112        BarrierLevel(InnerBarrierLevel::CubeCoop(elected_unit))
113    }
114
115    pub fn __expand_cube_manual(_scope: &mut Scope, elected_unit: u32) -> Self {
116        BarrierLevel(InnerBarrierLevel::CubeManual(elected_unit))
117    }
118}
119
120impl From<InnerBarrierLevel> for cubecl_ir::BarrierLevel {
121    fn from(val: InnerBarrierLevel) -> Self {
122        match val {
123            InnerBarrierLevel::Unit => cubecl_ir::BarrierLevel::Unit,
124            InnerBarrierLevel::CubeCoop(elected_unit) => {
125                cubecl_ir::BarrierLevel::CubeCoop(elected_unit)
126            }
127            InnerBarrierLevel::CubeManual(elected_unit) => {
128                cubecl_ir::BarrierLevel::CubeManual(elected_unit)
129            }
130        }
131    }
132}
133
134macro_rules! tensor_map_load {
135    ($dim: literal, $($arg: expr),*) => {
136        paste! {
137            impl<C: CubePrimitive> Barrier<C> {
138                /// Copy a tile from a global memory `source` to a shared memory `destination`, with
139                /// the provided offsets.
140                #[allow(unused, clippy::too_many_arguments)]
141                pub fn [<tma_load_ $dim d>](
142                    &self,
143                    source: &TensorMap<C>,
144                    destination: &mut SliceMut<Line<C>>,
145                    $($arg: i32),*
146                ) {
147                    unexpanded!()
148                }
149
150                #[allow(clippy::too_many_arguments)]
151                pub fn [<__expand_tma_load_ $dim d>](
152                    scope: &mut Scope,
153                    expand: BarrierExpand<C>,
154                    source: ExpandElementTyped<TensorMap<C>>,
155                    destination: SliceExpand<Line<C>, ReadWrite>,
156                    $($arg: ExpandElementTyped<i32>),*
157                ) {
158                    expand.[<__expand_tma_load_ $dim d_method>](scope, source, destination, $($arg),*);
159                }
160            }
161
162            impl<C: CubePrimitive> BarrierExpand<C> {
163                #[allow(clippy::too_many_arguments)]
164                pub fn [<__expand_tma_load_ $dim d_method>](
165                    &self,
166                    scope: &mut Scope,
167                    source: ExpandElementTyped<TensorMap<C>>,
168                    destination: SliceExpand<Line<C>, ReadWrite>,
169                    $($arg: ExpandElementTyped<i32>),*
170                ) {
171                    let barrier = *self.elem;
172                    let source = *source.expand;
173                    let (destination, destination_offset) = destination.__to_raw_parts();
174
175                    let mem_copy = BarrierOps::TmaLoad {
176                        barrier,
177                        tensor_map: source,
178                        indices: vec![$(*$arg.expand),*],
179                        offset_out: destination_offset
180                    };
181
182                    scope.register(Instruction::new(mem_copy, destination));
183                }
184            }
185        }
186    };
187}
188
189macro_rules! tensor_map_load_im2col {
190    ($dim: literal, $($arg: expr),*; $($offset: expr),*) => {
191        paste! {
192            impl<C: CubePrimitive> Barrier<C> {
193                /// Copy a tile from a global memory `source` to a shared memory `destination`, with
194                /// the provided offsets.
195                #[allow(unused, clippy::too_many_arguments)]
196                pub fn [<tma_load_im2col_ $dim d>](
197                    &self,
198                    source: &TensorMap<C>,
199                    destination: &mut SliceMut<Line<C>>,
200                    $($arg: i32,)*
201                    $($offset: u16),*
202                ) {
203                    unexpanded!()
204                }
205
206                #[allow(clippy::too_many_arguments)]
207                pub fn [<__expand_tma_load_im2col_ $dim d>](
208                    scope: &mut Scope,
209                    expand: BarrierExpand<C>,
210                    source: ExpandElementTyped<TensorMap<C>>,
211                    destination: SliceExpand<Line<C>, ReadWrite>,
212                    $($arg: ExpandElementTyped<i32>,)*
213                    $($offset: ExpandElementTyped<u16>),*
214                ) {
215                    expand.[<__expand_tma_load_im2col_ $dim d_method>](scope, source, destination, $($arg),*, $($offset),*);
216                }
217            }
218
219            impl<C: CubePrimitive> BarrierExpand<C> {
220                #[allow(clippy::too_many_arguments)]
221                pub fn [<__expand_tma_load_im2col_ $dim d_method>](
222                    &self,
223                    scope: &mut Scope,
224                    source: ExpandElementTyped<TensorMap<C>>,
225                    destination: SliceExpand<Line<C>, ReadWrite>,
226                    $($arg: ExpandElementTyped<i32>,)*
227                    $($offset: ExpandElementTyped<u16>),*
228                ) {
229                    let barrier = *self.elem;
230                    let source = *source.expand;
231                    let (destination, destination_offset) = destination.__to_raw_parts();
232
233                    let mem_copy = BarrierOps::TmaLoadIm2col {
234                        barrier,
235                        tensor_map: source,
236                        indices: vec![$(*$arg.expand),*],
237                        offsets: vec![$(*$offset.expand),*],
238                        offset_out: destination_offset,
239                    };
240
241                    scope.register(Instruction::new(mem_copy, destination));
242                }
243            }
244        }
245    };
246}
247
248tensor_map_load!(2, y, x);
249tensor_map_load!(3, z, y, x);
250tensor_map_load!(4, w, z, y, x);
251tensor_map_load!(5, v, w, z, y, x);
252
253tensor_map_load_im2col!(3, n, w, c; w_offset);
254tensor_map_load_im2col!(4, n, h, w, c; h_offset, w_offset);
255tensor_map_load_im2col!(5, n, d, h, w, c; d_offset, h_offset, w_offset);
256
257impl<C: CubePrimitive> Barrier<C> {
258    /// Creates a barrier using a user defined comptime barrier level
259    pub fn new(_level: BarrierLevel) -> Self {
260        Self { _c: PhantomData }
261    }
262
263    /// Creates a new barrier for use with TMA instructions. Adds a shared memory proxy barrier to
264    /// the initialization.
265    pub fn new_with_tma_proxy(_level: BarrierLevel) -> Self {
266        Self { _c: PhantomData }
267    }
268
269    /// Copy the source slice to destination
270    ///
271    /// # Safety
272    ///
273    /// This will try to copy the whole source slice, so
274    /// make sure source length <= destination length
275    pub fn memcpy_async(&self, _source: &Slice<Line<C>>, _destination: &mut SliceMut<Line<C>>) {
276        unexpanded!()
277    }
278
279    /// Arrive at the barrier, decrementing arrival count
280    pub fn arrive(&self) {
281        unexpanded!()
282    }
283
284    /// Arrive at the barrier, decrementing arrival count. Additionally increments expected count.
285    pub fn arrive_tx(&self, _arrival_count: u32, _transaction_count: u32) {
286        unexpanded!()
287    }
288
289    /// Increments the expected count of the barrier.
290    pub fn expect_tx(&self, _expected_count: u32) {
291        unexpanded!()
292    }
293
294    /// Wait at the barrier until all arrivals are done
295    pub fn wait(&self) {
296        unexpanded!()
297    }
298
299    /// Wait until all data is loaded
300    pub fn arrive_and_wait(&self) {
301        unexpanded!()
302    }
303
304    pub fn __expand_new(scope: &mut Scope, level: BarrierLevel) -> BarrierExpand<C> {
305        let elem = C::as_elem(scope);
306
307        let variable = scope.create_barrier(Item::new(elem), level.0.into());
308        scope.register(BarrierOps::Init {
309            barrier: *variable,
310            with_cta_fence: false,
311        });
312        BarrierExpand {
313            elem: variable,
314            _c: PhantomData,
315        }
316    }
317
318    pub fn __expand_new_with_tma_proxy(scope: &mut Scope, level: BarrierLevel) -> BarrierExpand<C> {
319        let elem = C::as_elem(scope);
320
321        let variable = scope.create_barrier(Item::new(elem), level.0.into());
322        scope.register(BarrierOps::Init {
323            barrier: *variable,
324            with_cta_fence: true,
325        });
326        BarrierExpand {
327            elem: variable,
328            _c: PhantomData,
329        }
330    }
331
332    pub fn __expand_memcpy_async(
333        scope: &mut Scope,
334        expand: BarrierExpand<C>,
335        source: SliceExpand<Line<C>, ReadOnly>,
336        destination: SliceExpand<Line<C>, ReadWrite>,
337    ) {
338        expand.__expand_memcpy_async_method(scope, source, destination);
339    }
340
341    pub fn __expand_arrive(scope: &mut Scope, expand: BarrierExpand<C>) {
342        expand.__expand_arrive_method(scope);
343    }
344
345    pub fn __expand_arrive_tx(
346        scope: &mut Scope,
347        expand: BarrierExpand<C>,
348        arrival_count: ExpandElementTyped<u32>,
349        transaction_count: ExpandElementTyped<u32>,
350    ) {
351        expand.__expand_arrive_tx_method(scope, arrival_count, transaction_count);
352    }
353
354    pub fn __expand_expect_tx(
355        scope: &mut Scope,
356        expand: BarrierExpand<C>,
357        expected_count: ExpandElementTyped<u32>,
358    ) {
359        expand.__expand_expect_tx_method(scope, expected_count);
360    }
361
362    pub fn __expand_wait(scope: &mut Scope, expand: BarrierExpand<C>) {
363        expand.__expand_wait_method(scope);
364    }
365
366    pub fn __expand_arrive_and_wait(scope: &mut Scope, expand: BarrierExpand<C>) {
367        expand.__expand_arrive_and_wait_method(scope);
368    }
369}
370
371impl<C: CubePrimitive> BarrierExpand<C> {
372    pub fn __expand_memcpy_async_method(
373        &self,
374        scope: &mut Scope,
375        source: SliceExpand<Line<C>, ReadOnly>,
376        destination: SliceExpand<Line<C>, ReadWrite>,
377    ) {
378        let barrier = *self.elem;
379        let source_length = *source.length.expand;
380        let (source, source_offset) = source.__to_raw_parts();
381        let (destination, destination_offset) = destination.__to_raw_parts();
382
383        let mem_copy = BarrierOps::MemCopyAsync {
384            barrier,
385            source,
386            source_length,
387            offset_source: source_offset,
388            offset_out: destination_offset,
389        };
390
391        scope.register(Instruction::new(mem_copy, destination));
392    }
393
394    pub fn __expand_arrive_method(&self, scope: &mut Scope) {
395        let barrier = *self.elem;
396        scope.register(BarrierOps::Arrive { barrier });
397    }
398
399    pub fn __expand_arrive_tx_method(
400        &self,
401        scope: &mut Scope,
402        arrival_count: ExpandElementTyped<u32>,
403        transaction_count: ExpandElementTyped<u32>,
404    ) {
405        let barrier = *self.elem;
406        let arrival_count: ExpandElement = arrival_count.into();
407        let transaction_count: ExpandElement = transaction_count.into();
408        scope.register(BarrierOps::ArriveTx {
409            barrier,
410            arrive_count_update: arrival_count.consume(),
411            transaction_count_update: transaction_count.consume(),
412        });
413    }
414
415    pub fn __expand_expect_tx_method(
416        &self,
417        scope: &mut Scope,
418        transaction_count: ExpandElementTyped<u32>,
419    ) {
420        let barrier = *self.elem;
421        let transaction_count: ExpandElement = transaction_count.into();
422        scope.register(BarrierOps::ExpectTx {
423            barrier,
424            transaction_count_update: transaction_count.consume(),
425        });
426    }
427
428    pub fn __expand_wait_method(&self, scope: &mut Scope) {
429        let barrier = *self.elem;
430        scope.register(BarrierOps::Wait { barrier });
431    }
432
433    pub fn __expand_arrive_and_wait_method(&self, scope: &mut Scope) {
434        let barrier = *self.elem;
435        scope.register(BarrierOps::ArriveAndWait { barrier });
436    }
437}