cubecl_core/frontend/
barrier.rs

1//! This module exposes barrier for asynchronous data transfer
2
3use std::ops::{Deref, DerefMut};
4
5use crate as cubecl;
6use cubecl_ir::{ExpandElement, Instruction, OpaqueType};
7use cubecl_macros::intrinsic;
8use paste::paste;
9
10use crate::{
11    ir::{BarrierOps, Scope},
12    prelude::*,
13    unexpanded,
14};
15
16use super::{
17    CubePrimitive, CubeType, ExpandElementTyped, Line, ReadOnly, ReadWrite, Slice, SliceExpand,
18    SliceMut, TensorMap,
19};
20
21/// A mechanism for awaiting on asynchronous data transfers
22/// Behaviour is defined by its [BarrierLevel](BarrierLevel).
23#[derive(Clone, Copy, PartialEq, Eq)]
24pub struct Barrier;
25pub type BarrierExpand = ExpandElementTyped<Barrier>;
26
27#[derive(Clone, Copy, PartialEq)]
28pub struct BarrierToken;
29
30impl CubeType for Barrier {
31    type ExpandType = ExpandElementTyped<Barrier>;
32}
33
34impl CubePrimitive for Barrier {
35    fn from_const_value(_value: cubecl_ir::ConstantScalarValue) -> Self {
36        unreachable!("Can't create from const value")
37    }
38}
39
40impl ExpandElementIntoMut for Barrier {
41    fn elem_into_mut(_scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
42        elem
43    }
44}
45
46impl CubeType for BarrierToken {
47    type ExpandType = ExpandElementTyped<BarrierToken>;
48}
49
50impl ExpandElementIntoMut for BarrierToken {
51    fn elem_into_mut(_scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
52        elem
53    }
54}
55
56macro_rules! tensor_map_load {
57    ($dim: literal, $($arg: expr),*) => {
58        paste! {
59            impl Barrier {
60                /// Copy a tile from a global memory `source` to a shared memory `destination`, with
61                /// the provided offsets.
62                #[allow(unused, clippy::too_many_arguments)]
63                pub fn [<tma_load_ $dim d>]<C: CubePrimitive>(
64                    &self,
65                    source: &TensorMap<C>,
66                    destination: &mut SliceMut<Line<C>>,
67                    $($arg: i32),*
68                ) {
69                    unexpanded!()
70                }
71
72                #[allow(clippy::too_many_arguments)]
73                pub fn [<__expand_tma_load_ $dim d>]<C: CubePrimitive>(
74                    scope: &mut Scope,
75                    expand: BarrierExpand,
76                    source: ExpandElementTyped<TensorMap<C>>,
77                    destination: SliceExpand<Line<C>, ReadWrite>,
78                    $($arg: ExpandElementTyped<i32>),*
79                ) {
80                    expand.[<__expand_tma_load_ $dim d_method>](scope, source, destination, $($arg),*);
81                }
82            }
83
84            impl BarrierExpand {
85                #[allow(clippy::too_many_arguments)]
86                pub fn [<__expand_tma_load_ $dim d_method>]<C: CubePrimitive>(
87                    &self,
88                    scope: &mut Scope,
89                    source: ExpandElementTyped<TensorMap<C>>,
90                    destination: SliceExpand<Line<C>, ReadWrite>,
91                    $($arg: ExpandElementTyped<i32>),*
92                ) {
93                    let barrier = *self.expand;
94                    let source = *source.expand;
95                    let (destination, destination_offset) = destination.__to_raw_parts();
96
97                    let mem_copy = BarrierOps::TmaLoad {
98                        barrier,
99                        tensor_map: source,
100                        indices: vec![$(*$arg.expand),*],
101                        offset_out: destination_offset
102                    };
103
104                    scope.register(Instruction::new(mem_copy, destination));
105                }
106            }
107        }
108    };
109}
110
111macro_rules! tensor_map_load_im2col {
112    ($dim: literal, $($arg: expr),*; $($offset: expr),*) => {
113        paste! {
114            impl Barrier {
115                /// Copy a tile from a global memory `source` to a shared memory `destination`, with
116                /// the provided offsets.
117                #[allow(unused, clippy::too_many_arguments)]
118                pub fn [<tma_load_im2col_ $dim d>]<C: CubePrimitive>(
119                    &self,
120                    source: &TensorMap<C>,
121                    destination: &mut SliceMut<Line<C>>,
122                    $($arg: i32,)*
123                    $($offset: u16),*
124                ) {
125                    unexpanded!()
126                }
127
128                #[allow(clippy::too_many_arguments)]
129                pub fn [<__expand_tma_load_im2col_ $dim d>]<C: CubePrimitive>(
130                    scope: &mut Scope,
131                    expand: BarrierExpand,
132                    source: ExpandElementTyped<TensorMap<C>>,
133                    destination: SliceExpand<Line<C>, ReadWrite>,
134                    $($arg: ExpandElementTyped<i32>,)*
135                    $($offset: ExpandElementTyped<u16>),*
136                ) {
137                    expand.[<__expand_tma_load_im2col_ $dim d_method>](scope, source, destination, $($arg),*, $($offset),*);
138                }
139            }
140
141            impl BarrierExpand {
142                #[allow(clippy::too_many_arguments)]
143                pub fn [<__expand_tma_load_im2col_ $dim d_method>]<C: CubePrimitive>(
144                    &self,
145                    scope: &mut Scope,
146                    source: ExpandElementTyped<TensorMap<C>>,
147                    destination: SliceExpand<Line<C>, ReadWrite>,
148                    $($arg: ExpandElementTyped<i32>,)*
149                    $($offset: ExpandElementTyped<u16>),*
150                ) {
151                    let barrier = *self.expand;
152                    let source = *source.expand;
153                    let (destination, destination_offset) = destination.__to_raw_parts();
154
155                    let mem_copy = BarrierOps::TmaLoadIm2col {
156                        barrier,
157                        tensor_map: source,
158                        indices: vec![$(*$arg.expand),*],
159                        offsets: vec![$(*$offset.expand),*],
160                        offset_out: destination_offset,
161                    };
162
163                    scope.register(Instruction::new(mem_copy, destination));
164                }
165            }
166        }
167    };
168}
169
170tensor_map_load!(1, x);
171tensor_map_load!(2, y, x);
172tensor_map_load!(3, z, y, x);
173tensor_map_load!(4, w, z, y, x);
174tensor_map_load!(5, v, w, z, y, x);
175
176tensor_map_load_im2col!(3, n, w, c; w_offset);
177tensor_map_load_im2col!(4, n, h, w, c; h_offset, w_offset);
178tensor_map_load_im2col!(5, n, d, h, w, c; d_offset, h_offset, w_offset);
179
180#[cube(self_type = "ref")]
181impl Barrier {
182    /// Create a local barrier object for the current unit. Automatically initialized with an
183    /// arrival count of `1`.
184    pub fn local() -> Self {
185        intrinsic!(|scope| {
186            let variable =
187                scope.create_local_mut(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Unit));
188            scope.register(BarrierOps::Init {
189                barrier: *variable,
190                is_elected: true.into(),
191                arrival_count: 1.into(),
192            });
193            variable.into()
194        })
195    }
196
197    /// Create a shared memory barrier that can be accesses by all units in the cube. Initialized
198    /// by the `is_elected` unit with an arrival count of `arrival_count`. This is the number of
199    /// times `arrive` or one of its variants needs to be called before the barrier advances.
200    ///
201    /// If all units in the cube arrive on the barrier, use `CUBE_DIM` as the arrival count. For
202    /// other purposes, only a subset may need to arrive.
203    #[allow(unused_variables)]
204    pub fn shared(arrival_count: u32, is_elected: bool) -> Shared<Barrier> {
205        intrinsic!(|scope| {
206            let variable = scope.create_shared(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Cube));
207            scope.register(BarrierOps::Init {
208                barrier: *variable,
209                is_elected: *is_elected.expand,
210                arrival_count: *arrival_count.expand,
211            });
212            variable.into()
213        })
214    }
215
216    /// Create a shared memory barrier that can be accesses by all units in the cube. Only declared,
217    /// but not initialized.
218    pub fn shared_uninit() -> Shared<Barrier> {
219        intrinsic!(|scope| {
220            let variable = scope.create_shared(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Cube));
221            scope.register(BarrierOps::Declare { barrier: *variable });
222            variable.into()
223        })
224    }
225
226    /// Initializes a barrier with a given `arrival_count`. This is the number of
227    /// times `arrive` or one of its variants needs to be called before the barrier advances.
228    ///
229    /// If all units in the cube arrive on the barrier, use `CUBE_DIM` as the arrival count. For
230    /// other purposes, only a subset may need to arrive.
231    ///
232    /// # Note
233    ///
234    /// No synchronization or election is performed, this is raw initialization. For shared barriers
235    /// ensure only one unit performs the initialization, and synchronize the cube afterwards. There
236    /// may also be additional synchronization requirements for bulk copy operations, like
237    /// [`sync_async_proxy_shared()`].
238    #[allow(unused_variables)]
239    pub fn init_manual(&self, arrival_count: u32) {
240        intrinsic!(|scope| {
241            let barrier = *self.expand.clone();
242
243            scope.register(BarrierOps::InitManual {
244                barrier,
245                arrival_count: *arrival_count.expand,
246            });
247        })
248    }
249}
250
251// MemcpyAsync
252
253#[cube(self_type = "ref")]
254impl Barrier {
255    /// Copy the source slice to destination
256    ///
257    /// # Safety
258    ///
259    /// This will try to copy the whole source slice, so
260    /// make sure source length <= destination length
261    #[allow(unused_variables)]
262    pub fn memcpy_async<C: CubePrimitive>(
263        &self,
264        source: &Slice<Line<C>>,
265        destination: &mut SliceMut<Line<C>>,
266    ) {
267        intrinsic!(|scope| {
268            let barrier = *self.expand;
269            let source_length = *source.length.expand;
270            let (source, source_offset) = source.__to_raw_parts();
271            let (destination, destination_offset) = destination.__to_raw_parts();
272
273            let mem_copy = BarrierOps::MemCopyAsync {
274                barrier,
275                source,
276                source_length,
277                offset_source: source_offset,
278                offset_out: destination_offset,
279            };
280
281            scope.register(Instruction::new(mem_copy, destination));
282        })
283    }
284
285    /// Copy the source slice to destination
286    ///
287    /// # Safety
288    ///
289    /// This will try to copy the whole source slice, so
290    /// make sure source length <= destination length
291    #[allow(unused_variables)]
292    pub fn memcpy_async_cooperative<C: CubePrimitive>(
293        &self,
294        source: &Slice<Line<C>>,
295        destination: &mut SliceMut<Line<C>>,
296    ) {
297        intrinsic!(|scope| {
298            let barrier = *self.expand;
299            let source_length = *source.length.expand;
300            let (source, source_offset) = source.__to_raw_parts();
301            let (destination, destination_offset) = destination.__to_raw_parts();
302
303            let mem_copy = BarrierOps::MemCopyAsyncCooperative {
304                barrier,
305                source,
306                source_length,
307                offset_source: source_offset,
308                offset_out: destination_offset,
309            };
310
311            scope.register(Instruction::new(mem_copy, destination));
312        })
313    }
314
315    /// Copy the source slice to destination. Uses transaction count like TMA, so use with
316    /// `expect_tx` or `arrive_and_expect_tx`.
317    ///
318    /// # Safety
319    ///
320    /// This will try to copy the whole source slice, so
321    /// make sure source length <= destination length
322    #[allow(unused_variables)]
323    pub fn memcpy_async_tx<C: CubePrimitive>(
324        &self,
325        source: &Slice<Line<C>>,
326        destination: &mut SliceMut<Line<C>>,
327    ) {
328        intrinsic!(|scope| {
329            let barrier = *self.expand;
330            let source_length = *source.length.expand;
331            let (source, source_offset) = source.__to_raw_parts();
332            let (destination, destination_offset) = destination.__to_raw_parts();
333
334            let mem_copy = BarrierOps::MemCopyAsyncTx {
335                barrier,
336                source,
337                source_length,
338                offset_source: source_offset,
339                offset_out: destination_offset,
340            };
341
342            scope.register(Instruction::new(mem_copy, destination));
343        })
344    }
345}
346
347// Arrival and Wait
348
349#[cube(self_type = "ref")]
350impl Barrier {
351    /// Arrive at the barrier, decrementing arrival count
352    pub fn arrive(&self) -> BarrierToken {
353        intrinsic!(|scope| {
354            let barrier = *self.expand;
355            let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
356                unreachable!()
357            };
358            let token = scope.create_barrier_token(barrier.index().unwrap(), level);
359            scope.register(Instruction::new(BarrierOps::Arrive { barrier }, *token));
360            token.into()
361        })
362    }
363
364    /// Arrive at the barrier, decrementing arrival count. Additionally increments expected count.
365    #[allow(unused_variables)]
366    pub fn arrive_and_expect_tx(&self, arrival_count: u32, transaction_count: u32) -> BarrierToken {
367        intrinsic!(|scope| {
368            let barrier = *self.expand;
369            let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
370                unreachable!()
371            };
372            let token = scope.create_barrier_token(barrier.index().unwrap(), level);
373            let arrival_count: ExpandElement = arrival_count.into();
374            let transaction_count: ExpandElement = transaction_count.into();
375            scope.register(Instruction::new(
376                BarrierOps::ArriveTx {
377                    barrier,
378                    arrive_count_update: arrival_count.consume(),
379                    transaction_count_update: transaction_count.consume(),
380                },
381                *token,
382            ));
383            token.into()
384        })
385    }
386
387    /// Increments the expected count of the barrier.
388    #[allow(unused_variables)]
389    pub fn expect_tx(&self, expected_count: u32) {
390        intrinsic!(|scope| {
391            let barrier = *self.expand;
392            let transaction_count: ExpandElement = expected_count.into();
393            scope.register(BarrierOps::ExpectTx {
394                barrier,
395                transaction_count_update: transaction_count.consume(),
396            });
397        })
398    }
399
400    /// Wait until all data is loaded
401    pub fn arrive_and_wait(&self) {
402        intrinsic!(|scope| {
403            let barrier = *self.expand;
404            scope.register(BarrierOps::ArriveAndWait { barrier });
405        })
406    }
407
408    /// Wait at the barrier until all arrivals are done
409    #[allow(unused_variables)]
410    pub fn wait(&self, token: BarrierToken) {
411        intrinsic!(|scope| {
412            let barrier = *self.expand;
413            let token = *token.expand;
414            scope.register(BarrierOps::Wait { barrier, token });
415        })
416    }
417
418    /// Wait at the barrier until the `phase` is completed. Doesn't require a token, but needs phase
419    /// to be managed manually.
420    #[allow(unused_variables)]
421    pub fn wait_parity(&self, phase: u32) {
422        intrinsic!(|scope| {
423            let barrier = *self.expand;
424            let phase = *phase.expand;
425            scope.register(BarrierOps::WaitParity { barrier, phase });
426        })
427    }
428}
429
430// Copy async
431
432/// Copy the source slice in global memory to destination in shared memory with a low level async
433/// copy. This only copies up to 128 bits/16 bytes, and does not synchronize. Use
434/// `barrier.copy_async_arrive` to make the reads visible.
435/// `copy_size` is in terms of elements to simplify copying between different line sizes.
436///
437/// # Safety
438///
439/// This will try to copy the entire `copy_size`, so make sure the full width is in bounds.
440/// Starting address must be aligned to the full copy size.
441pub fn copy_async<C: CubePrimitive>(
442    _source: &Slice<Line<C>>,
443    _destination: &mut SliceMut<Line<C>>,
444    _copy_size: u32,
445) {
446    unexpanded!()
447}
448
449pub mod copy_async {
450    use super::*;
451
452    pub fn expand<C: CubePrimitive>(
453        scope: &mut Scope,
454        source: SliceExpand<Line<C>, ReadOnly>,
455        destination: SliceExpand<Line<C>, ReadWrite>,
456        copy_length: u32,
457    ) {
458        let source_length = copy_length.into();
459        let (source, source_offset) = source.__to_raw_parts();
460        let (destination, destination_offset) = destination.__to_raw_parts();
461
462        let mem_copy = BarrierOps::CopyAsync {
463            source,
464            source_length,
465            offset_source: source_offset,
466            offset_out: destination_offset,
467            copy_length: copy_length * C::as_type(scope).size() as u32,
468            checked: false,
469        };
470
471        scope.register(Instruction::new(mem_copy, destination));
472    }
473}
474
475/// Copy the source slice in global memory to destination in shared memory with a low level async
476/// copy. This only copies up to 128 bits/16 bytes, and does not synchronize. Use
477/// `barrier.copy_async_arrive` to make the reads visible.
478/// `copy_size` is in terms of elements to simplify copying between different line sizes.
479///
480/// Will only copy the length of the source slice, and zero fill the rest. Source length must be
481/// <= copy size.
482///
483/// # Safety
484/// Starting address must be aligned to the full copy size.
485/// **This will silently fail if the address is only aligned to the source length and not the copy size!**
486pub fn copy_async_checked<C: CubePrimitive>(
487    _source: &Slice<Line<C>>,
488    _destination: &mut SliceMut<Line<C>>,
489    _copy_size: u32,
490) {
491    unexpanded!();
492}
493
494pub mod copy_async_checked {
495    use super::*;
496
497    pub fn expand<C: CubePrimitive>(
498        scope: &mut Scope,
499        source: SliceExpand<Line<C>, ReadOnly>,
500        destination: SliceExpand<Line<C>, ReadWrite>,
501        copy_length: u32,
502    ) {
503        let source_length = *source.length.expand;
504        let (source, source_offset) = source.__to_raw_parts();
505        let (destination, destination_offset) = destination.__to_raw_parts();
506
507        let mem_copy = BarrierOps::CopyAsync {
508            source,
509            source_length,
510            offset_source: source_offset,
511            offset_out: destination_offset,
512            copy_length: copy_length * C::as_type(scope).size() as u32,
513            checked: true,
514        };
515
516        scope.register(Instruction::new(mem_copy, destination));
517    }
518}
519
520#[cube(self_type = "ref")]
521impl Barrier {
522    /// Makes all previous `copy_async` operations visible on the barrier.
523    /// Should be called once after all copies have been dispatched, before reading from the shared
524    /// memory.
525    ///
526    /// Does *not* count as an arrive in terms of the barrier arrival count. So `arrive` or
527    /// `arrive_and_wait` should still be called afterwards.
528    pub fn commit_copy_async(&self) {
529        intrinsic!(|scope| {
530            let barrier = *self.expand;
531            let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
532                unreachable!()
533            };
534            let token = scope.create_barrier_token(barrier.index().unwrap(), level);
535            scope.register(Instruction::new(
536                BarrierOps::CommitCopyAsync { barrier },
537                *token,
538            ));
539        })
540    }
541}
542
543impl Deref for Shared<Barrier> {
544    type Target = Barrier;
545
546    fn deref(&self) -> &Self::Target {
547        unexpanded!()
548    }
549}
550impl Deref for SharedExpand<Barrier> {
551    type Target = BarrierExpand;
552
553    fn deref(&self) -> &Self::Target {
554        unsafe { self.as_type_ref_unchecked::<Barrier>() }
555    }
556}
557
558impl DerefMut for Shared<Barrier> {
559    fn deref_mut(&mut self) -> &mut Self::Target {
560        todo!()
561    }
562}
563impl DerefMut for SharedExpand<Barrier> {
564    fn deref_mut(&mut self) -> &mut Self::Target {
565        unsafe { self.as_type_mut_unchecked::<Barrier>() }
566    }
567}
568
569impl From<SharedExpand<Barrier>> for BarrierExpand {
570    fn from(value: SharedExpand<Barrier>) -> Self {
571        value.expand.into()
572    }
573}