cubecl_core/frontend/
barrier.rs

1//! This module exposes barrier for asynchronous data transfer
2
3use std::marker::PhantomData;
4
5use cubecl_ir::{ExpandElement, Instruction};
6use paste::paste;
7
8use crate::{
9    ir::{BarrierOps, Item, Scope},
10    unexpanded,
11};
12
13use super::{
14    CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, Init, Line, Slice, SliceMut, TensorMap,
15};
16
17/// A mechanism for awaiting on asynchronous data transfers
18/// Behaviour is defined by its [BarrierLevel](BarrierLevel).
19#[derive(Clone, Copy)]
20pub struct Barrier<C: CubePrimitive> {
21    _c: PhantomData<C>,
22}
23
24impl<C: CubePrimitive> CubeType for Barrier<C> {
25    type ExpandType = BarrierExpand<C>;
26}
27
28impl<C: CubePrimitive> Init for BarrierExpand<C> {
29    fn init(self, _scope: &mut Scope) -> Self {
30        self
31    }
32}
33
34impl<C: CubePrimitive> CubeDebug for BarrierExpand<C> {
35    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
36        scope.update_variable_name(*self.elem, name);
37    }
38}
39
40#[derive(Clone)]
41/// Expand type of [Barrier]
42pub struct BarrierExpand<C: CubePrimitive> {
43    elem: ExpandElement,
44    _c: PhantomData<C>,
45}
46
47#[derive(Copy, Clone, PartialEq, Eq)]
48pub struct BarrierLevel(InnerBarrierLevel);
49
50impl CubeType for BarrierLevel {
51    type ExpandType = Self;
52}
53
54impl Init for BarrierLevel {
55    fn init(self, _scope: &mut Scope) -> Self {
56        self
57    }
58}
59
60impl CubeDebug for BarrierLevel {
61    fn set_debug_name(&self, _scope: &mut Scope, _name: &'static str) {}
62}
63
64#[derive(Copy, Clone, Eq, PartialEq)]
65/// Defines how many units must reach the barrier before execution can continue.
66/// This also determines how `memcpy_async` operations should be handled.
67enum InnerBarrierLevel {
68    /// Waits only for the unit that declared this barrier.
69    /// Useful for synchronizing after async data loading.
70    Unit,
71
72    /// All units in the Cube must reach the barrier before continuing.
73    /// The argument is the ID of the unit elected for initialization.
74    ///
75    /// `memcpy_async` is **cooperative**, so all units in the Cube must call `memcpy_async` with the same arguments.
76    /// The called is not elected by default, so it must be done manually if wanted
77    CubeCoop(u32),
78
79    /// All units in the Cube must reach the barrier before continuing.
80    /// The argument is the ID of the unit elected for initialization.
81    ///
82    /// `memcpy_async` is **not cooperative**, so each unit must manually handle its own data slice.
83    CubeManual(u32),
84}
85
86impl BarrierLevel {
87    /// Creates a Unit barrier level
88    pub fn unit() -> Self {
89        BarrierLevel(InnerBarrierLevel::Unit)
90    }
91
92    /// Creates a CubeCoop barrier level
93    ///
94    /// Will sync all units
95    pub fn cube_coop(elected_unit: u32) -> Self {
96        BarrierLevel(InnerBarrierLevel::CubeCoop(elected_unit))
97    }
98
99    /// Creates a CubeManual barrier level
100    ///
101    /// Will sync all units
102    pub fn cube_manual(elected_unit: u32) -> Self {
103        BarrierLevel(InnerBarrierLevel::CubeManual(elected_unit))
104    }
105
106    pub fn __expand_unit(_scope: &mut Scope) -> BarrierLevel {
107        BarrierLevel(InnerBarrierLevel::Unit)
108    }
109
110    pub fn __expand_cube_coop(_scope: &mut Scope, elected_unit: u32) -> Self {
111        BarrierLevel(InnerBarrierLevel::CubeCoop(elected_unit))
112    }
113
114    pub fn __expand_cube_manual(_scope: &mut Scope, elected_unit: u32) -> Self {
115        BarrierLevel(InnerBarrierLevel::CubeManual(elected_unit))
116    }
117}
118
119impl From<InnerBarrierLevel> for cubecl_ir::BarrierLevel {
120    fn from(val: InnerBarrierLevel) -> Self {
121        match val {
122            InnerBarrierLevel::Unit => cubecl_ir::BarrierLevel::Unit,
123            InnerBarrierLevel::CubeCoop(elected_unit) => {
124                cubecl_ir::BarrierLevel::CubeCoop(elected_unit)
125            }
126            InnerBarrierLevel::CubeManual(elected_unit) => {
127                cubecl_ir::BarrierLevel::CubeManual(elected_unit)
128            }
129        }
130    }
131}
132
133macro_rules! tensor_map_load {
134    ($dim: literal, $($arg: expr),*) => {
135        paste! {
136            impl<C: CubePrimitive> Barrier<C> {
137                /// Copy a tile from a global memory `source` to a shared memory `destination`, with
138                /// the provided offsets.
139                #[allow(unused, clippy::too_many_arguments)]
140                pub fn [<tma_load_ $dim d>](
141                    &self,
142                    source: &TensorMap<C>,
143                    destination: &mut SliceMut<Line<C>>,
144                    $($arg: i32),*
145                ) {
146                    unexpanded!()
147                }
148
149                #[allow(clippy::too_many_arguments)]
150                pub fn [<__expand_tma_load_ $dim d>](
151                    scope: &mut Scope,
152                    expand: BarrierExpand<C>,
153                    source: ExpandElementTyped<TensorMap<C>>,
154                    destination: ExpandElementTyped<SliceMut<Line<C>>>,
155                    $($arg: ExpandElementTyped<i32>),*
156                ) {
157                    expand.[<__expand_tma_load_ $dim d_method>](scope, source, destination, $($arg),*);
158                }
159            }
160
161            impl<C: CubePrimitive> BarrierExpand<C> {
162                #[allow(clippy::too_many_arguments)]
163                pub fn [<__expand_tma_load_ $dim d_method>](
164                    &self,
165                    scope: &mut Scope,
166                    source: ExpandElementTyped<TensorMap<C>>,
167                    destination: ExpandElementTyped<SliceMut<Line<C>>>,
168                    $($arg: ExpandElementTyped<i32>),*
169                ) {
170                    let barrier = *self.elem;
171                    let source = *source.expand;
172                    let destination = *destination.expand;
173
174                    let mem_copy = BarrierOps::TmaLoad {
175                        barrier,
176                        tensor_map: source,
177                        indices: vec![$(*$arg.expand),*],
178                    };
179
180                    scope.register(Instruction::new(mem_copy, destination));
181                }
182            }
183        }
184    };
185}
186
187macro_rules! tensor_map_load_im2col {
188    ($dim: literal, $($arg: expr),*; $($offset: expr),*) => {
189        paste! {
190            impl<C: CubePrimitive> Barrier<C> {
191                /// Copy a tile from a global memory `source` to a shared memory `destination`, with
192                /// the provided offsets.
193                #[allow(unused, clippy::too_many_arguments)]
194                pub fn [<tma_load_im2col_ $dim d>](
195                    &self,
196                    source: &TensorMap<C>,
197                    destination: &mut SliceMut<Line<C>>,
198                    $($arg: i32,)*
199                    $($offset: u16),*
200                ) {
201                    unexpanded!()
202                }
203
204                #[allow(clippy::too_many_arguments)]
205                pub fn [<__expand_tma_load_im2col_ $dim d>](
206                    scope: &mut Scope,
207                    expand: BarrierExpand<C>,
208                    source: ExpandElementTyped<TensorMap<C>>,
209                    destination: ExpandElementTyped<SliceMut<Line<C>>>,
210                    $($arg: ExpandElementTyped<i32>,)*
211                    $($offset: ExpandElementTyped<u16>),*
212                ) {
213                    expand.[<__expand_tma_load_im2col_ $dim d_method>](scope, source, destination, $($arg),*, $($offset),*);
214                }
215            }
216
217            impl<C: CubePrimitive> BarrierExpand<C> {
218                #[allow(clippy::too_many_arguments)]
219                pub fn [<__expand_tma_load_im2col_ $dim d_method>](
220                    &self,
221                    scope: &mut Scope,
222                    source: ExpandElementTyped<TensorMap<C>>,
223                    destination: ExpandElementTyped<SliceMut<Line<C>>>,
224                    $($arg: ExpandElementTyped<i32>,)*
225                    $($offset: ExpandElementTyped<u16>),*
226                ) {
227                    let barrier = *self.elem;
228                    let source = *source.expand;
229                    let destination = *destination.expand;
230
231                    let mem_copy = BarrierOps::TmaLoadIm2col {
232                        barrier,
233                        tensor_map: source,
234                        indices: vec![$(*$arg.expand),*],
235                        offsets: vec![$(*$offset.expand),*],
236                    };
237
238                    scope.register(Instruction::new(mem_copy, destination));
239                }
240            }
241        }
242    };
243}
244
245tensor_map_load!(2, y, x);
246tensor_map_load!(3, z, y, x);
247tensor_map_load!(4, w, z, y, x);
248tensor_map_load!(5, v, w, z, y, x);
249
250tensor_map_load_im2col!(3, n, w, c; w_offset);
251tensor_map_load_im2col!(4, n, h, w, c; h_offset, w_offset);
252tensor_map_load_im2col!(5, n, d, h, w, c; d_offset, h_offset, w_offset);
253
254impl<C: CubePrimitive> Barrier<C> {
255    /// Creates a barrier using a user defined comptime barrier level
256    pub fn new(_level: BarrierLevel) -> Self {
257        Self { _c: PhantomData }
258    }
259
260    /// Creates a new barrier for use with TMA instructions. Adds a shared memory proxy barrier to
261    /// the initialization.
262    pub fn new_with_tma_proxy(_level: BarrierLevel) -> Self {
263        Self { _c: PhantomData }
264    }
265
266    /// Copy the source slice to destination
267    ///
268    /// # Safety
269    ///
270    /// This will try to copy the whole source slice, so
271    /// make sure source length <= destination length
272    pub fn memcpy_async(&self, _source: &Slice<Line<C>>, _destination: &mut SliceMut<Line<C>>) {
273        unexpanded!()
274    }
275
276    /// Arrive at the barrier, decrementing arrival count
277    pub fn arrive(&self) {
278        unexpanded!()
279    }
280
281    /// Arrive at the barrier, decrementing arrival count. Additionally increments expected count.
282    pub fn arrive_tx(&self, _arrival_count: u32, _transaction_count: u32) {
283        unexpanded!()
284    }
285
286    /// Increments the expected count of the barrier.
287    pub fn expect_tx(&self, _expected_count: u32) {
288        unexpanded!()
289    }
290
291    /// Wait at the barrier until all arrivals are done
292    pub fn wait(&self) {
293        unexpanded!()
294    }
295
296    /// Wait until all data is loaded
297    pub fn arrive_and_wait(&self) {
298        unexpanded!()
299    }
300
301    pub fn __expand_new(scope: &mut Scope, level: BarrierLevel) -> BarrierExpand<C> {
302        let elem = C::as_elem(scope);
303
304        let variable = scope.create_barrier(Item::new(elem), level.0.into());
305        scope.register(BarrierOps::Init {
306            barrier: *variable,
307            with_cta_fence: false,
308        });
309        BarrierExpand {
310            elem: variable,
311            _c: PhantomData,
312        }
313    }
314
315    pub fn __expand_new_with_tma_proxy(scope: &mut Scope, level: BarrierLevel) -> BarrierExpand<C> {
316        let elem = C::as_elem(scope);
317
318        let variable = scope.create_barrier(Item::new(elem), level.0.into());
319        scope.register(BarrierOps::Init {
320            barrier: *variable,
321            with_cta_fence: true,
322        });
323        BarrierExpand {
324            elem: variable,
325            _c: PhantomData,
326        }
327    }
328
329    pub fn __expand_memcpy_async(
330        scope: &mut Scope,
331        expand: BarrierExpand<C>,
332        source: ExpandElementTyped<Slice<Line<C>>>,
333        destination: ExpandElementTyped<SliceMut<Line<C>>>,
334    ) {
335        expand.__expand_memcpy_async_method(scope, source, destination);
336    }
337
338    pub fn __expand_arrive(scope: &mut Scope, expand: BarrierExpand<C>) {
339        expand.__expand_arrive_method(scope);
340    }
341
342    pub fn __expand_arrive_tx(
343        scope: &mut Scope,
344        expand: BarrierExpand<C>,
345        arrival_count: ExpandElementTyped<u32>,
346        transaction_count: ExpandElementTyped<u32>,
347    ) {
348        expand.__expand_arrive_tx_method(scope, arrival_count, transaction_count);
349    }
350
351    pub fn __expand_expect_tx(
352        scope: &mut Scope,
353        expand: BarrierExpand<C>,
354        expected_count: ExpandElementTyped<u32>,
355    ) {
356        expand.__expand_expect_tx_method(scope, expected_count);
357    }
358
359    pub fn __expand_wait(scope: &mut Scope, expand: BarrierExpand<C>) {
360        expand.__expand_wait_method(scope);
361    }
362
363    pub fn __expand_arrive_and_wait(scope: &mut Scope, expand: BarrierExpand<C>) {
364        expand.__expand_arrive_and_wait_method(scope);
365    }
366}
367
368impl<C: CubePrimitive> BarrierExpand<C> {
369    pub fn __expand_memcpy_async_method(
370        &self,
371        scope: &mut Scope,
372        source: ExpandElementTyped<Slice<Line<C>>>,
373        destination: ExpandElementTyped<SliceMut<Line<C>>>,
374    ) {
375        let barrier = *self.elem;
376        let source = *source.expand;
377        let destination = *destination.expand;
378
379        let mem_copy = BarrierOps::MemCopyAsync { barrier, source };
380
381        scope.register(Instruction::new(mem_copy, destination));
382    }
383
384    pub fn __expand_arrive_method(&self, scope: &mut Scope) {
385        let barrier = *self.elem;
386        scope.register(BarrierOps::Arrive { barrier });
387    }
388
389    pub fn __expand_arrive_tx_method(
390        &self,
391        scope: &mut Scope,
392        arrival_count: ExpandElementTyped<u32>,
393        transaction_count: ExpandElementTyped<u32>,
394    ) {
395        let barrier = *self.elem;
396        let arrival_count: ExpandElement = arrival_count.into();
397        let transaction_count: ExpandElement = transaction_count.into();
398        scope.register(BarrierOps::ArriveTx {
399            barrier,
400            arrive_count_update: arrival_count.consume(),
401            transaction_count_update: transaction_count.consume(),
402        });
403    }
404
405    pub fn __expand_expect_tx_method(
406        &self,
407        scope: &mut Scope,
408        transaction_count: ExpandElementTyped<u32>,
409    ) {
410        let barrier = *self.elem;
411        let transaction_count: ExpandElement = transaction_count.into();
412        scope.register(BarrierOps::ExpectTx {
413            barrier,
414            transaction_count_update: transaction_count.consume(),
415        });
416    }
417
418    pub fn __expand_wait_method(&self, scope: &mut Scope) {
419        let barrier = *self.elem;
420        scope.register(BarrierOps::Wait { barrier });
421    }
422
423    pub fn __expand_arrive_and_wait_method(&self, scope: &mut Scope) {
424        let barrier = *self.elem;
425        scope.register(BarrierOps::ArriveAndWait { barrier });
426    }
427}