Skip to main content

cubecl_core/frontend/
barrier.rs

1//! This module exposes barrier for asynchronous data transfer
2
3use alloc::vec;
4use core::ops::{Deref, DerefMut};
5
6use crate as cubecl;
7use cubecl_ir::{Instruction, ManagedVariable, OpaqueType};
8use cubecl_macros::intrinsic;
9use paste::paste;
10
11use crate::{
12    ir::{BarrierOps, Scope},
13    prelude::*,
14    unexpanded,
15};
16
17use super::{
18    CubePrimitive, CubeType, NativeExpand, ReadOnly, ReadWrite, Slice, SliceExpand, SliceMut,
19    TensorMap,
20};
21
22/// A mechanism for awaiting on asynchronous data transfers
23/// Behavior is defined by its ``BarrierLevel``.
24#[derive(Clone, Copy, PartialEq, Eq)]
25pub struct Barrier;
26pub type BarrierExpand = NativeExpand<Barrier>;
27
28#[derive(Clone, Copy, PartialEq)]
29pub struct BarrierToken;
30
31impl CubeType for Barrier {
32    type ExpandType = NativeExpand<Barrier>;
33}
34
35impl CubePrimitive for Barrier {
36    type Scalar = u32; // Dummy, maybe we need another trait for non-standard primitives
37    type Size = Const<1>;
38    type WithScalar<S: Scalar> = S;
39    fn from_const_value(_value: cubecl_ir::ConstantValue) -> Self {
40        unreachable!("Can't create from const value")
41    }
42}
43
44impl NativeAssign for Barrier {
45    fn elem_init_mut(_scope: &mut Scope, elem: ManagedVariable) -> ManagedVariable {
46        elem
47    }
48}
49
50impl CubeType for BarrierToken {
51    type ExpandType = NativeExpand<BarrierToken>;
52}
53
54impl NativeAssign for BarrierToken {
55    fn elem_init_mut(_scope: &mut crate::ir::Scope, elem: ManagedVariable) -> ManagedVariable {
56        elem
57    }
58}
59
60macro_rules! tensor_map_load {
61    ($dim: literal, $($arg: expr),*) => {
62        paste! {
63            impl Barrier {
64                /// Copy a tile from a global memory `source` to a shared memory `destination`, with
65                /// the provided offsets.
66                #[allow(unused, clippy::too_many_arguments)]
67                pub fn [<tma_load_ $dim d>]<C1: CubePrimitive, C2: CubePrimitive<Scalar = C1::Scalar>>(
68                    &self,
69                    source: &TensorMap<C1, Tiled>,
70                    destination: &mut SliceMut<C2>,
71                    $($arg: i32),*
72                ) {
73                    unexpanded!()
74                }
75
76                #[allow(clippy::too_many_arguments)]
77                pub fn [<__expand_tma_load_ $dim d>]<C1: CubePrimitive, C2: CubePrimitive<Scalar = C1::Scalar>>(
78                    scope: &mut Scope,
79                    expand: BarrierExpand,
80                    source: NativeExpand<TensorMap<C1, Tiled>>,
81                    destination: SliceExpand<C2, ReadWrite>,
82                    $($arg: NativeExpand<i32>),*
83                ) {
84                    expand.[<__expand_tma_load_ $dim d_method>](scope, source, destination, $($arg),*);
85                }
86            }
87
88            impl BarrierExpand {
89                #[allow(clippy::too_many_arguments)]
90                pub fn [<__expand_tma_load_ $dim d_method>]<C1: CubePrimitive, C2: CubePrimitive<Scalar = C1::Scalar>>(
91                    &self,
92                    scope: &mut Scope,
93                    source: NativeExpand<TensorMap<C1, Tiled>>,
94                    destination: SliceExpand<C2, ReadWrite>,
95                    $($arg: NativeExpand<i32>),*
96                ) {
97                    let barrier = *self.expand;
98                    let source = *source.expand;
99                    let (destination, destination_offset) = destination.__to_raw_parts();
100
101                    let mem_copy = BarrierOps::TmaLoad {
102                        barrier,
103                        tensor_map: source,
104                        indices: vec![$(*$arg.expand),*],
105                        offset_out: destination_offset
106                    };
107
108                    scope.register(Instruction::new(mem_copy, destination));
109                }
110            }
111        }
112    };
113}
114
115macro_rules! tensor_map_load_im2col {
116    ($dim: literal, $($arg: expr),*; $($offset: expr),*) => {
117        paste! {
118            impl Barrier {
119                /// Copy a tile from a global memory `source` to a shared memory `destination`, with
120                /// the provided offsets.
121                #[allow(unused, clippy::too_many_arguments)]
122                pub fn [<tma_load_im2col_ $dim d>]<C1: CubePrimitive, C2: CubePrimitive<Scalar = C1::Scalar>>(
123                    &self,
124                    source: &TensorMap<C1, Im2col>,
125                    destination: &mut SliceMut<C2>,
126                    $($arg: i32,)*
127                    $($offset: u16),*
128                ) {
129                    unexpanded!()
130                }
131
132                #[allow(clippy::too_many_arguments)]
133                pub fn [<__expand_tma_load_im2col_ $dim d>]<C1: CubePrimitive, C2: CubePrimitive<Scalar = C1::Scalar>>(
134                    scope: &mut Scope,
135                    expand: BarrierExpand,
136                    source: NativeExpand<TensorMap<C1, Im2col>>,
137                    destination: SliceExpand<C2, ReadWrite>,
138                    $($arg: NativeExpand<i32>,)*
139                    $($offset: NativeExpand<u16>),*
140                ) {
141                    expand.[<__expand_tma_load_im2col_ $dim d_method>](scope, source, destination, $($arg),*, $($offset),*);
142                }
143            }
144
145            impl BarrierExpand {
146                #[allow(clippy::too_many_arguments)]
147                pub fn [<__expand_tma_load_im2col_ $dim d_method>]<C1: CubePrimitive, C2: CubePrimitive<Scalar = C1::Scalar>>(
148                    &self,
149                    scope: &mut Scope,
150                    source: NativeExpand<TensorMap<C1, Im2col>>,
151                    destination: SliceExpand<C2, ReadWrite>,
152                    $($arg: NativeExpand<i32>,)*
153                    $($offset: NativeExpand<u16>),*
154                ) {
155                    let barrier = *self.expand;
156                    let source = *source.expand;
157                    let (destination, destination_offset) = destination.__to_raw_parts();
158
159                    let mem_copy = BarrierOps::TmaLoadIm2col {
160                        barrier,
161                        tensor_map: source,
162                        indices: vec![$(*$arg.expand),*],
163                        offsets: vec![$(*$offset.expand),*],
164                        offset_out: destination_offset,
165                    };
166
167                    scope.register(Instruction::new(mem_copy, destination));
168                }
169            }
170        }
171    };
172}
173
174tensor_map_load!(1, x);
175tensor_map_load!(2, y, x);
176tensor_map_load!(3, z, y, x);
177tensor_map_load!(4, w, z, y, x);
178tensor_map_load!(5, v, w, z, y, x);
179
180tensor_map_load_im2col!(3, n, w, c; w_offset);
181tensor_map_load_im2col!(4, n, h, w, c; h_offset, w_offset);
182tensor_map_load_im2col!(5, n, d, h, w, c; d_offset, h_offset, w_offset);
183
184#[cube(self_type = "ref")]
185impl Barrier {
186    /// Create a local barrier object for the current unit. Automatically initialized with an
187    /// arrival count of `1`.
188    pub fn local() -> Self {
189        intrinsic!(|scope| {
190            let variable =
191                scope.create_local_mut(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Unit));
192            scope.register(BarrierOps::Init {
193                barrier: *variable,
194                is_elected: true.into(),
195                arrival_count: 1.into(),
196            });
197            variable.into()
198        })
199    }
200
201    /// Create a shared memory barrier that can be accesses by all units in the cube. Initialized
202    /// by the `is_elected` unit with an arrival count of `arrival_count`. This is the number of
203    /// times `arrive` or one of its variants needs to be called before the barrier advances.
204    ///
205    /// If all units in the cube arrive on the barrier, use `CUBE_DIM` as the arrival count. For
206    /// other purposes, only a subset may need to arrive.
207    #[allow(unused_variables)]
208    pub fn shared(arrival_count: u32, is_elected: bool) -> Shared<Barrier> {
209        intrinsic!(|scope| {
210            let variable = scope.create_shared(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Cube));
211            scope.register(BarrierOps::Init {
212                barrier: *variable,
213                is_elected: *is_elected.expand,
214                arrival_count: *arrival_count.expand,
215            });
216            variable.into()
217        })
218    }
219
220    /// Create a shared memory barrier that can be accesses by all units in the cube. Only declared,
221    /// but not initialized.
222    pub fn shared_uninit() -> Shared<Barrier> {
223        intrinsic!(|scope| {
224            let variable = scope.create_shared(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Cube));
225            scope.register(BarrierOps::Declare { barrier: *variable });
226            variable.into()
227        })
228    }
229
230    /// Initializes a barrier with a given `arrival_count`. This is the number of
231    /// times `arrive` or one of its variants needs to be called before the barrier advances.
232    ///
233    /// If all units in the cube arrive on the barrier, use `CUBE_DIM` as the arrival count. For
234    /// other purposes, only a subset may need to arrive.
235    ///
236    /// # Note
237    ///
238    /// No synchronization or election is performed, this is raw initialization. For shared barriers
239    /// ensure only one unit performs the initialization, and synchronize the cube afterwards. There
240    /// may also be additional synchronization requirements for bulk copy operations, like
241    /// [`sync_async_proxy_shared()`].
242    #[allow(unused_variables)]
243    pub fn init_manual(&self, arrival_count: u32) {
244        intrinsic!(|scope| {
245            let barrier = *self.expand.clone();
246
247            scope.register(BarrierOps::InitManual {
248                barrier,
249                arrival_count: *arrival_count.expand,
250            });
251        })
252    }
253}
254
255// MemcpyAsync
256
257#[cube(self_type = "ref")]
258impl Barrier {
259    /// Copy the source slice to destination
260    ///
261    /// # Safety
262    ///
263    /// This will try to copy the whole source slice, so
264    /// make sure source length <= destination length
265    #[allow(unused_variables)]
266    pub fn memcpy_async<C: CubePrimitive>(&self, source: &Slice<C>, destination: &mut SliceMut<C>) {
267        intrinsic!(|scope| {
268            let barrier = *self.expand;
269            let source_length = *source.length.expand;
270            let (source, source_offset) = source.__to_raw_parts();
271            let (destination, destination_offset) = destination.__to_raw_parts();
272
273            let mem_copy = BarrierOps::MemCopyAsync {
274                barrier,
275                source,
276                source_length,
277                offset_source: source_offset,
278                offset_out: destination_offset,
279            };
280
281            scope.register(Instruction::new(mem_copy, destination));
282        })
283    }
284
285    /// Copy the source slice to destination
286    ///
287    /// # Safety
288    ///
289    /// This will try to copy the whole source slice, so
290    /// make sure source length <= destination length
291    #[allow(unused_variables)]
292    pub fn memcpy_async_cooperative<C: CubePrimitive>(
293        &self,
294        source: &Slice<C>,
295        destination: &mut SliceMut<C>,
296    ) {
297        intrinsic!(|scope| {
298            let barrier = *self.expand;
299            let source_length = *source.length.expand;
300            let (source, source_offset) = source.__to_raw_parts();
301            let (destination, destination_offset) = destination.__to_raw_parts();
302
303            let mem_copy = BarrierOps::MemCopyAsyncCooperative {
304                barrier,
305                source,
306                source_length,
307                offset_source: source_offset,
308                offset_out: destination_offset,
309            };
310
311            scope.register(Instruction::new(mem_copy, destination));
312        })
313    }
314
315    /// Copy the source slice to destination. Uses transaction count like TMA, so use with
316    /// `expect_tx` or `arrive_and_expect_tx`.
317    ///
318    /// # Safety
319    ///
320    /// This will try to copy the whole source slice, so
321    /// make sure source length <= destination length
322    #[allow(unused_variables)]
323    pub fn memcpy_async_tx<C: CubePrimitive>(
324        &self,
325        source: &Slice<C>,
326        destination: &mut SliceMut<C>,
327    ) {
328        intrinsic!(|scope| {
329            let barrier = *self.expand;
330            let source_length = *source.length.expand;
331            let (source, source_offset) = source.__to_raw_parts();
332            let (destination, destination_offset) = destination.__to_raw_parts();
333
334            let mem_copy = BarrierOps::MemCopyAsyncTx {
335                barrier,
336                source,
337                source_length,
338                offset_source: source_offset,
339                offset_out: destination_offset,
340            };
341
342            scope.register(Instruction::new(mem_copy, destination));
343        })
344    }
345}
346
347// Arrival and Wait
348
349#[cube(self_type = "ref")]
350impl Barrier {
351    /// Arrive at the barrier, decrementing arrival count
352    pub fn arrive(&self) -> BarrierToken {
353        intrinsic!(|scope| {
354            let barrier = *self.expand;
355            let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
356                unreachable!()
357            };
358            let token = scope.create_barrier_token(barrier.index().unwrap(), level);
359            scope.register(Instruction::new(BarrierOps::Arrive { barrier }, *token));
360            token.into()
361        })
362    }
363
364    /// Arrive at the barrier, decrementing arrival count. Additionally increments expected count.
365    #[allow(unused_variables)]
366    pub fn arrive_and_expect_tx(&self, arrival_count: u32, transaction_count: u32) -> BarrierToken {
367        intrinsic!(|scope| {
368            let barrier = *self.expand;
369            let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
370                unreachable!()
371            };
372            let token = scope.create_barrier_token(barrier.index().unwrap(), level);
373            let arrival_count: ManagedVariable = arrival_count.into();
374            let transaction_count: ManagedVariable = transaction_count.into();
375            scope.register(Instruction::new(
376                BarrierOps::ArriveTx {
377                    barrier,
378                    arrive_count_update: arrival_count.consume(),
379                    transaction_count_update: transaction_count.consume(),
380                },
381                *token,
382            ));
383            token.into()
384        })
385    }
386
387    /// Increments the expected count of the barrier.
388    #[allow(unused_variables)]
389    pub fn expect_tx(&self, expected_count: u32) {
390        intrinsic!(|scope| {
391            let barrier = *self.expand;
392            let transaction_count: ManagedVariable = expected_count.into();
393            scope.register(BarrierOps::ExpectTx {
394                barrier,
395                transaction_count_update: transaction_count.consume(),
396            });
397        })
398    }
399
400    /// Wait until all data is loaded
401    pub fn arrive_and_wait(&self) {
402        intrinsic!(|scope| {
403            let barrier = *self.expand;
404            scope.register(BarrierOps::ArriveAndWait { barrier });
405        })
406    }
407
408    /// Wait at the barrier until all arrivals are done
409    #[allow(unused_variables)]
410    pub fn wait(&self, token: BarrierToken) {
411        intrinsic!(|scope| {
412            let barrier = *self.expand;
413            let token = *token.expand;
414            scope.register(BarrierOps::Wait { barrier, token });
415        })
416    }
417
418    /// Wait at the barrier until the `phase` is completed. Doesn't require a token, but needs phase
419    /// to be managed manually.
420    #[allow(unused_variables)]
421    pub fn wait_parity(&self, phase: u32) {
422        intrinsic!(|scope| {
423            let barrier = *self.expand;
424            let phase = *phase.expand;
425            scope.register(BarrierOps::WaitParity { barrier, phase });
426        })
427    }
428}
429
430// Copy async
431
432/// Copy the source slice in global memory to destination in shared memory with a low level async
433/// copy. This only copies up to 128 bits/16 bytes, and does not synchronize. Use
434/// `barrier.copy_async_arrive` to make the reads visible.
435/// `copy_size` is in terms of elements to simplify copying between different vector sizes.
436///
437/// # Safety
438///
439/// This will try to copy the entire `copy_size`, so make sure the full width is in bounds.
440/// Starting address must be aligned to the full copy size.
441pub fn copy_async<C: CubePrimitive>(
442    _source: &Slice<C>,
443    _destination: &mut SliceMut<C>,
444    _copy_size: u32,
445) {
446    unexpanded!()
447}
448
449pub mod copy_async {
450    use super::*;
451
452    pub fn expand<C: CubePrimitive>(
453        scope: &mut Scope,
454        source: SliceExpand<C, ReadOnly>,
455        destination: SliceExpand<C, ReadWrite>,
456        copy_length: u32,
457    ) {
458        let source_length = copy_length.into();
459        let (source, source_offset) = source.__to_raw_parts();
460        let (destination, destination_offset) = destination.__to_raw_parts();
461        let scalar_size = C::as_type(scope).storage_type().size();
462
463        let mem_copy = BarrierOps::CopyAsync {
464            source,
465            source_length,
466            offset_source: source_offset,
467            offset_out: destination_offset,
468            copy_length: copy_length * scalar_size as u32,
469            checked: false,
470        };
471
472        scope.register(Instruction::new(mem_copy, destination));
473    }
474}
475
476/// Copy the source slice in global memory to destination in shared memory with a low level async
477/// copy. This only copies up to 128 bits/16 bytes, and does not synchronize. Use
478/// `barrier.copy_async_arrive` to make the reads visible.
479/// `copy_size` is in terms of elements to simplify copying between different vector sizes.
480///
481/// Will only copy the length of the source slice, and zero fill the rest. Source length must be
482/// <= copy size.
483///
484/// # Safety
485/// Starting address must be aligned to the full copy size.
486/// **This will silently fail if the address is only aligned to the source length and not the copy size!**
487pub fn copy_async_checked<C: CubePrimitive>(
488    _source: &Slice<C>,
489    _destination: &mut SliceMut<C>,
490    _copy_size: u32,
491) {
492    unexpanded!();
493}
494
495pub mod copy_async_checked {
496    use super::*;
497
498    pub fn expand<C: CubePrimitive>(
499        scope: &mut Scope,
500        source: SliceExpand<C, ReadOnly>,
501        destination: SliceExpand<C, ReadWrite>,
502        copy_length: u32,
503    ) {
504        let source_length = *source.length.expand;
505        let (source, source_offset) = source.__to_raw_parts();
506        let (destination, destination_offset) = destination.__to_raw_parts();
507        let scalar_size = C::as_type(scope).storage_type().size();
508
509        let mem_copy = BarrierOps::CopyAsync {
510            source,
511            source_length,
512            offset_source: source_offset,
513            offset_out: destination_offset,
514            copy_length: copy_length * scalar_size as u32,
515            checked: true,
516        };
517
518        scope.register(Instruction::new(mem_copy, destination));
519    }
520}
521
522#[cube(self_type = "ref")]
523impl Barrier {
524    /// Makes all previous `copy_async` operations visible on the barrier.
525    /// Should be called once after all copies have been dispatched, before reading from the shared
526    /// memory.
527    ///
528    /// Does *not* count as an arrive in terms of the barrier arrival count. So `arrive` or
529    /// `arrive_and_wait` should still be called afterwards.
530    pub fn commit_copy_async(&self) {
531        intrinsic!(|scope| {
532            let barrier = *self.expand;
533            let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
534                unreachable!()
535            };
536            let token = scope.create_barrier_token(barrier.index().unwrap(), level);
537            scope.register(Instruction::new(
538                BarrierOps::CommitCopyAsync { barrier },
539                *token,
540            ));
541        })
542    }
543}
544
545impl Deref for Shared<Barrier> {
546    type Target = Barrier;
547
548    fn deref(&self) -> &Self::Target {
549        unexpanded!()
550    }
551}
552impl Deref for SharedExpand<Barrier> {
553    type Target = BarrierExpand;
554
555    fn deref(&self) -> &Self::Target {
556        unsafe { self.as_type_ref_unchecked::<Barrier>() }
557    }
558}
559
560impl DerefMut for Shared<Barrier> {
561    fn deref_mut(&mut self) -> &mut Self::Target {
562        todo!()
563    }
564}
565impl DerefMut for SharedExpand<Barrier> {
566    fn deref_mut(&mut self) -> &mut Self::Target {
567        unsafe { self.as_type_mut_unchecked::<Barrier>() }
568    }
569}
570
571impl From<SharedExpand<Barrier>> for BarrierExpand {
572    fn from(value: SharedExpand<Barrier>) -> Self {
573        value.expand.into()
574    }
575}