1use alloc::vec;
4use core::ops::{Deref, DerefMut};
5
6use crate as cubecl;
7use cubecl_ir::{Instruction, ManagedVariable, OpaqueType};
8use cubecl_macros::intrinsic;
9use paste::paste;
10
11use crate::{
12 ir::{BarrierOps, Scope},
13 prelude::*,
14 unexpanded,
15};
16
17use super::{
18 CubePrimitive, CubeType, NativeExpand, ReadOnly, ReadWrite, Slice, SliceExpand, SliceMut,
19 TensorMap,
20};
21
22#[derive(Clone, Copy, PartialEq, Eq)]
25pub struct Barrier;
26pub type BarrierExpand = NativeExpand<Barrier>;
27
28#[derive(Clone, Copy, PartialEq)]
29pub struct BarrierToken;
30
31impl CubeType for Barrier {
32 type ExpandType = NativeExpand<Barrier>;
33}
34
35impl CubePrimitive for Barrier {
36 type Scalar = u32; type Size = Const<1>;
38 type WithScalar<S: Scalar> = S;
39 fn from_const_value(_value: cubecl_ir::ConstantValue) -> Self {
40 unreachable!("Can't create from const value")
41 }
42}
43
44impl NativeAssign for Barrier {
45 fn elem_init_mut(_scope: &mut Scope, elem: ManagedVariable) -> ManagedVariable {
46 elem
47 }
48}
49
50impl CubeType for BarrierToken {
51 type ExpandType = NativeExpand<BarrierToken>;
52}
53
54impl NativeAssign for BarrierToken {
55 fn elem_init_mut(_scope: &mut crate::ir::Scope, elem: ManagedVariable) -> ManagedVariable {
56 elem
57 }
58}
59
60macro_rules! tensor_map_load {
61 ($dim: literal, $($arg: expr),*) => {
62 paste! {
63 impl Barrier {
64 #[allow(unused, clippy::too_many_arguments)]
67 pub fn [<tma_load_ $dim d>]<C1: CubePrimitive, C2: CubePrimitive<Scalar = C1::Scalar>>(
68 &self,
69 source: &TensorMap<C1, Tiled>,
70 destination: &mut SliceMut<C2>,
71 $($arg: i32),*
72 ) {
73 unexpanded!()
74 }
75
76 #[allow(clippy::too_many_arguments)]
77 pub fn [<__expand_tma_load_ $dim d>]<C1: CubePrimitive, C2: CubePrimitive<Scalar = C1::Scalar>>(
78 scope: &mut Scope,
79 expand: BarrierExpand,
80 source: NativeExpand<TensorMap<C1, Tiled>>,
81 destination: SliceExpand<C2, ReadWrite>,
82 $($arg: NativeExpand<i32>),*
83 ) {
84 expand.[<__expand_tma_load_ $dim d_method>](scope, source, destination, $($arg),*);
85 }
86 }
87
88 impl BarrierExpand {
89 #[allow(clippy::too_many_arguments)]
90 pub fn [<__expand_tma_load_ $dim d_method>]<C1: CubePrimitive, C2: CubePrimitive<Scalar = C1::Scalar>>(
91 &self,
92 scope: &mut Scope,
93 source: NativeExpand<TensorMap<C1, Tiled>>,
94 destination: SliceExpand<C2, ReadWrite>,
95 $($arg: NativeExpand<i32>),*
96 ) {
97 let barrier = *self.expand;
98 let source = *source.expand;
99 let (destination, destination_offset) = destination.__to_raw_parts();
100
101 let mem_copy = BarrierOps::TmaLoad {
102 barrier,
103 tensor_map: source,
104 indices: vec![$(*$arg.expand),*],
105 offset_out: destination_offset
106 };
107
108 scope.register(Instruction::new(mem_copy, destination));
109 }
110 }
111 }
112 };
113}
114
115macro_rules! tensor_map_load_im2col {
116 ($dim: literal, $($arg: expr),*; $($offset: expr),*) => {
117 paste! {
118 impl Barrier {
119 #[allow(unused, clippy::too_many_arguments)]
122 pub fn [<tma_load_im2col_ $dim d>]<C1: CubePrimitive, C2: CubePrimitive<Scalar = C1::Scalar>>(
123 &self,
124 source: &TensorMap<C1, Im2col>,
125 destination: &mut SliceMut<C2>,
126 $($arg: i32,)*
127 $($offset: u16),*
128 ) {
129 unexpanded!()
130 }
131
132 #[allow(clippy::too_many_arguments)]
133 pub fn [<__expand_tma_load_im2col_ $dim d>]<C1: CubePrimitive, C2: CubePrimitive<Scalar = C1::Scalar>>(
134 scope: &mut Scope,
135 expand: BarrierExpand,
136 source: NativeExpand<TensorMap<C1, Im2col>>,
137 destination: SliceExpand<C2, ReadWrite>,
138 $($arg: NativeExpand<i32>,)*
139 $($offset: NativeExpand<u16>),*
140 ) {
141 expand.[<__expand_tma_load_im2col_ $dim d_method>](scope, source, destination, $($arg),*, $($offset),*);
142 }
143 }
144
145 impl BarrierExpand {
146 #[allow(clippy::too_many_arguments)]
147 pub fn [<__expand_tma_load_im2col_ $dim d_method>]<C1: CubePrimitive, C2: CubePrimitive<Scalar = C1::Scalar>>(
148 &self,
149 scope: &mut Scope,
150 source: NativeExpand<TensorMap<C1, Im2col>>,
151 destination: SliceExpand<C2, ReadWrite>,
152 $($arg: NativeExpand<i32>,)*
153 $($offset: NativeExpand<u16>),*
154 ) {
155 let barrier = *self.expand;
156 let source = *source.expand;
157 let (destination, destination_offset) = destination.__to_raw_parts();
158
159 let mem_copy = BarrierOps::TmaLoadIm2col {
160 barrier,
161 tensor_map: source,
162 indices: vec![$(*$arg.expand),*],
163 offsets: vec![$(*$offset.expand),*],
164 offset_out: destination_offset,
165 };
166
167 scope.register(Instruction::new(mem_copy, destination));
168 }
169 }
170 }
171 };
172}
173
174tensor_map_load!(1, x);
175tensor_map_load!(2, y, x);
176tensor_map_load!(3, z, y, x);
177tensor_map_load!(4, w, z, y, x);
178tensor_map_load!(5, v, w, z, y, x);
179
180tensor_map_load_im2col!(3, n, w, c; w_offset);
181tensor_map_load_im2col!(4, n, h, w, c; h_offset, w_offset);
182tensor_map_load_im2col!(5, n, d, h, w, c; d_offset, h_offset, w_offset);
183
184#[cube(self_type = "ref")]
185impl Barrier {
186 pub fn local() -> Self {
189 intrinsic!(|scope| {
190 let variable =
191 scope.create_local_mut(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Unit));
192 scope.register(BarrierOps::Init {
193 barrier: *variable,
194 is_elected: true.into(),
195 arrival_count: 1.into(),
196 });
197 variable.into()
198 })
199 }
200
201 #[allow(unused_variables)]
208 pub fn shared(arrival_count: u32, is_elected: bool) -> Shared<Barrier> {
209 intrinsic!(|scope| {
210 let variable = scope.create_shared(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Cube));
211 scope.register(BarrierOps::Init {
212 barrier: *variable,
213 is_elected: *is_elected.expand,
214 arrival_count: *arrival_count.expand,
215 });
216 variable.into()
217 })
218 }
219
220 pub fn shared_uninit() -> Shared<Barrier> {
223 intrinsic!(|scope| {
224 let variable = scope.create_shared(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Cube));
225 scope.register(BarrierOps::Declare { barrier: *variable });
226 variable.into()
227 })
228 }
229
230 #[allow(unused_variables)]
243 pub fn init_manual(&self, arrival_count: u32) {
244 intrinsic!(|scope| {
245 let barrier = *self.expand.clone();
246
247 scope.register(BarrierOps::InitManual {
248 barrier,
249 arrival_count: *arrival_count.expand,
250 });
251 })
252 }
253}
254
255#[cube(self_type = "ref")]
258impl Barrier {
259 #[allow(unused_variables)]
266 pub fn memcpy_async<C: CubePrimitive>(&self, source: &Slice<C>, destination: &mut SliceMut<C>) {
267 intrinsic!(|scope| {
268 let barrier = *self.expand;
269 let source_length = *source.length.expand;
270 let (source, source_offset) = source.__to_raw_parts();
271 let (destination, destination_offset) = destination.__to_raw_parts();
272
273 let mem_copy = BarrierOps::MemCopyAsync {
274 barrier,
275 source,
276 source_length,
277 offset_source: source_offset,
278 offset_out: destination_offset,
279 };
280
281 scope.register(Instruction::new(mem_copy, destination));
282 })
283 }
284
285 #[allow(unused_variables)]
292 pub fn memcpy_async_cooperative<C: CubePrimitive>(
293 &self,
294 source: &Slice<C>,
295 destination: &mut SliceMut<C>,
296 ) {
297 intrinsic!(|scope| {
298 let barrier = *self.expand;
299 let source_length = *source.length.expand;
300 let (source, source_offset) = source.__to_raw_parts();
301 let (destination, destination_offset) = destination.__to_raw_parts();
302
303 let mem_copy = BarrierOps::MemCopyAsyncCooperative {
304 barrier,
305 source,
306 source_length,
307 offset_source: source_offset,
308 offset_out: destination_offset,
309 };
310
311 scope.register(Instruction::new(mem_copy, destination));
312 })
313 }
314
315 #[allow(unused_variables)]
323 pub fn memcpy_async_tx<C: CubePrimitive>(
324 &self,
325 source: &Slice<C>,
326 destination: &mut SliceMut<C>,
327 ) {
328 intrinsic!(|scope| {
329 let barrier = *self.expand;
330 let source_length = *source.length.expand;
331 let (source, source_offset) = source.__to_raw_parts();
332 let (destination, destination_offset) = destination.__to_raw_parts();
333
334 let mem_copy = BarrierOps::MemCopyAsyncTx {
335 barrier,
336 source,
337 source_length,
338 offset_source: source_offset,
339 offset_out: destination_offset,
340 };
341
342 scope.register(Instruction::new(mem_copy, destination));
343 })
344 }
345}
346
347#[cube(self_type = "ref")]
350impl Barrier {
351 pub fn arrive(&self) -> BarrierToken {
353 intrinsic!(|scope| {
354 let barrier = *self.expand;
355 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
356 unreachable!()
357 };
358 let token = scope.create_barrier_token(barrier.index().unwrap(), level);
359 scope.register(Instruction::new(BarrierOps::Arrive { barrier }, *token));
360 token.into()
361 })
362 }
363
364 #[allow(unused_variables)]
366 pub fn arrive_and_expect_tx(&self, arrival_count: u32, transaction_count: u32) -> BarrierToken {
367 intrinsic!(|scope| {
368 let barrier = *self.expand;
369 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
370 unreachable!()
371 };
372 let token = scope.create_barrier_token(barrier.index().unwrap(), level);
373 let arrival_count: ManagedVariable = arrival_count.into();
374 let transaction_count: ManagedVariable = transaction_count.into();
375 scope.register(Instruction::new(
376 BarrierOps::ArriveTx {
377 barrier,
378 arrive_count_update: arrival_count.consume(),
379 transaction_count_update: transaction_count.consume(),
380 },
381 *token,
382 ));
383 token.into()
384 })
385 }
386
387 #[allow(unused_variables)]
389 pub fn expect_tx(&self, expected_count: u32) {
390 intrinsic!(|scope| {
391 let barrier = *self.expand;
392 let transaction_count: ManagedVariable = expected_count.into();
393 scope.register(BarrierOps::ExpectTx {
394 barrier,
395 transaction_count_update: transaction_count.consume(),
396 });
397 })
398 }
399
400 pub fn arrive_and_wait(&self) {
402 intrinsic!(|scope| {
403 let barrier = *self.expand;
404 scope.register(BarrierOps::ArriveAndWait { barrier });
405 })
406 }
407
408 #[allow(unused_variables)]
410 pub fn wait(&self, token: BarrierToken) {
411 intrinsic!(|scope| {
412 let barrier = *self.expand;
413 let token = *token.expand;
414 scope.register(BarrierOps::Wait { barrier, token });
415 })
416 }
417
418 #[allow(unused_variables)]
421 pub fn wait_parity(&self, phase: u32) {
422 intrinsic!(|scope| {
423 let barrier = *self.expand;
424 let phase = *phase.expand;
425 scope.register(BarrierOps::WaitParity { barrier, phase });
426 })
427 }
428}
429
430pub fn copy_async<C: CubePrimitive>(
442 _source: &Slice<C>,
443 _destination: &mut SliceMut<C>,
444 _copy_size: u32,
445) {
446 unexpanded!()
447}
448
449pub mod copy_async {
450 use super::*;
451
452 pub fn expand<C: CubePrimitive>(
453 scope: &mut Scope,
454 source: SliceExpand<C, ReadOnly>,
455 destination: SliceExpand<C, ReadWrite>,
456 copy_length: u32,
457 ) {
458 let source_length = copy_length.into();
459 let (source, source_offset) = source.__to_raw_parts();
460 let (destination, destination_offset) = destination.__to_raw_parts();
461 let scalar_size = C::as_type(scope).storage_type().size();
462
463 let mem_copy = BarrierOps::CopyAsync {
464 source,
465 source_length,
466 offset_source: source_offset,
467 offset_out: destination_offset,
468 copy_length: copy_length * scalar_size as u32,
469 checked: false,
470 };
471
472 scope.register(Instruction::new(mem_copy, destination));
473 }
474}
475
476pub fn copy_async_checked<C: CubePrimitive>(
488 _source: &Slice<C>,
489 _destination: &mut SliceMut<C>,
490 _copy_size: u32,
491) {
492 unexpanded!();
493}
494
495pub mod copy_async_checked {
496 use super::*;
497
498 pub fn expand<C: CubePrimitive>(
499 scope: &mut Scope,
500 source: SliceExpand<C, ReadOnly>,
501 destination: SliceExpand<C, ReadWrite>,
502 copy_length: u32,
503 ) {
504 let source_length = *source.length.expand;
505 let (source, source_offset) = source.__to_raw_parts();
506 let (destination, destination_offset) = destination.__to_raw_parts();
507 let scalar_size = C::as_type(scope).storage_type().size();
508
509 let mem_copy = BarrierOps::CopyAsync {
510 source,
511 source_length,
512 offset_source: source_offset,
513 offset_out: destination_offset,
514 copy_length: copy_length * scalar_size as u32,
515 checked: true,
516 };
517
518 scope.register(Instruction::new(mem_copy, destination));
519 }
520}
521
522#[cube(self_type = "ref")]
523impl Barrier {
524 pub fn commit_copy_async(&self) {
531 intrinsic!(|scope| {
532 let barrier = *self.expand;
533 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
534 unreachable!()
535 };
536 let token = scope.create_barrier_token(barrier.index().unwrap(), level);
537 scope.register(Instruction::new(
538 BarrierOps::CommitCopyAsync { barrier },
539 *token,
540 ));
541 })
542 }
543}
544
545impl Deref for Shared<Barrier> {
546 type Target = Barrier;
547
548 fn deref(&self) -> &Self::Target {
549 unexpanded!()
550 }
551}
552impl Deref for SharedExpand<Barrier> {
553 type Target = BarrierExpand;
554
555 fn deref(&self) -> &Self::Target {
556 unsafe { self.as_type_ref_unchecked::<Barrier>() }
557 }
558}
559
560impl DerefMut for Shared<Barrier> {
561 fn deref_mut(&mut self) -> &mut Self::Target {
562 todo!()
563 }
564}
565impl DerefMut for SharedExpand<Barrier> {
566 fn deref_mut(&mut self) -> &mut Self::Target {
567 unsafe { self.as_type_mut_unchecked::<Barrier>() }
568 }
569}
570
571impl From<SharedExpand<Barrier>> for BarrierExpand {
572 fn from(value: SharedExpand<Barrier>) -> Self {
573 value.expand.into()
574 }
575}