1use std::ops::{Deref, DerefMut};
4
5use crate as cubecl;
6use cubecl_ir::{ExpandElement, Instruction, OpaqueType};
7use cubecl_macros::intrinsic;
8use paste::paste;
9
10use crate::{
11 ir::{BarrierOps, Scope},
12 prelude::*,
13 unexpanded,
14};
15
16use super::{
17 CubePrimitive, CubeType, ExpandElementTyped, Line, ReadOnly, ReadWrite, Slice, SliceExpand,
18 SliceMut, TensorMap,
19};
20
21#[derive(Clone, Copy, PartialEq, Eq)]
24pub struct Barrier;
25pub type BarrierExpand = ExpandElementTyped<Barrier>;
26
27#[derive(Clone, Copy, PartialEq)]
28pub struct BarrierToken;
29
30impl CubeType for Barrier {
31 type ExpandType = ExpandElementTyped<Barrier>;
32}
33
34impl CubePrimitive for Barrier {
35 fn from_const_value(_value: cubecl_ir::ConstantScalarValue) -> Self {
36 unreachable!("Can't create from const value")
37 }
38}
39
40impl ExpandElementIntoMut for Barrier {
41 fn elem_into_mut(_scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
42 elem
43 }
44}
45
46impl CubeType for BarrierToken {
47 type ExpandType = ExpandElementTyped<BarrierToken>;
48}
49
50impl ExpandElementIntoMut for BarrierToken {
51 fn elem_into_mut(_scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
52 elem
53 }
54}
55
56macro_rules! tensor_map_load {
57 ($dim: literal, $($arg: expr),*) => {
58 paste! {
59 impl Barrier {
60 #[allow(unused, clippy::too_many_arguments)]
63 pub fn [<tma_load_ $dim d>]<C: CubePrimitive>(
64 &self,
65 source: &TensorMap<C>,
66 destination: &mut SliceMut<Line<C>>,
67 $($arg: i32),*
68 ) {
69 unexpanded!()
70 }
71
72 #[allow(clippy::too_many_arguments)]
73 pub fn [<__expand_tma_load_ $dim d>]<C: CubePrimitive>(
74 scope: &mut Scope,
75 expand: BarrierExpand,
76 source: ExpandElementTyped<TensorMap<C>>,
77 destination: SliceExpand<Line<C>, ReadWrite>,
78 $($arg: ExpandElementTyped<i32>),*
79 ) {
80 expand.[<__expand_tma_load_ $dim d_method>](scope, source, destination, $($arg),*);
81 }
82 }
83
84 impl BarrierExpand {
85 #[allow(clippy::too_many_arguments)]
86 pub fn [<__expand_tma_load_ $dim d_method>]<C: CubePrimitive>(
87 &self,
88 scope: &mut Scope,
89 source: ExpandElementTyped<TensorMap<C>>,
90 destination: SliceExpand<Line<C>, ReadWrite>,
91 $($arg: ExpandElementTyped<i32>),*
92 ) {
93 let barrier = *self.expand;
94 let source = *source.expand;
95 let (destination, destination_offset) = destination.__to_raw_parts();
96
97 let mem_copy = BarrierOps::TmaLoad {
98 barrier,
99 tensor_map: source,
100 indices: vec![$(*$arg.expand),*],
101 offset_out: destination_offset
102 };
103
104 scope.register(Instruction::new(mem_copy, destination));
105 }
106 }
107 }
108 };
109}
110
111macro_rules! tensor_map_load_im2col {
112 ($dim: literal, $($arg: expr),*; $($offset: expr),*) => {
113 paste! {
114 impl Barrier {
115 #[allow(unused, clippy::too_many_arguments)]
118 pub fn [<tma_load_im2col_ $dim d>]<C: CubePrimitive>(
119 &self,
120 source: &TensorMap<C>,
121 destination: &mut SliceMut<Line<C>>,
122 $($arg: i32,)*
123 $($offset: u16),*
124 ) {
125 unexpanded!()
126 }
127
128 #[allow(clippy::too_many_arguments)]
129 pub fn [<__expand_tma_load_im2col_ $dim d>]<C: CubePrimitive>(
130 scope: &mut Scope,
131 expand: BarrierExpand,
132 source: ExpandElementTyped<TensorMap<C>>,
133 destination: SliceExpand<Line<C>, ReadWrite>,
134 $($arg: ExpandElementTyped<i32>,)*
135 $($offset: ExpandElementTyped<u16>),*
136 ) {
137 expand.[<__expand_tma_load_im2col_ $dim d_method>](scope, source, destination, $($arg),*, $($offset),*);
138 }
139 }
140
141 impl BarrierExpand {
142 #[allow(clippy::too_many_arguments)]
143 pub fn [<__expand_tma_load_im2col_ $dim d_method>]<C: CubePrimitive>(
144 &self,
145 scope: &mut Scope,
146 source: ExpandElementTyped<TensorMap<C>>,
147 destination: SliceExpand<Line<C>, ReadWrite>,
148 $($arg: ExpandElementTyped<i32>,)*
149 $($offset: ExpandElementTyped<u16>),*
150 ) {
151 let barrier = *self.expand;
152 let source = *source.expand;
153 let (destination, destination_offset) = destination.__to_raw_parts();
154
155 let mem_copy = BarrierOps::TmaLoadIm2col {
156 barrier,
157 tensor_map: source,
158 indices: vec![$(*$arg.expand),*],
159 offsets: vec![$(*$offset.expand),*],
160 offset_out: destination_offset,
161 };
162
163 scope.register(Instruction::new(mem_copy, destination));
164 }
165 }
166 }
167 };
168}
169
170tensor_map_load!(1, x);
171tensor_map_load!(2, y, x);
172tensor_map_load!(3, z, y, x);
173tensor_map_load!(4, w, z, y, x);
174tensor_map_load!(5, v, w, z, y, x);
175
176tensor_map_load_im2col!(3, n, w, c; w_offset);
177tensor_map_load_im2col!(4, n, h, w, c; h_offset, w_offset);
178tensor_map_load_im2col!(5, n, d, h, w, c; d_offset, h_offset, w_offset);
179
180#[cube(self_type = "ref")]
181impl Barrier {
182 pub fn local() -> Self {
185 intrinsic!(|scope| {
186 let variable =
187 scope.create_local_mut(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Unit));
188 scope.register(BarrierOps::Init {
189 barrier: *variable,
190 is_elected: true.into(),
191 arrival_count: 1.into(),
192 });
193 variable.into()
194 })
195 }
196
197 #[allow(unused_variables)]
204 pub fn shared(arrival_count: u32, is_elected: bool) -> Shared<Barrier> {
205 intrinsic!(|scope| {
206 let variable = scope.create_shared(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Cube));
207 scope.register(BarrierOps::Init {
208 barrier: *variable,
209 is_elected: *is_elected.expand,
210 arrival_count: *arrival_count.expand,
211 });
212 variable.into()
213 })
214 }
215
216 pub fn shared_uninit() -> Shared<Barrier> {
219 intrinsic!(|scope| {
220 let variable = scope.create_shared(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Cube));
221 scope.register(BarrierOps::Declare { barrier: *variable });
222 variable.into()
223 })
224 }
225
226 #[allow(unused_variables)]
239 pub fn init_manual(&self, arrival_count: u32) {
240 intrinsic!(|scope| {
241 let barrier = *self.expand.clone();
242
243 scope.register(BarrierOps::InitManual {
244 barrier,
245 arrival_count: *arrival_count.expand,
246 });
247 })
248 }
249}
250
251#[cube(self_type = "ref")]
254impl Barrier {
255 #[allow(unused_variables)]
262 pub fn memcpy_async<C: CubePrimitive>(
263 &self,
264 source: &Slice<Line<C>>,
265 destination: &mut SliceMut<Line<C>>,
266 ) {
267 intrinsic!(|scope| {
268 let barrier = *self.expand;
269 let source_length = *source.length.expand;
270 let (source, source_offset) = source.__to_raw_parts();
271 let (destination, destination_offset) = destination.__to_raw_parts();
272
273 let mem_copy = BarrierOps::MemCopyAsync {
274 barrier,
275 source,
276 source_length,
277 offset_source: source_offset,
278 offset_out: destination_offset,
279 };
280
281 scope.register(Instruction::new(mem_copy, destination));
282 })
283 }
284
285 #[allow(unused_variables)]
292 pub fn memcpy_async_cooperative<C: CubePrimitive>(
293 &self,
294 source: &Slice<Line<C>>,
295 destination: &mut SliceMut<Line<C>>,
296 ) {
297 intrinsic!(|scope| {
298 let barrier = *self.expand;
299 let source_length = *source.length.expand;
300 let (source, source_offset) = source.__to_raw_parts();
301 let (destination, destination_offset) = destination.__to_raw_parts();
302
303 let mem_copy = BarrierOps::MemCopyAsyncCooperative {
304 barrier,
305 source,
306 source_length,
307 offset_source: source_offset,
308 offset_out: destination_offset,
309 };
310
311 scope.register(Instruction::new(mem_copy, destination));
312 })
313 }
314
315 #[allow(unused_variables)]
323 pub fn memcpy_async_tx<C: CubePrimitive>(
324 &self,
325 source: &Slice<Line<C>>,
326 destination: &mut SliceMut<Line<C>>,
327 ) {
328 intrinsic!(|scope| {
329 let barrier = *self.expand;
330 let source_length = *source.length.expand;
331 let (source, source_offset) = source.__to_raw_parts();
332 let (destination, destination_offset) = destination.__to_raw_parts();
333
334 let mem_copy = BarrierOps::MemCopyAsyncTx {
335 barrier,
336 source,
337 source_length,
338 offset_source: source_offset,
339 offset_out: destination_offset,
340 };
341
342 scope.register(Instruction::new(mem_copy, destination));
343 })
344 }
345}
346
347#[cube(self_type = "ref")]
350impl Barrier {
351 pub fn arrive(&self) -> BarrierToken {
353 intrinsic!(|scope| {
354 let barrier = *self.expand;
355 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
356 unreachable!()
357 };
358 let token = scope.create_barrier_token(barrier.index().unwrap(), level);
359 scope.register(Instruction::new(BarrierOps::Arrive { barrier }, *token));
360 token.into()
361 })
362 }
363
364 #[allow(unused_variables)]
366 pub fn arrive_and_expect_tx(&self, arrival_count: u32, transaction_count: u32) -> BarrierToken {
367 intrinsic!(|scope| {
368 let barrier = *self.expand;
369 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
370 unreachable!()
371 };
372 let token = scope.create_barrier_token(barrier.index().unwrap(), level);
373 let arrival_count: ExpandElement = arrival_count.into();
374 let transaction_count: ExpandElement = transaction_count.into();
375 scope.register(Instruction::new(
376 BarrierOps::ArriveTx {
377 barrier,
378 arrive_count_update: arrival_count.consume(),
379 transaction_count_update: transaction_count.consume(),
380 },
381 *token,
382 ));
383 token.into()
384 })
385 }
386
387 #[allow(unused_variables)]
389 pub fn expect_tx(&self, expected_count: u32) {
390 intrinsic!(|scope| {
391 let barrier = *self.expand;
392 let transaction_count: ExpandElement = expected_count.into();
393 scope.register(BarrierOps::ExpectTx {
394 barrier,
395 transaction_count_update: transaction_count.consume(),
396 });
397 })
398 }
399
400 pub fn arrive_and_wait(&self) {
402 intrinsic!(|scope| {
403 let barrier = *self.expand;
404 scope.register(BarrierOps::ArriveAndWait { barrier });
405 })
406 }
407
408 #[allow(unused_variables)]
410 pub fn wait(&self, token: BarrierToken) {
411 intrinsic!(|scope| {
412 let barrier = *self.expand;
413 let token = *token.expand;
414 scope.register(BarrierOps::Wait { barrier, token });
415 })
416 }
417
418 #[allow(unused_variables)]
421 pub fn wait_parity(&self, phase: u32) {
422 intrinsic!(|scope| {
423 let barrier = *self.expand;
424 let phase = *phase.expand;
425 scope.register(BarrierOps::WaitParity { barrier, phase });
426 })
427 }
428}
429
430pub fn copy_async<C: CubePrimitive>(
442 _source: &Slice<Line<C>>,
443 _destination: &mut SliceMut<Line<C>>,
444 _copy_size: u32,
445) {
446 unexpanded!()
447}
448
449pub mod copy_async {
450 use super::*;
451
452 pub fn expand<C: CubePrimitive>(
453 scope: &mut Scope,
454 source: SliceExpand<Line<C>, ReadOnly>,
455 destination: SliceExpand<Line<C>, ReadWrite>,
456 copy_length: u32,
457 ) {
458 let source_length = copy_length.into();
459 let (source, source_offset) = source.__to_raw_parts();
460 let (destination, destination_offset) = destination.__to_raw_parts();
461
462 let mem_copy = BarrierOps::CopyAsync {
463 source,
464 source_length,
465 offset_source: source_offset,
466 offset_out: destination_offset,
467 copy_length: copy_length * C::as_type(scope).size() as u32,
468 checked: false,
469 };
470
471 scope.register(Instruction::new(mem_copy, destination));
472 }
473}
474
475pub fn copy_async_checked<C: CubePrimitive>(
487 _source: &Slice<Line<C>>,
488 _destination: &mut SliceMut<Line<C>>,
489 _copy_size: u32,
490) {
491 unexpanded!();
492}
493
494pub mod copy_async_checked {
495 use super::*;
496
497 pub fn expand<C: CubePrimitive>(
498 scope: &mut Scope,
499 source: SliceExpand<Line<C>, ReadOnly>,
500 destination: SliceExpand<Line<C>, ReadWrite>,
501 copy_length: u32,
502 ) {
503 let source_length = *source.length.expand;
504 let (source, source_offset) = source.__to_raw_parts();
505 let (destination, destination_offset) = destination.__to_raw_parts();
506
507 let mem_copy = BarrierOps::CopyAsync {
508 source,
509 source_length,
510 offset_source: source_offset,
511 offset_out: destination_offset,
512 copy_length: copy_length * C::as_type(scope).size() as u32,
513 checked: true,
514 };
515
516 scope.register(Instruction::new(mem_copy, destination));
517 }
518}
519
520#[cube(self_type = "ref")]
521impl Barrier {
522 pub fn commit_copy_async(&self) {
529 intrinsic!(|scope| {
530 let barrier = *self.expand;
531 let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
532 unreachable!()
533 };
534 let token = scope.create_barrier_token(barrier.index().unwrap(), level);
535 scope.register(Instruction::new(
536 BarrierOps::CommitCopyAsync { barrier },
537 *token,
538 ));
539 })
540 }
541}
542
543impl Deref for Shared<Barrier> {
544 type Target = Barrier;
545
546 fn deref(&self) -> &Self::Target {
547 unexpanded!()
548 }
549}
550impl Deref for SharedExpand<Barrier> {
551 type Target = BarrierExpand;
552
553 fn deref(&self) -> &Self::Target {
554 unsafe { self.as_type_ref_unchecked::<Barrier>() }
555 }
556}
557
558impl DerefMut for Shared<Barrier> {
559 fn deref_mut(&mut self) -> &mut Self::Target {
560 todo!()
561 }
562}
563impl DerefMut for SharedExpand<Barrier> {
564 fn deref_mut(&mut self) -> &mut Self::Target {
565 unsafe { self.as_type_mut_unchecked::<Barrier>() }
566 }
567}
568
569impl From<SharedExpand<Barrier>> for BarrierExpand {
570 fn from(value: SharedExpand<Barrier>) -> Self {
571 value.expand.into()
572 }
573}