1use std::marker::PhantomData;
4
5use cubecl_ir::{ExpandElement, Instruction};
6use paste::paste;
7
8use crate::{
9 ir::{BarrierOps, Item, Scope},
10 unexpanded,
11};
12
13use super::{
14 CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, Line, ReadOnly, ReadWrite,
15 Slice, SliceExpand, SliceMut, TensorMap,
16};
17
18#[derive(Clone, Copy)]
21pub struct Barrier<C: CubePrimitive> {
22 _c: PhantomData<C>,
23}
24
25impl<C: CubePrimitive> CubeType for Barrier<C> {
26 type ExpandType = BarrierExpand<C>;
27}
28
29impl<C: CubePrimitive> IntoMut for BarrierExpand<C> {
30 fn into_mut(self, _scope: &mut Scope) -> Self {
31 self
32 }
33}
34
35impl<C: CubePrimitive> CubeDebug for BarrierExpand<C> {
36 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
37 scope.update_variable_name(*self.elem, name);
38 }
39}
40
41#[derive(Clone)]
42pub struct BarrierExpand<C: CubePrimitive> {
44 elem: ExpandElement,
45 _c: PhantomData<C>,
46}
47
48#[derive(Copy, Clone, PartialEq, Eq)]
49pub struct BarrierLevel(InnerBarrierLevel);
50
51impl CubeType for BarrierLevel {
52 type ExpandType = Self;
53}
54
55impl IntoMut for BarrierLevel {
56 fn into_mut(self, _scope: &mut Scope) -> Self {
57 self
58 }
59}
60
61impl CubeDebug for BarrierLevel {
62 fn set_debug_name(&self, _scope: &mut Scope, _name: &'static str) {}
63}
64
65#[derive(Copy, Clone, Eq, PartialEq)]
66enum InnerBarrierLevel {
69 Unit,
72
73 CubeCoop(u32),
79
80 CubeManual(u32),
85}
86
87impl BarrierLevel {
88 pub fn unit() -> Self {
90 BarrierLevel(InnerBarrierLevel::Unit)
91 }
92
93 pub fn cube_coop(elected_unit: u32) -> Self {
97 BarrierLevel(InnerBarrierLevel::CubeCoop(elected_unit))
98 }
99
100 pub fn cube_manual(elected_unit: u32) -> Self {
104 BarrierLevel(InnerBarrierLevel::CubeManual(elected_unit))
105 }
106
107 pub fn __expand_unit(_scope: &mut Scope) -> BarrierLevel {
108 BarrierLevel(InnerBarrierLevel::Unit)
109 }
110
111 pub fn __expand_cube_coop(_scope: &mut Scope, elected_unit: u32) -> Self {
112 BarrierLevel(InnerBarrierLevel::CubeCoop(elected_unit))
113 }
114
115 pub fn __expand_cube_manual(_scope: &mut Scope, elected_unit: u32) -> Self {
116 BarrierLevel(InnerBarrierLevel::CubeManual(elected_unit))
117 }
118}
119
120impl From<InnerBarrierLevel> for cubecl_ir::BarrierLevel {
121 fn from(val: InnerBarrierLevel) -> Self {
122 match val {
123 InnerBarrierLevel::Unit => cubecl_ir::BarrierLevel::Unit,
124 InnerBarrierLevel::CubeCoop(elected_unit) => {
125 cubecl_ir::BarrierLevel::CubeCoop(elected_unit)
126 }
127 InnerBarrierLevel::CubeManual(elected_unit) => {
128 cubecl_ir::BarrierLevel::CubeManual(elected_unit)
129 }
130 }
131 }
132}
133
134macro_rules! tensor_map_load {
135 ($dim: literal, $($arg: expr),*) => {
136 paste! {
137 impl<C: CubePrimitive> Barrier<C> {
138 #[allow(unused, clippy::too_many_arguments)]
141 pub fn [<tma_load_ $dim d>](
142 &self,
143 source: &TensorMap<C>,
144 destination: &mut SliceMut<Line<C>>,
145 $($arg: i32),*
146 ) {
147 unexpanded!()
148 }
149
150 #[allow(clippy::too_many_arguments)]
151 pub fn [<__expand_tma_load_ $dim d>](
152 scope: &mut Scope,
153 expand: BarrierExpand<C>,
154 source: ExpandElementTyped<TensorMap<C>>,
155 destination: SliceExpand<Line<C>, ReadWrite>,
156 $($arg: ExpandElementTyped<i32>),*
157 ) {
158 expand.[<__expand_tma_load_ $dim d_method>](scope, source, destination, $($arg),*);
159 }
160 }
161
162 impl<C: CubePrimitive> BarrierExpand<C> {
163 #[allow(clippy::too_many_arguments)]
164 pub fn [<__expand_tma_load_ $dim d_method>](
165 &self,
166 scope: &mut Scope,
167 source: ExpandElementTyped<TensorMap<C>>,
168 destination: SliceExpand<Line<C>, ReadWrite>,
169 $($arg: ExpandElementTyped<i32>),*
170 ) {
171 let barrier = *self.elem;
172 let source = *source.expand;
173 let (destination, destination_offset) = destination.__to_raw_parts();
174
175 let mem_copy = BarrierOps::TmaLoad {
176 barrier,
177 tensor_map: source,
178 indices: vec![$(*$arg.expand),*],
179 offset_out: destination_offset
180 };
181
182 scope.register(Instruction::new(mem_copy, destination));
183 }
184 }
185 }
186 };
187}
188
189macro_rules! tensor_map_load_im2col {
190 ($dim: literal, $($arg: expr),*; $($offset: expr),*) => {
191 paste! {
192 impl<C: CubePrimitive> Barrier<C> {
193 #[allow(unused, clippy::too_many_arguments)]
196 pub fn [<tma_load_im2col_ $dim d>](
197 &self,
198 source: &TensorMap<C>,
199 destination: &mut SliceMut<Line<C>>,
200 $($arg: i32,)*
201 $($offset: u16),*
202 ) {
203 unexpanded!()
204 }
205
206 #[allow(clippy::too_many_arguments)]
207 pub fn [<__expand_tma_load_im2col_ $dim d>](
208 scope: &mut Scope,
209 expand: BarrierExpand<C>,
210 source: ExpandElementTyped<TensorMap<C>>,
211 destination: SliceExpand<Line<C>, ReadWrite>,
212 $($arg: ExpandElementTyped<i32>,)*
213 $($offset: ExpandElementTyped<u16>),*
214 ) {
215 expand.[<__expand_tma_load_im2col_ $dim d_method>](scope, source, destination, $($arg),*, $($offset),*);
216 }
217 }
218
219 impl<C: CubePrimitive> BarrierExpand<C> {
220 #[allow(clippy::too_many_arguments)]
221 pub fn [<__expand_tma_load_im2col_ $dim d_method>](
222 &self,
223 scope: &mut Scope,
224 source: ExpandElementTyped<TensorMap<C>>,
225 destination: SliceExpand<Line<C>, ReadWrite>,
226 $($arg: ExpandElementTyped<i32>,)*
227 $($offset: ExpandElementTyped<u16>),*
228 ) {
229 let barrier = *self.elem;
230 let source = *source.expand;
231 let (destination, destination_offset) = destination.__to_raw_parts();
232
233 let mem_copy = BarrierOps::TmaLoadIm2col {
234 barrier,
235 tensor_map: source,
236 indices: vec![$(*$arg.expand),*],
237 offsets: vec![$(*$offset.expand),*],
238 offset_out: destination_offset,
239 };
240
241 scope.register(Instruction::new(mem_copy, destination));
242 }
243 }
244 }
245 };
246}
247
248tensor_map_load!(2, y, x);
249tensor_map_load!(3, z, y, x);
250tensor_map_load!(4, w, z, y, x);
251tensor_map_load!(5, v, w, z, y, x);
252
253tensor_map_load_im2col!(3, n, w, c; w_offset);
254tensor_map_load_im2col!(4, n, h, w, c; h_offset, w_offset);
255tensor_map_load_im2col!(5, n, d, h, w, c; d_offset, h_offset, w_offset);
256
257impl<C: CubePrimitive> Barrier<C> {
258 pub fn new(_level: BarrierLevel) -> Self {
260 Self { _c: PhantomData }
261 }
262
263 pub fn new_with_tma_proxy(_level: BarrierLevel) -> Self {
266 Self { _c: PhantomData }
267 }
268
269 pub fn memcpy_async(&self, _source: &Slice<Line<C>>, _destination: &mut SliceMut<Line<C>>) {
276 unexpanded!()
277 }
278
279 pub fn arrive(&self) {
281 unexpanded!()
282 }
283
284 pub fn arrive_tx(&self, _arrival_count: u32, _transaction_count: u32) {
286 unexpanded!()
287 }
288
289 pub fn expect_tx(&self, _expected_count: u32) {
291 unexpanded!()
292 }
293
294 pub fn wait(&self) {
296 unexpanded!()
297 }
298
299 pub fn arrive_and_wait(&self) {
301 unexpanded!()
302 }
303
304 pub fn __expand_new(scope: &mut Scope, level: BarrierLevel) -> BarrierExpand<C> {
305 let elem = C::as_elem(scope);
306
307 let variable = scope.create_barrier(Item::new(elem), level.0.into());
308 scope.register(BarrierOps::Init {
309 barrier: *variable,
310 with_cta_fence: false,
311 });
312 BarrierExpand {
313 elem: variable,
314 _c: PhantomData,
315 }
316 }
317
318 pub fn __expand_new_with_tma_proxy(scope: &mut Scope, level: BarrierLevel) -> BarrierExpand<C> {
319 let elem = C::as_elem(scope);
320
321 let variable = scope.create_barrier(Item::new(elem), level.0.into());
322 scope.register(BarrierOps::Init {
323 barrier: *variable,
324 with_cta_fence: true,
325 });
326 BarrierExpand {
327 elem: variable,
328 _c: PhantomData,
329 }
330 }
331
332 pub fn __expand_memcpy_async(
333 scope: &mut Scope,
334 expand: BarrierExpand<C>,
335 source: SliceExpand<Line<C>, ReadOnly>,
336 destination: SliceExpand<Line<C>, ReadWrite>,
337 ) {
338 expand.__expand_memcpy_async_method(scope, source, destination);
339 }
340
341 pub fn __expand_arrive(scope: &mut Scope, expand: BarrierExpand<C>) {
342 expand.__expand_arrive_method(scope);
343 }
344
345 pub fn __expand_arrive_tx(
346 scope: &mut Scope,
347 expand: BarrierExpand<C>,
348 arrival_count: ExpandElementTyped<u32>,
349 transaction_count: ExpandElementTyped<u32>,
350 ) {
351 expand.__expand_arrive_tx_method(scope, arrival_count, transaction_count);
352 }
353
354 pub fn __expand_expect_tx(
355 scope: &mut Scope,
356 expand: BarrierExpand<C>,
357 expected_count: ExpandElementTyped<u32>,
358 ) {
359 expand.__expand_expect_tx_method(scope, expected_count);
360 }
361
362 pub fn __expand_wait(scope: &mut Scope, expand: BarrierExpand<C>) {
363 expand.__expand_wait_method(scope);
364 }
365
366 pub fn __expand_arrive_and_wait(scope: &mut Scope, expand: BarrierExpand<C>) {
367 expand.__expand_arrive_and_wait_method(scope);
368 }
369}
370
371impl<C: CubePrimitive> BarrierExpand<C> {
372 pub fn __expand_memcpy_async_method(
373 &self,
374 scope: &mut Scope,
375 source: SliceExpand<Line<C>, ReadOnly>,
376 destination: SliceExpand<Line<C>, ReadWrite>,
377 ) {
378 let barrier = *self.elem;
379 let source_length = *source.length.expand;
380 let (source, source_offset) = source.__to_raw_parts();
381 let (destination, destination_offset) = destination.__to_raw_parts();
382
383 let mem_copy = BarrierOps::MemCopyAsync {
384 barrier,
385 source,
386 source_length,
387 offset_source: source_offset,
388 offset_out: destination_offset,
389 };
390
391 scope.register(Instruction::new(mem_copy, destination));
392 }
393
394 pub fn __expand_arrive_method(&self, scope: &mut Scope) {
395 let barrier = *self.elem;
396 scope.register(BarrierOps::Arrive { barrier });
397 }
398
399 pub fn __expand_arrive_tx_method(
400 &self,
401 scope: &mut Scope,
402 arrival_count: ExpandElementTyped<u32>,
403 transaction_count: ExpandElementTyped<u32>,
404 ) {
405 let barrier = *self.elem;
406 let arrival_count: ExpandElement = arrival_count.into();
407 let transaction_count: ExpandElement = transaction_count.into();
408 scope.register(BarrierOps::ArriveTx {
409 barrier,
410 arrive_count_update: arrival_count.consume(),
411 transaction_count_update: transaction_count.consume(),
412 });
413 }
414
415 pub fn __expand_expect_tx_method(
416 &self,
417 scope: &mut Scope,
418 transaction_count: ExpandElementTyped<u32>,
419 ) {
420 let barrier = *self.elem;
421 let transaction_count: ExpandElement = transaction_count.into();
422 scope.register(BarrierOps::ExpectTx {
423 barrier,
424 transaction_count_update: transaction_count.consume(),
425 });
426 }
427
428 pub fn __expand_wait_method(&self, scope: &mut Scope) {
429 let barrier = *self.elem;
430 scope.register(BarrierOps::Wait { barrier });
431 }
432
433 pub fn __expand_arrive_and_wait_method(&self, scope: &mut Scope) {
434 let barrier = *self.elem;
435 scope.register(BarrierOps::ArriveAndWait { barrier });
436 }
437}