1use alloc::vec;
4use core::ops::{Deref, DerefMut};
5
6use crate as cubecl;
7use cubecl_ir::{ExpandElement, Instruction, OpaqueType};
8use cubecl_macros::intrinsic;
9use paste::paste;
10
11use crate::{
12 ir::{BarrierOps, Scope},
13 prelude::*,
14 unexpanded,
15};
16
17use super::{
18 CubePrimitive, CubeType, ExpandElementTyped, Line, ReadOnly, ReadWrite, Slice, SliceExpand,
19 SliceMut, TensorMap,
20};
21
22#[derive(Clone, Copy, PartialEq, Eq)]
25pub struct Barrier;
26pub type BarrierExpand = ExpandElementTyped<Barrier>;
27
28#[derive(Clone, Copy, PartialEq)]
29pub struct BarrierToken;
30
31impl CubeType for Barrier {
32 type ExpandType = ExpandElementTyped<Barrier>;
33}
34
35impl CubePrimitive for Barrier {
36 fn from_const_value(_value: cubecl_ir::ConstantValue) -> Self {
37 unreachable!("Can't create from const value")
38 }
39}
40
41impl ExpandElementIntoMut for Barrier {
42 fn elem_into_mut(_scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
43 elem
44 }
45}
46
47impl CubeType for BarrierToken {
48 type ExpandType = ExpandElementTyped<BarrierToken>;
49}
50
51impl ExpandElementIntoMut for BarrierToken {
52 fn elem_into_mut(_scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
53 elem
54 }
55}
56
57macro_rules! tensor_map_load {
58 ($dim: literal, $($arg: expr),*) => {
59 paste! {
60 impl Barrier {
61 #[allow(unused, clippy::too_many_arguments)]
64 pub fn [<tma_load_ $dim d>]<C: CubePrimitive>(
65 &self,
66 source: &TensorMap<C, Tiled>,
67 destination: &mut SliceMut<Line<C>>,
68 $($arg: i32),*
69 ) {
70 unexpanded!()
71 }
72
73 #[allow(clippy::too_many_arguments)]
74 pub fn [<__expand_tma_load_ $dim d>]<C: CubePrimitive>(
75 scope: &mut Scope,
76 expand: BarrierExpand,
77 source: ExpandElementTyped<TensorMap<C, Tiled>>,
78 destination: SliceExpand<Line<C>, ReadWrite>,
79 $($arg: ExpandElementTyped<i32>),*
80 ) {
81 expand.[<__expand_tma_load_ $dim d_method>](scope, source, destination, $($arg),*);
82 }
83 }
84
85 impl BarrierExpand {
86 #[allow(clippy::too_many_arguments)]
87 pub fn [<__expand_tma_load_ $dim d_method>]<C: CubePrimitive>(
88 &self,
89 scope: &mut Scope,
90 source: ExpandElementTyped<TensorMap<C, Tiled>>,
91 destination: SliceExpand<Line<C>, ReadWrite>,
92 $($arg: ExpandElementTyped<i32>),*
93 ) {
94 let barrier = *self.expand;
95 let source = *source.expand;
96 let (destination, destination_offset) = destination.__to_raw_parts();
97
98 let mem_copy = BarrierOps::TmaLoad {
99 barrier,
100 tensor_map: source,
101 indices: vec![$(*$arg.expand),*],
102 offset_out: destination_offset
103 };
104
105 scope.register(Instruction::new(mem_copy, destination));
106 }
107 }
108 }
109 };
110}
111
112macro_rules! tensor_map_load_im2col {
113 ($dim: literal, $($arg: expr),*; $($offset: expr),*) => {
114 paste! {
115 impl Barrier {
116 #[allow(unused, clippy::too_many_arguments)]
119 pub fn [<tma_load_im2col_ $dim d>]<C: CubePrimitive>(
120 &self,
121 source: &TensorMap<C, Im2col>,
122 destination: &mut SliceMut<Line<C>>,
123 $($arg: i32,)*
124 $($offset: u16),*
125 ) {
126 unexpanded!()
127 }
128
129 #[allow(clippy::too_many_arguments)]
130 pub fn [<__expand_tma_load_im2col_ $dim d>]<C: CubePrimitive>(
131 scope: &mut Scope,
132 expand: BarrierExpand,
133 source: ExpandElementTyped<TensorMap<C, Im2col>>,
134 destination: SliceExpand<Line<C>, ReadWrite>,
135 $($arg: ExpandElementTyped<i32>,)*
136 $($offset: ExpandElementTyped<u16>),*
137 ) {
138 expand.[<__expand_tma_load_im2col_ $dim d_method>](scope, source, destination, $($arg),*, $($offset),*);
139 }
140 }
141
142 impl BarrierExpand {
143 #[allow(clippy::too_many_arguments)]
144 pub fn [<__expand_tma_load_im2col_ $dim d_method>]<C: CubePrimitive>(
145 &self,
146 scope: &mut Scope,
147 source: ExpandElementTyped<TensorMap<C, Im2col>>,
148 destination: SliceExpand<Line<C>, ReadWrite>,
149 $($arg: ExpandElementTyped<i32>,)*
150 $($offset: ExpandElementTyped<u16>),*
151 ) {
152 let barrier = *self.expand;
153 let source = *source.expand;
154 let (destination, destination_offset) = destination.__to_raw_parts();
155
156 let mem_copy = BarrierOps::TmaLoadIm2col {
157 barrier,
158 tensor_map: source,
159 indices: vec![$(*$arg.expand),*],
160 offsets: vec![$(*$offset.expand),*],
161 offset_out: destination_offset,
162 };
163
164 scope.register(Instruction::new(mem_copy, destination));
165 }
166 }
167 }
168 };
169}
170
171tensor_map_load!(1, x);
172tensor_map_load!(2, y, x);
173tensor_map_load!(3, z, y, x);
174tensor_map_load!(4, w, z, y, x);
175tensor_map_load!(5, v, w, z, y, x);
176
177tensor_map_load_im2col!(3, n, w, c; w_offset);
178tensor_map_load_im2col!(4, n, h, w, c; h_offset, w_offset);
179tensor_map_load_im2col!(5, n, d, h, w, c; d_offset, h_offset, w_offset);
180
181#[cube(self_type = "ref")]
182impl Barrier {
183 pub fn local() -> Self {
186 intrinsic!(|scope| {
187 let variable =
188 scope.create_local_mut(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Unit));
189 scope.register(BarrierOps::Init {
190 barrier: *variable,
191 is_elected: true.into(),
192 arrival_count: 1.into(),
193 });
194 variable.into()
195 })
196 }
197
198 #[allow(unused_variables)]
205 pub fn shared(arrival_count: u32, is_elected: bool) -> Shared<Barrier> {
206 intrinsic!(|scope| {
207 let variable = scope.create_shared(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Cube));
208 scope.register(BarrierOps::Init {
209 barrier: *variable,
210 is_elected: *is_elected.expand,
211 arrival_count: *arrival_count.expand,
212 });
213 variable.into()
214 })
215 }
216
217 pub fn shared_uninit() -> Shared<Barrier> {
220 intrinsic!(|scope| {
221 let variable = scope.create_shared(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Cube));
222 scope.register(BarrierOps::Declare { barrier: *variable });
223 variable.into()
224 })
225 }
226
227 #[allow(unused_variables)]
240 pub fn init_manual(&self, arrival_count: u32) {
241 intrinsic!(|scope| {
242 let barrier = *self.expand.clone();
243
244 scope.register(BarrierOps::InitManual {
245 barrier,
246 arrival_count: *arrival_count.expand,
247 });
248 })
249 }
250}
251
252#[cube(self_type = "ref")]
255impl Barrier {
256 #[allow(unused_variables)]
263 pub fn memcpy_async<C: CubePrimitive>(
264 &self,
265 source: &Slice<Line<C>>,
266 destination: &mut SliceMut<Line<C>>,
267 ) {
268 intrinsic!(|scope| {
269 let barrier = *self.expand;
270 let source_length = *source.length.expand;
271 let (source, source_offset) = source.__to_raw_parts();
272 let (destination, destination_offset) = destination.__to_raw_parts();
273
274 let mem_copy = BarrierOps::MemCopyAsync {
275 barrier,
276 source,
277 source_length,
278 offset_source: source_offset,
279 offset_out: destination_offset,
280 };
281
282 scope.register(Instruction::new(mem_copy, destination));
283 })
284 }
285
286 #[allow(unused_variables)]
293 pub fn memcpy_async_cooperative<C: CubePrimitive>(
294 &self,
295 source: &Slice<Line<C>>,
296 destination: &mut SliceMut<Line<C>>,
297 ) {
298 intrinsic!(|scope| {
299 let barrier = *self.expand;
300 let source_length = *source.length.expand;
301 let (source, source_offset) = source.__to_raw_parts();
302 let (destination, destination_offset) = destination.__to_raw_parts();
303
304 let mem_copy = BarrierOps::MemCopyAsyncCooperative {
305 barrier,
306 source,
307 source_length,
308 offset_source: source_offset,
309 offset_out: destination_offset,
310 };
311
312 scope.register(Instruction::new(mem_copy, destination));
313 })
314 }
315
316 #[allow(unused_variables)]
324 pub fn memcpy_async_tx<C: CubePrimitive>(
325 &self,
326 source: &Slice<Line<C>>,
327 destination: &mut SliceMut<Line<C>>,
328 ) {
329 intrinsic!(|scope| {
330 let barrier = *self.expand;
331 let source_length = *source.length.expand;
332 let (source, source_offset) = source.__to_raw_parts();
333 let (destination, destination_offset) = destination.__to_raw_parts();
334
335 let mem_copy = BarrierOps::MemCopyAsyncTx {
336 barrier,
337 source,
338 source_length,
339 offset_source: source_offset,
340 offset_out: destination_offset,
341 };
342
343 scope.register(Instruction::new(mem_copy, destination));
344 })
345 }
346}
347
348#[cube(self_type = "ref")]
351impl Barrier {
352 pub fn arrive(&self) -> BarrierToken {
354 intrinsic!(|scope| {
355 let barrier = *self.expand;
356 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
357 unreachable!()
358 };
359 let token = scope.create_barrier_token(barrier.index().unwrap(), level);
360 scope.register(Instruction::new(BarrierOps::Arrive { barrier }, *token));
361 token.into()
362 })
363 }
364
365 #[allow(unused_variables)]
367 pub fn arrive_and_expect_tx(&self, arrival_count: u32, transaction_count: u32) -> BarrierToken {
368 intrinsic!(|scope| {
369 let barrier = *self.expand;
370 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
371 unreachable!()
372 };
373 let token = scope.create_barrier_token(barrier.index().unwrap(), level);
374 let arrival_count: ExpandElement = arrival_count.into();
375 let transaction_count: ExpandElement = transaction_count.into();
376 scope.register(Instruction::new(
377 BarrierOps::ArriveTx {
378 barrier,
379 arrive_count_update: arrival_count.consume(),
380 transaction_count_update: transaction_count.consume(),
381 },
382 *token,
383 ));
384 token.into()
385 })
386 }
387
388 #[allow(unused_variables)]
390 pub fn expect_tx(&self, expected_count: u32) {
391 intrinsic!(|scope| {
392 let barrier = *self.expand;
393 let transaction_count: ExpandElement = expected_count.into();
394 scope.register(BarrierOps::ExpectTx {
395 barrier,
396 transaction_count_update: transaction_count.consume(),
397 });
398 })
399 }
400
401 pub fn arrive_and_wait(&self) {
403 intrinsic!(|scope| {
404 let barrier = *self.expand;
405 scope.register(BarrierOps::ArriveAndWait { barrier });
406 })
407 }
408
409 #[allow(unused_variables)]
411 pub fn wait(&self, token: BarrierToken) {
412 intrinsic!(|scope| {
413 let barrier = *self.expand;
414 let token = *token.expand;
415 scope.register(BarrierOps::Wait { barrier, token });
416 })
417 }
418
419 #[allow(unused_variables)]
422 pub fn wait_parity(&self, phase: u32) {
423 intrinsic!(|scope| {
424 let barrier = *self.expand;
425 let phase = *phase.expand;
426 scope.register(BarrierOps::WaitParity { barrier, phase });
427 })
428 }
429}
430
431pub fn copy_async<C: CubePrimitive>(
443 _source: &Slice<Line<C>>,
444 _destination: &mut SliceMut<Line<C>>,
445 _copy_size: u32,
446) {
447 unexpanded!()
448}
449
450pub mod copy_async {
451 use super::*;
452
453 pub fn expand<C: CubePrimitive>(
454 scope: &mut Scope,
455 source: SliceExpand<Line<C>, ReadOnly>,
456 destination: SliceExpand<Line<C>, ReadWrite>,
457 copy_length: u32,
458 ) {
459 let source_length = copy_length.into();
460 let (source, source_offset) = source.__to_raw_parts();
461 let (destination, destination_offset) = destination.__to_raw_parts();
462
463 let mem_copy = BarrierOps::CopyAsync {
464 source,
465 source_length,
466 offset_source: source_offset,
467 offset_out: destination_offset,
468 copy_length: copy_length * C::as_type(scope).size() as u32,
469 checked: false,
470 };
471
472 scope.register(Instruction::new(mem_copy, destination));
473 }
474}
475
476pub fn copy_async_checked<C: CubePrimitive>(
488 _source: &Slice<Line<C>>,
489 _destination: &mut SliceMut<Line<C>>,
490 _copy_size: u32,
491) {
492 unexpanded!();
493}
494
495pub mod copy_async_checked {
496 use super::*;
497
498 pub fn expand<C: CubePrimitive>(
499 scope: &mut Scope,
500 source: SliceExpand<Line<C>, ReadOnly>,
501 destination: SliceExpand<Line<C>, ReadWrite>,
502 copy_length: u32,
503 ) {
504 let source_length = *source.length.expand;
505 let (source, source_offset) = source.__to_raw_parts();
506 let (destination, destination_offset) = destination.__to_raw_parts();
507
508 let mem_copy = BarrierOps::CopyAsync {
509 source,
510 source_length,
511 offset_source: source_offset,
512 offset_out: destination_offset,
513 copy_length: copy_length * C::as_type(scope).size() as u32,
514 checked: true,
515 };
516
517 scope.register(Instruction::new(mem_copy, destination));
518 }
519}
520
521#[cube(self_type = "ref")]
522impl Barrier {
523 pub fn commit_copy_async(&self) {
530 intrinsic!(|scope| {
531 let barrier = *self.expand;
532 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
533 unreachable!()
534 };
535 let token = scope.create_barrier_token(barrier.index().unwrap(), level);
536 scope.register(Instruction::new(
537 BarrierOps::CommitCopyAsync { barrier },
538 *token,
539 ));
540 })
541 }
542}
543
544impl Deref for Shared<Barrier> {
545 type Target = Barrier;
546
547 fn deref(&self) -> &Self::Target {
548 unexpanded!()
549 }
550}
551impl Deref for SharedExpand<Barrier> {
552 type Target = BarrierExpand;
553
554 fn deref(&self) -> &Self::Target {
555 unsafe { self.as_type_ref_unchecked::<Barrier>() }
556 }
557}
558
559impl DerefMut for Shared<Barrier> {
560 fn deref_mut(&mut self) -> &mut Self::Target {
561 todo!()
562 }
563}
564impl DerefMut for SharedExpand<Barrier> {
565 fn deref_mut(&mut self) -> &mut Self::Target {
566 unsafe { self.as_type_mut_unchecked::<Barrier>() }
567 }
568}
569
570impl From<SharedExpand<Barrier>> for BarrierExpand {
571 fn from(value: SharedExpand<Barrier>) -> Self {
572 value.expand.into()
573 }
574}