Skip to main content

cubecl_core/frontend/
barrier.rs

1//! This module exposes barrier for asynchronous data transfer
2
3use alloc::vec;
4use core::ops::{Deref, DerefMut};
5
6use crate as cubecl;
7use cubecl_ir::{ExpandElement, Instruction, OpaqueType};
8use cubecl_macros::intrinsic;
9use paste::paste;
10
11use crate::{
12    ir::{BarrierOps, Scope},
13    prelude::*,
14    unexpanded,
15};
16
17use super::{
18    CubePrimitive, CubeType, ExpandElementTyped, Line, ReadOnly, ReadWrite, Slice, SliceExpand,
19    SliceMut, TensorMap,
20};
21
22/// A mechanism for awaiting on asynchronous data transfers
23/// Behavior is defined by its ``BarrierLevel``.
24#[derive(Clone, Copy, PartialEq, Eq)]
25pub struct Barrier;
26pub type BarrierExpand = ExpandElementTyped<Barrier>;
27
28#[derive(Clone, Copy, PartialEq)]
29pub struct BarrierToken;
30
31impl CubeType for Barrier {
32    type ExpandType = ExpandElementTyped<Barrier>;
33}
34
35impl CubePrimitive for Barrier {
36    fn from_const_value(_value: cubecl_ir::ConstantValue) -> Self {
37        unreachable!("Can't create from const value")
38    }
39}
40
41impl ExpandElementIntoMut for Barrier {
42    fn elem_into_mut(_scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
43        elem
44    }
45}
46
47impl CubeType for BarrierToken {
48    type ExpandType = ExpandElementTyped<BarrierToken>;
49}
50
51impl ExpandElementIntoMut for BarrierToken {
52    fn elem_into_mut(_scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
53        elem
54    }
55}
56
57macro_rules! tensor_map_load {
58    ($dim: literal, $($arg: expr),*) => {
59        paste! {
60            impl Barrier {
61                /// Copy a tile from a global memory `source` to a shared memory `destination`, with
62                /// the provided offsets.
63                #[allow(unused, clippy::too_many_arguments)]
64                pub fn [<tma_load_ $dim d>]<C: CubePrimitive>(
65                    &self,
66                    source: &TensorMap<C, Tiled>,
67                    destination: &mut SliceMut<Line<C>>,
68                    $($arg: i32),*
69                ) {
70                    unexpanded!()
71                }
72
73                #[allow(clippy::too_many_arguments)]
74                pub fn [<__expand_tma_load_ $dim d>]<C: CubePrimitive>(
75                    scope: &mut Scope,
76                    expand: BarrierExpand,
77                    source: ExpandElementTyped<TensorMap<C, Tiled>>,
78                    destination: SliceExpand<Line<C>, ReadWrite>,
79                    $($arg: ExpandElementTyped<i32>),*
80                ) {
81                    expand.[<__expand_tma_load_ $dim d_method>](scope, source, destination, $($arg),*);
82                }
83            }
84
85            impl BarrierExpand {
86                #[allow(clippy::too_many_arguments)]
87                pub fn [<__expand_tma_load_ $dim d_method>]<C: CubePrimitive>(
88                    &self,
89                    scope: &mut Scope,
90                    source: ExpandElementTyped<TensorMap<C, Tiled>>,
91                    destination: SliceExpand<Line<C>, ReadWrite>,
92                    $($arg: ExpandElementTyped<i32>),*
93                ) {
94                    let barrier = *self.expand;
95                    let source = *source.expand;
96                    let (destination, destination_offset) = destination.__to_raw_parts();
97
98                    let mem_copy = BarrierOps::TmaLoad {
99                        barrier,
100                        tensor_map: source,
101                        indices: vec![$(*$arg.expand),*],
102                        offset_out: destination_offset
103                    };
104
105                    scope.register(Instruction::new(mem_copy, destination));
106                }
107            }
108        }
109    };
110}
111
112macro_rules! tensor_map_load_im2col {
113    ($dim: literal, $($arg: expr),*; $($offset: expr),*) => {
114        paste! {
115            impl Barrier {
116                /// Copy a tile from a global memory `source` to a shared memory `destination`, with
117                /// the provided offsets.
118                #[allow(unused, clippy::too_many_arguments)]
119                pub fn [<tma_load_im2col_ $dim d>]<C: CubePrimitive>(
120                    &self,
121                    source: &TensorMap<C, Im2col>,
122                    destination: &mut SliceMut<Line<C>>,
123                    $($arg: i32,)*
124                    $($offset: u16),*
125                ) {
126                    unexpanded!()
127                }
128
129                #[allow(clippy::too_many_arguments)]
130                pub fn [<__expand_tma_load_im2col_ $dim d>]<C: CubePrimitive>(
131                    scope: &mut Scope,
132                    expand: BarrierExpand,
133                    source: ExpandElementTyped<TensorMap<C, Im2col>>,
134                    destination: SliceExpand<Line<C>, ReadWrite>,
135                    $($arg: ExpandElementTyped<i32>,)*
136                    $($offset: ExpandElementTyped<u16>),*
137                ) {
138                    expand.[<__expand_tma_load_im2col_ $dim d_method>](scope, source, destination, $($arg),*, $($offset),*);
139                }
140            }
141
142            impl BarrierExpand {
143                #[allow(clippy::too_many_arguments)]
144                pub fn [<__expand_tma_load_im2col_ $dim d_method>]<C: CubePrimitive>(
145                    &self,
146                    scope: &mut Scope,
147                    source: ExpandElementTyped<TensorMap<C, Im2col>>,
148                    destination: SliceExpand<Line<C>, ReadWrite>,
149                    $($arg: ExpandElementTyped<i32>,)*
150                    $($offset: ExpandElementTyped<u16>),*
151                ) {
152                    let barrier = *self.expand;
153                    let source = *source.expand;
154                    let (destination, destination_offset) = destination.__to_raw_parts();
155
156                    let mem_copy = BarrierOps::TmaLoadIm2col {
157                        barrier,
158                        tensor_map: source,
159                        indices: vec![$(*$arg.expand),*],
160                        offsets: vec![$(*$offset.expand),*],
161                        offset_out: destination_offset,
162                    };
163
164                    scope.register(Instruction::new(mem_copy, destination));
165                }
166            }
167        }
168    };
169}
170
171tensor_map_load!(1, x);
172tensor_map_load!(2, y, x);
173tensor_map_load!(3, z, y, x);
174tensor_map_load!(4, w, z, y, x);
175tensor_map_load!(5, v, w, z, y, x);
176
177tensor_map_load_im2col!(3, n, w, c; w_offset);
178tensor_map_load_im2col!(4, n, h, w, c; h_offset, w_offset);
179tensor_map_load_im2col!(5, n, d, h, w, c; d_offset, h_offset, w_offset);
180
181#[cube(self_type = "ref")]
182impl Barrier {
183    /// Create a local barrier object for the current unit. Automatically initialized with an
184    /// arrival count of `1`.
185    pub fn local() -> Self {
186        intrinsic!(|scope| {
187            let variable =
188                scope.create_local_mut(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Unit));
189            scope.register(BarrierOps::Init {
190                barrier: *variable,
191                is_elected: true.into(),
192                arrival_count: 1.into(),
193            });
194            variable.into()
195        })
196    }
197
198    /// Create a shared memory barrier that can be accesses by all units in the cube. Initialized
199    /// by the `is_elected` unit with an arrival count of `arrival_count`. This is the number of
200    /// times `arrive` or one of its variants needs to be called before the barrier advances.
201    ///
202    /// If all units in the cube arrive on the barrier, use `CUBE_DIM` as the arrival count. For
203    /// other purposes, only a subset may need to arrive.
204    #[allow(unused_variables)]
205    pub fn shared(arrival_count: u32, is_elected: bool) -> Shared<Barrier> {
206        intrinsic!(|scope| {
207            let variable = scope.create_shared(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Cube));
208            scope.register(BarrierOps::Init {
209                barrier: *variable,
210                is_elected: *is_elected.expand,
211                arrival_count: *arrival_count.expand,
212            });
213            variable.into()
214        })
215    }
216
217    /// Create a shared memory barrier that can be accesses by all units in the cube. Only declared,
218    /// but not initialized.
219    pub fn shared_uninit() -> Shared<Barrier> {
220        intrinsic!(|scope| {
221            let variable = scope.create_shared(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Cube));
222            scope.register(BarrierOps::Declare { barrier: *variable });
223            variable.into()
224        })
225    }
226
227    /// Initializes a barrier with a given `arrival_count`. This is the number of
228    /// times `arrive` or one of its variants needs to be called before the barrier advances.
229    ///
230    /// If all units in the cube arrive on the barrier, use `CUBE_DIM` as the arrival count. For
231    /// other purposes, only a subset may need to arrive.
232    ///
233    /// # Note
234    ///
235    /// No synchronization or election is performed, this is raw initialization. For shared barriers
236    /// ensure only one unit performs the initialization, and synchronize the cube afterwards. There
237    /// may also be additional synchronization requirements for bulk copy operations, like
238    /// [`sync_async_proxy_shared()`].
239    #[allow(unused_variables)]
240    pub fn init_manual(&self, arrival_count: u32) {
241        intrinsic!(|scope| {
242            let barrier = *self.expand.clone();
243
244            scope.register(BarrierOps::InitManual {
245                barrier,
246                arrival_count: *arrival_count.expand,
247            });
248        })
249    }
250}
251
252// MemcpyAsync
253
254#[cube(self_type = "ref")]
255impl Barrier {
256    /// Copy the source slice to destination
257    ///
258    /// # Safety
259    ///
260    /// This will try to copy the whole source slice, so
261    /// make sure source length <= destination length
262    #[allow(unused_variables)]
263    pub fn memcpy_async<C: CubePrimitive>(
264        &self,
265        source: &Slice<Line<C>>,
266        destination: &mut SliceMut<Line<C>>,
267    ) {
268        intrinsic!(|scope| {
269            let barrier = *self.expand;
270            let source_length = *source.length.expand;
271            let (source, source_offset) = source.__to_raw_parts();
272            let (destination, destination_offset) = destination.__to_raw_parts();
273
274            let mem_copy = BarrierOps::MemCopyAsync {
275                barrier,
276                source,
277                source_length,
278                offset_source: source_offset,
279                offset_out: destination_offset,
280            };
281
282            scope.register(Instruction::new(mem_copy, destination));
283        })
284    }
285
286    /// Copy the source slice to destination
287    ///
288    /// # Safety
289    ///
290    /// This will try to copy the whole source slice, so
291    /// make sure source length <= destination length
292    #[allow(unused_variables)]
293    pub fn memcpy_async_cooperative<C: CubePrimitive>(
294        &self,
295        source: &Slice<Line<C>>,
296        destination: &mut SliceMut<Line<C>>,
297    ) {
298        intrinsic!(|scope| {
299            let barrier = *self.expand;
300            let source_length = *source.length.expand;
301            let (source, source_offset) = source.__to_raw_parts();
302            let (destination, destination_offset) = destination.__to_raw_parts();
303
304            let mem_copy = BarrierOps::MemCopyAsyncCooperative {
305                barrier,
306                source,
307                source_length,
308                offset_source: source_offset,
309                offset_out: destination_offset,
310            };
311
312            scope.register(Instruction::new(mem_copy, destination));
313        })
314    }
315
316    /// Copy the source slice to destination. Uses transaction count like TMA, so use with
317    /// `expect_tx` or `arrive_and_expect_tx`.
318    ///
319    /// # Safety
320    ///
321    /// This will try to copy the whole source slice, so
322    /// make sure source length <= destination length
323    #[allow(unused_variables)]
324    pub fn memcpy_async_tx<C: CubePrimitive>(
325        &self,
326        source: &Slice<Line<C>>,
327        destination: &mut SliceMut<Line<C>>,
328    ) {
329        intrinsic!(|scope| {
330            let barrier = *self.expand;
331            let source_length = *source.length.expand;
332            let (source, source_offset) = source.__to_raw_parts();
333            let (destination, destination_offset) = destination.__to_raw_parts();
334
335            let mem_copy = BarrierOps::MemCopyAsyncTx {
336                barrier,
337                source,
338                source_length,
339                offset_source: source_offset,
340                offset_out: destination_offset,
341            };
342
343            scope.register(Instruction::new(mem_copy, destination));
344        })
345    }
346}
347
348// Arrival and Wait
349
350#[cube(self_type = "ref")]
351impl Barrier {
352    /// Arrive at the barrier, decrementing arrival count
353    pub fn arrive(&self) -> BarrierToken {
354        intrinsic!(|scope| {
355            let barrier = *self.expand;
356            let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
357                unreachable!()
358            };
359            let token = scope.create_barrier_token(barrier.index().unwrap(), level);
360            scope.register(Instruction::new(BarrierOps::Arrive { barrier }, *token));
361            token.into()
362        })
363    }
364
365    /// Arrive at the barrier, decrementing arrival count. Additionally increments expected count.
366    #[allow(unused_variables)]
367    pub fn arrive_and_expect_tx(&self, arrival_count: u32, transaction_count: u32) -> BarrierToken {
368        intrinsic!(|scope| {
369            let barrier = *self.expand;
370            let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
371                unreachable!()
372            };
373            let token = scope.create_barrier_token(barrier.index().unwrap(), level);
374            let arrival_count: ExpandElement = arrival_count.into();
375            let transaction_count: ExpandElement = transaction_count.into();
376            scope.register(Instruction::new(
377                BarrierOps::ArriveTx {
378                    barrier,
379                    arrive_count_update: arrival_count.consume(),
380                    transaction_count_update: transaction_count.consume(),
381                },
382                *token,
383            ));
384            token.into()
385        })
386    }
387
388    /// Increments the expected count of the barrier.
389    #[allow(unused_variables)]
390    pub fn expect_tx(&self, expected_count: u32) {
391        intrinsic!(|scope| {
392            let barrier = *self.expand;
393            let transaction_count: ExpandElement = expected_count.into();
394            scope.register(BarrierOps::ExpectTx {
395                barrier,
396                transaction_count_update: transaction_count.consume(),
397            });
398        })
399    }
400
401    /// Wait until all data is loaded
402    pub fn arrive_and_wait(&self) {
403        intrinsic!(|scope| {
404            let barrier = *self.expand;
405            scope.register(BarrierOps::ArriveAndWait { barrier });
406        })
407    }
408
409    /// Wait at the barrier until all arrivals are done
410    #[allow(unused_variables)]
411    pub fn wait(&self, token: BarrierToken) {
412        intrinsic!(|scope| {
413            let barrier = *self.expand;
414            let token = *token.expand;
415            scope.register(BarrierOps::Wait { barrier, token });
416        })
417    }
418
419    /// Wait at the barrier until the `phase` is completed. Doesn't require a token, but needs phase
420    /// to be managed manually.
421    #[allow(unused_variables)]
422    pub fn wait_parity(&self, phase: u32) {
423        intrinsic!(|scope| {
424            let barrier = *self.expand;
425            let phase = *phase.expand;
426            scope.register(BarrierOps::WaitParity { barrier, phase });
427        })
428    }
429}
430
431// Copy async
432
433/// Copy the source slice in global memory to destination in shared memory with a low level async
434/// copy. This only copies up to 128 bits/16 bytes, and does not synchronize. Use
435/// `barrier.copy_async_arrive` to make the reads visible.
436/// `copy_size` is in terms of elements to simplify copying between different line sizes.
437///
438/// # Safety
439///
440/// This will try to copy the entire `copy_size`, so make sure the full width is in bounds.
441/// Starting address must be aligned to the full copy size.
442pub fn copy_async<C: CubePrimitive>(
443    _source: &Slice<Line<C>>,
444    _destination: &mut SliceMut<Line<C>>,
445    _copy_size: u32,
446) {
447    unexpanded!()
448}
449
450pub mod copy_async {
451    use super::*;
452
453    pub fn expand<C: CubePrimitive>(
454        scope: &mut Scope,
455        source: SliceExpand<Line<C>, ReadOnly>,
456        destination: SliceExpand<Line<C>, ReadWrite>,
457        copy_length: u32,
458    ) {
459        let source_length = copy_length.into();
460        let (source, source_offset) = source.__to_raw_parts();
461        let (destination, destination_offset) = destination.__to_raw_parts();
462
463        let mem_copy = BarrierOps::CopyAsync {
464            source,
465            source_length,
466            offset_source: source_offset,
467            offset_out: destination_offset,
468            copy_length: copy_length * C::as_type(scope).size() as u32,
469            checked: false,
470        };
471
472        scope.register(Instruction::new(mem_copy, destination));
473    }
474}
475
476/// Copy the source slice in global memory to destination in shared memory with a low level async
477/// copy. This only copies up to 128 bits/16 bytes, and does not synchronize. Use
478/// `barrier.copy_async_arrive` to make the reads visible.
479/// `copy_size` is in terms of elements to simplify copying between different line sizes.
480///
481/// Will only copy the length of the source slice, and zero fill the rest. Source length must be
482/// <= copy size.
483///
484/// # Safety
485/// Starting address must be aligned to the full copy size.
486/// **This will silently fail if the address is only aligned to the source length and not the copy size!**
487pub fn copy_async_checked<C: CubePrimitive>(
488    _source: &Slice<Line<C>>,
489    _destination: &mut SliceMut<Line<C>>,
490    _copy_size: u32,
491) {
492    unexpanded!();
493}
494
495pub mod copy_async_checked {
496    use super::*;
497
498    pub fn expand<C: CubePrimitive>(
499        scope: &mut Scope,
500        source: SliceExpand<Line<C>, ReadOnly>,
501        destination: SliceExpand<Line<C>, ReadWrite>,
502        copy_length: u32,
503    ) {
504        let source_length = *source.length.expand;
505        let (source, source_offset) = source.__to_raw_parts();
506        let (destination, destination_offset) = destination.__to_raw_parts();
507
508        let mem_copy = BarrierOps::CopyAsync {
509            source,
510            source_length,
511            offset_source: source_offset,
512            offset_out: destination_offset,
513            copy_length: copy_length * C::as_type(scope).size() as u32,
514            checked: true,
515        };
516
517        scope.register(Instruction::new(mem_copy, destination));
518    }
519}
520
521#[cube(self_type = "ref")]
522impl Barrier {
523    /// Makes all previous `copy_async` operations visible on the barrier.
524    /// Should be called once after all copies have been dispatched, before reading from the shared
525    /// memory.
526    ///
527    /// Does *not* count as an arrive in terms of the barrier arrival count. So `arrive` or
528    /// `arrive_and_wait` should still be called afterwards.
529    pub fn commit_copy_async(&self) {
530        intrinsic!(|scope| {
531            let barrier = *self.expand;
532            let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
533                unreachable!()
534            };
535            let token = scope.create_barrier_token(barrier.index().unwrap(), level);
536            scope.register(Instruction::new(
537                BarrierOps::CommitCopyAsync { barrier },
538                *token,
539            ));
540        })
541    }
542}
543
544impl Deref for Shared<Barrier> {
545    type Target = Barrier;
546
547    fn deref(&self) -> &Self::Target {
548        unexpanded!()
549    }
550}
551impl Deref for SharedExpand<Barrier> {
552    type Target = BarrierExpand;
553
554    fn deref(&self) -> &Self::Target {
555        unsafe { self.as_type_ref_unchecked::<Barrier>() }
556    }
557}
558
559impl DerefMut for Shared<Barrier> {
560    fn deref_mut(&mut self) -> &mut Self::Target {
561        todo!()
562    }
563}
564impl DerefMut for SharedExpand<Barrier> {
565    fn deref_mut(&mut self) -> &mut Self::Target {
566        unsafe { self.as_type_mut_unchecked::<Barrier>() }
567    }
568}
569
570impl From<SharedExpand<Barrier>> for BarrierExpand {
571    fn from(value: SharedExpand<Barrier>) -> Self {
572        value.expand.into()
573    }
574}