cubecl_core/frontend/
barrier.rs

1//! This module exposes barrier for asynchronous data transfer
2
3use cubecl_ir::{ExpandElement, Instruction};
4use paste::paste;
5
6use crate::{
7    ir::{BarrierOps, Scope},
8    unexpanded,
9};
10
11use super::{
12    CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, Line, ReadOnly, ReadWrite,
13    Slice, SliceExpand, SliceMut, TensorMap,
14};
15
16/// A mechanism for awaiting on asynchronous data transfers
17/// Behaviour is defined by its [BarrierLevel](BarrierLevel).
18#[derive(Clone, Copy)]
19pub struct Barrier;
20
21impl CubeType for Barrier {
22    type ExpandType = BarrierExpand;
23}
24
25impl IntoMut for BarrierExpand {
26    fn into_mut(self, _scope: &mut Scope) -> Self {
27        self
28    }
29}
30
31impl CubeDebug for BarrierExpand {
32    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
33        scope.update_variable_name(*self.elem, name);
34    }
35}
36
37#[derive(Clone)]
38/// Expand type of [Barrier]
39pub struct BarrierExpand {
40    elem: ExpandElement,
41}
42
43#[derive(Copy, Clone, PartialEq, Eq)]
44pub struct BarrierLevel(InnerBarrierLevel);
45
46impl CubeType for BarrierLevel {
47    type ExpandType = Self;
48}
49
50impl IntoMut for BarrierLevel {
51    fn into_mut(self, _scope: &mut Scope) -> Self {
52        self
53    }
54}
55
56impl CubeDebug for BarrierLevel {
57    fn set_debug_name(&self, _scope: &mut Scope, _name: &'static str) {}
58}
59
60#[derive(Copy, Clone, Eq, PartialEq)]
61/// Defines how many units must reach the barrier before execution can continue.
62/// This also determines how `memcpy_async` operations should be handled.
63enum InnerBarrierLevel {
64    /// Waits only for the unit that declared this barrier.
65    /// Useful for synchronizing after async data loading.
66    Unit,
67
68    /// All units in the Cube must reach the barrier before continuing.
69    /// The argument is the ID of the unit elected for initialization.
70    ///
71    /// `memcpy_async` is **cooperative**, so all units in the Cube must call `memcpy_async` with the same arguments.
72    /// The called is not elected by default, so it must be done manually if wanted
73    CubeCoop(u32),
74
75    /// All units in the Cube must reach the barrier before continuing.
76    /// The argument is the ID of the unit elected for initialization.
77    ///
78    /// `memcpy_async` is **not cooperative**, so each unit must manually handle its own data slice.
79    CubeManual(u32),
80}
81
82impl BarrierLevel {
83    /// Creates a Unit barrier level
84    pub fn unit() -> Self {
85        BarrierLevel(InnerBarrierLevel::Unit)
86    }
87
88    /// Creates a CubeCoop barrier level
89    ///
90    /// Will sync all units
91    pub fn cube_coop(elected_unit: u32) -> Self {
92        BarrierLevel(InnerBarrierLevel::CubeCoop(elected_unit))
93    }
94
95    /// Creates a CubeManual barrier level
96    ///
97    /// Will sync all units
98    pub fn cube_manual(elected_unit: u32) -> Self {
99        BarrierLevel(InnerBarrierLevel::CubeManual(elected_unit))
100    }
101
102    pub fn __expand_unit(_scope: &mut Scope) -> BarrierLevel {
103        BarrierLevel(InnerBarrierLevel::Unit)
104    }
105
106    pub fn __expand_cube_coop(_scope: &mut Scope, elected_unit: u32) -> Self {
107        BarrierLevel(InnerBarrierLevel::CubeCoop(elected_unit))
108    }
109
110    pub fn __expand_cube_manual(_scope: &mut Scope, elected_unit: u32) -> Self {
111        BarrierLevel(InnerBarrierLevel::CubeManual(elected_unit))
112    }
113}
114
115impl From<InnerBarrierLevel> for cubecl_ir::BarrierLevel {
116    fn from(val: InnerBarrierLevel) -> Self {
117        match val {
118            InnerBarrierLevel::Unit => cubecl_ir::BarrierLevel::Unit,
119            InnerBarrierLevel::CubeCoop(elected_unit) => {
120                cubecl_ir::BarrierLevel::CubeCoop(elected_unit)
121            }
122            InnerBarrierLevel::CubeManual(elected_unit) => {
123                cubecl_ir::BarrierLevel::CubeManual(elected_unit)
124            }
125        }
126    }
127}
128
129macro_rules! tensor_map_load {
130    ($dim: literal, $($arg: expr),*) => {
131        paste! {
132            impl Barrier {
133                /// Copy a tile from a global memory `source` to a shared memory `destination`, with
134                /// the provided offsets.
135                #[allow(unused, clippy::too_many_arguments)]
136                pub fn [<tma_load_ $dim d>]<C: CubePrimitive>(
137                    &self,
138                    source: &TensorMap<C>,
139                    destination: &mut SliceMut<Line<C>>,
140                    $($arg: i32),*
141                ) {
142                    unexpanded!()
143                }
144
145                #[allow(clippy::too_many_arguments)]
146                pub fn [<__expand_tma_load_ $dim d>]<C: CubePrimitive>(
147                    scope: &mut Scope,
148                    expand: BarrierExpand,
149                    source: ExpandElementTyped<TensorMap<C>>,
150                    destination: SliceExpand<Line<C>, ReadWrite>,
151                    $($arg: ExpandElementTyped<i32>),*
152                ) {
153                    expand.[<__expand_tma_load_ $dim d_method>](scope, source, destination, $($arg),*);
154                }
155            }
156
157            impl BarrierExpand {
158                #[allow(clippy::too_many_arguments)]
159                pub fn [<__expand_tma_load_ $dim d_method>]<C: CubePrimitive>(
160                    &self,
161                    scope: &mut Scope,
162                    source: ExpandElementTyped<TensorMap<C>>,
163                    destination: SliceExpand<Line<C>, ReadWrite>,
164                    $($arg: ExpandElementTyped<i32>),*
165                ) {
166                    let barrier = *self.elem;
167                    let source = *source.expand;
168                    let (destination, destination_offset) = destination.__to_raw_parts();
169
170                    let mem_copy = BarrierOps::TmaLoad {
171                        barrier,
172                        tensor_map: source,
173                        indices: vec![$(*$arg.expand),*],
174                        offset_out: destination_offset
175                    };
176
177                    scope.register(Instruction::new(mem_copy, destination));
178                }
179            }
180        }
181    };
182}
183
184macro_rules! tensor_map_load_im2col {
185    ($dim: literal, $($arg: expr),*; $($offset: expr),*) => {
186        paste! {
187            impl Barrier {
188                /// Copy a tile from a global memory `source` to a shared memory `destination`, with
189                /// the provided offsets.
190                #[allow(unused, clippy::too_many_arguments)]
191                pub fn [<tma_load_im2col_ $dim d>]<C: CubePrimitive>(
192                    &self,
193                    source: &TensorMap<C>,
194                    destination: &mut SliceMut<Line<C>>,
195                    $($arg: i32,)*
196                    $($offset: u16),*
197                ) {
198                    unexpanded!()
199                }
200
201                #[allow(clippy::too_many_arguments)]
202                pub fn [<__expand_tma_load_im2col_ $dim d>]<C: CubePrimitive>(
203                    scope: &mut Scope,
204                    expand: BarrierExpand,
205                    source: ExpandElementTyped<TensorMap<C>>,
206                    destination: SliceExpand<Line<C>, ReadWrite>,
207                    $($arg: ExpandElementTyped<i32>,)*
208                    $($offset: ExpandElementTyped<u16>),*
209                ) {
210                    expand.[<__expand_tma_load_im2col_ $dim d_method>](scope, source, destination, $($arg),*, $($offset),*);
211                }
212            }
213
214            impl BarrierExpand {
215                #[allow(clippy::too_many_arguments)]
216                pub fn [<__expand_tma_load_im2col_ $dim d_method>]<C: CubePrimitive>(
217                    &self,
218                    scope: &mut Scope,
219                    source: ExpandElementTyped<TensorMap<C>>,
220                    destination: SliceExpand<Line<C>, ReadWrite>,
221                    $($arg: ExpandElementTyped<i32>,)*
222                    $($offset: ExpandElementTyped<u16>),*
223                ) {
224                    let barrier = *self.elem;
225                    let source = *source.expand;
226                    let (destination, destination_offset) = destination.__to_raw_parts();
227
228                    let mem_copy = BarrierOps::TmaLoadIm2col {
229                        barrier,
230                        tensor_map: source,
231                        indices: vec![$(*$arg.expand),*],
232                        offsets: vec![$(*$offset.expand),*],
233                        offset_out: destination_offset,
234                    };
235
236                    scope.register(Instruction::new(mem_copy, destination));
237                }
238            }
239        }
240    };
241}
242
243tensor_map_load!(1, x);
244tensor_map_load!(2, y, x);
245tensor_map_load!(3, z, y, x);
246tensor_map_load!(4, w, z, y, x);
247tensor_map_load!(5, v, w, z, y, x);
248
249tensor_map_load_im2col!(3, n, w, c; w_offset);
250tensor_map_load_im2col!(4, n, h, w, c; h_offset, w_offset);
251tensor_map_load_im2col!(5, n, d, h, w, c; d_offset, h_offset, w_offset);
252
253impl Barrier {
254    /// Creates a barrier using a user defined comptime barrier level
255    pub fn new(_level: BarrierLevel) -> Self {
256        Self
257    }
258
259    /// Creates a new barrier for use with TMA instructions. Adds a shared memory proxy barrier to
260    /// the initialization.
261    pub fn new_with_tma_proxy(_level: BarrierLevel) -> Self {
262        Self
263    }
264
265    /// Copy the source slice to destination
266    ///
267    /// # Safety
268    ///
269    /// This will try to copy the whole source slice, so
270    /// make sure source length <= destination length
271    pub fn memcpy_async<C: CubePrimitive>(
272        &self,
273        _source: &Slice<Line<C>>,
274        _destination: &mut SliceMut<Line<C>>,
275    ) {
276        unexpanded!()
277    }
278
279    /// Arrive at the barrier, decrementing arrival count
280    pub fn arrive(&self) {
281        unexpanded!()
282    }
283
284    /// Arrive at the barrier, decrementing arrival count. Additionally increments expected count.
285    pub fn arrive_tx(&self, _arrival_count: u32, _transaction_count: u32) {
286        unexpanded!()
287    }
288
289    /// Increments the expected count of the barrier.
290    pub fn expect_tx(&self, _expected_count: u32) {
291        unexpanded!()
292    }
293
294    /// Wait at the barrier until all arrivals are done
295    pub fn wait(&self) {
296        unexpanded!()
297    }
298
299    /// Wait until all data is loaded
300    pub fn arrive_and_wait(&self) {
301        unexpanded!()
302    }
303
304    pub fn __expand_new(scope: &mut Scope, level: BarrierLevel) -> BarrierExpand {
305        let variable = scope.create_barrier(level.0.into());
306        scope.register(BarrierOps::Init {
307            barrier: *variable,
308            with_cta_fence: false,
309        });
310        BarrierExpand { elem: variable }
311    }
312
313    pub fn __expand_new_with_tma_proxy(scope: &mut Scope, level: BarrierLevel) -> BarrierExpand {
314        let variable = scope.create_barrier(level.0.into());
315        scope.register(BarrierOps::Init {
316            barrier: *variable,
317            with_cta_fence: true,
318        });
319        BarrierExpand { elem: variable }
320    }
321
322    pub fn __expand_memcpy_async<C: CubePrimitive>(
323        scope: &mut Scope,
324        expand: BarrierExpand,
325        source: SliceExpand<Line<C>, ReadOnly>,
326        destination: SliceExpand<Line<C>, ReadWrite>,
327    ) {
328        expand.__expand_memcpy_async_method(scope, source, destination);
329    }
330
331    pub fn __expand_arrive(scope: &mut Scope, expand: BarrierExpand) {
332        expand.__expand_arrive_method(scope);
333    }
334
335    pub fn __expand_arrive_tx(
336        scope: &mut Scope,
337        expand: BarrierExpand,
338        arrival_count: ExpandElementTyped<u32>,
339        transaction_count: ExpandElementTyped<u32>,
340    ) {
341        expand.__expand_arrive_tx_method(scope, arrival_count, transaction_count);
342    }
343
344    pub fn __expand_expect_tx(
345        scope: &mut Scope,
346        expand: BarrierExpand,
347        expected_count: ExpandElementTyped<u32>,
348    ) {
349        expand.__expand_expect_tx_method(scope, expected_count);
350    }
351
352    pub fn __expand_wait(scope: &mut Scope, expand: BarrierExpand) {
353        expand.__expand_wait_method(scope);
354    }
355
356    pub fn __expand_arrive_and_wait(scope: &mut Scope, expand: BarrierExpand) {
357        expand.__expand_arrive_and_wait_method(scope);
358    }
359}
360
361impl BarrierExpand {
362    pub fn __expand_memcpy_async_method<C: CubePrimitive>(
363        &self,
364        scope: &mut Scope,
365        source: SliceExpand<Line<C>, ReadOnly>,
366        destination: SliceExpand<Line<C>, ReadWrite>,
367    ) {
368        let barrier = *self.elem;
369        let source_length = *source.length.expand;
370        let (source, source_offset) = source.__to_raw_parts();
371        let (destination, destination_offset) = destination.__to_raw_parts();
372
373        let mem_copy = BarrierOps::MemCopyAsync {
374            barrier,
375            source,
376            source_length,
377            offset_source: source_offset,
378            offset_out: destination_offset,
379        };
380
381        scope.register(Instruction::new(mem_copy, destination));
382    }
383
384    pub fn __expand_arrive_method(&self, scope: &mut Scope) {
385        let barrier = *self.elem;
386        scope.register(BarrierOps::Arrive { barrier });
387    }
388
389    pub fn __expand_arrive_tx_method(
390        &self,
391        scope: &mut Scope,
392        arrival_count: ExpandElementTyped<u32>,
393        transaction_count: ExpandElementTyped<u32>,
394    ) {
395        let barrier = *self.elem;
396        let arrival_count: ExpandElement = arrival_count.into();
397        let transaction_count: ExpandElement = transaction_count.into();
398        scope.register(BarrierOps::ArriveTx {
399            barrier,
400            arrive_count_update: arrival_count.consume(),
401            transaction_count_update: transaction_count.consume(),
402        });
403    }
404
405    pub fn __expand_expect_tx_method(
406        &self,
407        scope: &mut Scope,
408        transaction_count: ExpandElementTyped<u32>,
409    ) {
410        let barrier = *self.elem;
411        let transaction_count: ExpandElement = transaction_count.into();
412        scope.register(BarrierOps::ExpectTx {
413            barrier,
414            transaction_count_update: transaction_count.consume(),
415        });
416    }
417
418    pub fn __expand_wait_method(&self, scope: &mut Scope) {
419        let barrier = *self.elem;
420        scope.register(BarrierOps::Wait { barrier });
421    }
422
423    pub fn __expand_arrive_and_wait_method(&self, scope: &mut Scope) {
424        let barrier = *self.elem;
425        scope.register(BarrierOps::ArriveAndWait { barrier });
426    }
427}