1use cubecl_ir::{ExpandElement, Instruction};
4use paste::paste;
5
6use crate::{
7 ir::{BarrierOps, Scope},
8 unexpanded,
9};
10
11use super::{
12 CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, Line, ReadOnly, ReadWrite,
13 Slice, SliceExpand, SliceMut, TensorMap,
14};
15
16#[derive(Clone, Copy)]
19pub struct Barrier;
20
21impl CubeType for Barrier {
22 type ExpandType = BarrierExpand;
23}
24
25impl IntoMut for BarrierExpand {
26 fn into_mut(self, _scope: &mut Scope) -> Self {
27 self
28 }
29}
30
31impl CubeDebug for BarrierExpand {
32 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
33 scope.update_variable_name(*self.elem, name);
34 }
35}
36
37#[derive(Clone)]
38pub struct BarrierExpand {
40 elem: ExpandElement,
41}
42
43#[derive(Copy, Clone, PartialEq, Eq)]
44pub struct BarrierLevel(InnerBarrierLevel);
45
46impl CubeType for BarrierLevel {
47 type ExpandType = Self;
48}
49
50impl IntoMut for BarrierLevel {
51 fn into_mut(self, _scope: &mut Scope) -> Self {
52 self
53 }
54}
55
56impl CubeDebug for BarrierLevel {
57 fn set_debug_name(&self, _scope: &mut Scope, _name: &'static str) {}
58}
59
60#[derive(Copy, Clone, Eq, PartialEq)]
61enum InnerBarrierLevel {
64 Unit,
67
68 CubeCoop(u32),
74
75 CubeManual(u32),
80}
81
82impl BarrierLevel {
83 pub fn unit() -> Self {
85 BarrierLevel(InnerBarrierLevel::Unit)
86 }
87
88 pub fn cube_coop(elected_unit: u32) -> Self {
92 BarrierLevel(InnerBarrierLevel::CubeCoop(elected_unit))
93 }
94
95 pub fn cube_manual(elected_unit: u32) -> Self {
99 BarrierLevel(InnerBarrierLevel::CubeManual(elected_unit))
100 }
101
102 pub fn __expand_unit(_scope: &mut Scope) -> BarrierLevel {
103 BarrierLevel(InnerBarrierLevel::Unit)
104 }
105
106 pub fn __expand_cube_coop(_scope: &mut Scope, elected_unit: u32) -> Self {
107 BarrierLevel(InnerBarrierLevel::CubeCoop(elected_unit))
108 }
109
110 pub fn __expand_cube_manual(_scope: &mut Scope, elected_unit: u32) -> Self {
111 BarrierLevel(InnerBarrierLevel::CubeManual(elected_unit))
112 }
113}
114
115impl From<InnerBarrierLevel> for cubecl_ir::BarrierLevel {
116 fn from(val: InnerBarrierLevel) -> Self {
117 match val {
118 InnerBarrierLevel::Unit => cubecl_ir::BarrierLevel::Unit,
119 InnerBarrierLevel::CubeCoop(elected_unit) => {
120 cubecl_ir::BarrierLevel::CubeCoop(elected_unit)
121 }
122 InnerBarrierLevel::CubeManual(elected_unit) => {
123 cubecl_ir::BarrierLevel::CubeManual(elected_unit)
124 }
125 }
126 }
127}
128
129macro_rules! tensor_map_load {
130 ($dim: literal, $($arg: expr),*) => {
131 paste! {
132 impl Barrier {
133 #[allow(unused, clippy::too_many_arguments)]
136 pub fn [<tma_load_ $dim d>]<C: CubePrimitive>(
137 &self,
138 source: &TensorMap<C>,
139 destination: &mut SliceMut<Line<C>>,
140 $($arg: i32),*
141 ) {
142 unexpanded!()
143 }
144
145 #[allow(clippy::too_many_arguments)]
146 pub fn [<__expand_tma_load_ $dim d>]<C: CubePrimitive>(
147 scope: &mut Scope,
148 expand: BarrierExpand,
149 source: ExpandElementTyped<TensorMap<C>>,
150 destination: SliceExpand<Line<C>, ReadWrite>,
151 $($arg: ExpandElementTyped<i32>),*
152 ) {
153 expand.[<__expand_tma_load_ $dim d_method>](scope, source, destination, $($arg),*);
154 }
155 }
156
157 impl BarrierExpand {
158 #[allow(clippy::too_many_arguments)]
159 pub fn [<__expand_tma_load_ $dim d_method>]<C: CubePrimitive>(
160 &self,
161 scope: &mut Scope,
162 source: ExpandElementTyped<TensorMap<C>>,
163 destination: SliceExpand<Line<C>, ReadWrite>,
164 $($arg: ExpandElementTyped<i32>),*
165 ) {
166 let barrier = *self.elem;
167 let source = *source.expand;
168 let (destination, destination_offset) = destination.__to_raw_parts();
169
170 let mem_copy = BarrierOps::TmaLoad {
171 barrier,
172 tensor_map: source,
173 indices: vec![$(*$arg.expand),*],
174 offset_out: destination_offset
175 };
176
177 scope.register(Instruction::new(mem_copy, destination));
178 }
179 }
180 }
181 };
182}
183
184macro_rules! tensor_map_load_im2col {
185 ($dim: literal, $($arg: expr),*; $($offset: expr),*) => {
186 paste! {
187 impl Barrier {
188 #[allow(unused, clippy::too_many_arguments)]
191 pub fn [<tma_load_im2col_ $dim d>]<C: CubePrimitive>(
192 &self,
193 source: &TensorMap<C>,
194 destination: &mut SliceMut<Line<C>>,
195 $($arg: i32,)*
196 $($offset: u16),*
197 ) {
198 unexpanded!()
199 }
200
201 #[allow(clippy::too_many_arguments)]
202 pub fn [<__expand_tma_load_im2col_ $dim d>]<C: CubePrimitive>(
203 scope: &mut Scope,
204 expand: BarrierExpand,
205 source: ExpandElementTyped<TensorMap<C>>,
206 destination: SliceExpand<Line<C>, ReadWrite>,
207 $($arg: ExpandElementTyped<i32>,)*
208 $($offset: ExpandElementTyped<u16>),*
209 ) {
210 expand.[<__expand_tma_load_im2col_ $dim d_method>](scope, source, destination, $($arg),*, $($offset),*);
211 }
212 }
213
214 impl BarrierExpand {
215 #[allow(clippy::too_many_arguments)]
216 pub fn [<__expand_tma_load_im2col_ $dim d_method>]<C: CubePrimitive>(
217 &self,
218 scope: &mut Scope,
219 source: ExpandElementTyped<TensorMap<C>>,
220 destination: SliceExpand<Line<C>, ReadWrite>,
221 $($arg: ExpandElementTyped<i32>,)*
222 $($offset: ExpandElementTyped<u16>),*
223 ) {
224 let barrier = *self.elem;
225 let source = *source.expand;
226 let (destination, destination_offset) = destination.__to_raw_parts();
227
228 let mem_copy = BarrierOps::TmaLoadIm2col {
229 barrier,
230 tensor_map: source,
231 indices: vec![$(*$arg.expand),*],
232 offsets: vec![$(*$offset.expand),*],
233 offset_out: destination_offset,
234 };
235
236 scope.register(Instruction::new(mem_copy, destination));
237 }
238 }
239 }
240 };
241}
242
243tensor_map_load!(1, x);
244tensor_map_load!(2, y, x);
245tensor_map_load!(3, z, y, x);
246tensor_map_load!(4, w, z, y, x);
247tensor_map_load!(5, v, w, z, y, x);
248
249tensor_map_load_im2col!(3, n, w, c; w_offset);
250tensor_map_load_im2col!(4, n, h, w, c; h_offset, w_offset);
251tensor_map_load_im2col!(5, n, d, h, w, c; d_offset, h_offset, w_offset);
252
253impl Barrier {
254 pub fn new(_level: BarrierLevel) -> Self {
256 Self
257 }
258
259 pub fn new_with_tma_proxy(_level: BarrierLevel) -> Self {
262 Self
263 }
264
265 pub fn memcpy_async<C: CubePrimitive>(
272 &self,
273 _source: &Slice<Line<C>>,
274 _destination: &mut SliceMut<Line<C>>,
275 ) {
276 unexpanded!()
277 }
278
279 pub fn arrive(&self) {
281 unexpanded!()
282 }
283
284 pub fn arrive_tx(&self, _arrival_count: u32, _transaction_count: u32) {
286 unexpanded!()
287 }
288
289 pub fn expect_tx(&self, _expected_count: u32) {
291 unexpanded!()
292 }
293
294 pub fn wait(&self) {
296 unexpanded!()
297 }
298
299 pub fn arrive_and_wait(&self) {
301 unexpanded!()
302 }
303
304 pub fn __expand_new(scope: &mut Scope, level: BarrierLevel) -> BarrierExpand {
305 let variable = scope.create_barrier(level.0.into());
306 scope.register(BarrierOps::Init {
307 barrier: *variable,
308 with_cta_fence: false,
309 });
310 BarrierExpand { elem: variable }
311 }
312
313 pub fn __expand_new_with_tma_proxy(scope: &mut Scope, level: BarrierLevel) -> BarrierExpand {
314 let variable = scope.create_barrier(level.0.into());
315 scope.register(BarrierOps::Init {
316 barrier: *variable,
317 with_cta_fence: true,
318 });
319 BarrierExpand { elem: variable }
320 }
321
322 pub fn __expand_memcpy_async<C: CubePrimitive>(
323 scope: &mut Scope,
324 expand: BarrierExpand,
325 source: SliceExpand<Line<C>, ReadOnly>,
326 destination: SliceExpand<Line<C>, ReadWrite>,
327 ) {
328 expand.__expand_memcpy_async_method(scope, source, destination);
329 }
330
331 pub fn __expand_arrive(scope: &mut Scope, expand: BarrierExpand) {
332 expand.__expand_arrive_method(scope);
333 }
334
335 pub fn __expand_arrive_tx(
336 scope: &mut Scope,
337 expand: BarrierExpand,
338 arrival_count: ExpandElementTyped<u32>,
339 transaction_count: ExpandElementTyped<u32>,
340 ) {
341 expand.__expand_arrive_tx_method(scope, arrival_count, transaction_count);
342 }
343
344 pub fn __expand_expect_tx(
345 scope: &mut Scope,
346 expand: BarrierExpand,
347 expected_count: ExpandElementTyped<u32>,
348 ) {
349 expand.__expand_expect_tx_method(scope, expected_count);
350 }
351
352 pub fn __expand_wait(scope: &mut Scope, expand: BarrierExpand) {
353 expand.__expand_wait_method(scope);
354 }
355
356 pub fn __expand_arrive_and_wait(scope: &mut Scope, expand: BarrierExpand) {
357 expand.__expand_arrive_and_wait_method(scope);
358 }
359}
360
361impl BarrierExpand {
362 pub fn __expand_memcpy_async_method<C: CubePrimitive>(
363 &self,
364 scope: &mut Scope,
365 source: SliceExpand<Line<C>, ReadOnly>,
366 destination: SliceExpand<Line<C>, ReadWrite>,
367 ) {
368 let barrier = *self.elem;
369 let source_length = *source.length.expand;
370 let (source, source_offset) = source.__to_raw_parts();
371 let (destination, destination_offset) = destination.__to_raw_parts();
372
373 let mem_copy = BarrierOps::MemCopyAsync {
374 barrier,
375 source,
376 source_length,
377 offset_source: source_offset,
378 offset_out: destination_offset,
379 };
380
381 scope.register(Instruction::new(mem_copy, destination));
382 }
383
384 pub fn __expand_arrive_method(&self, scope: &mut Scope) {
385 let barrier = *self.elem;
386 scope.register(BarrierOps::Arrive { barrier });
387 }
388
389 pub fn __expand_arrive_tx_method(
390 &self,
391 scope: &mut Scope,
392 arrival_count: ExpandElementTyped<u32>,
393 transaction_count: ExpandElementTyped<u32>,
394 ) {
395 let barrier = *self.elem;
396 let arrival_count: ExpandElement = arrival_count.into();
397 let transaction_count: ExpandElement = transaction_count.into();
398 scope.register(BarrierOps::ArriveTx {
399 barrier,
400 arrive_count_update: arrival_count.consume(),
401 transaction_count_update: transaction_count.consume(),
402 });
403 }
404
405 pub fn __expand_expect_tx_method(
406 &self,
407 scope: &mut Scope,
408 transaction_count: ExpandElementTyped<u32>,
409 ) {
410 let barrier = *self.elem;
411 let transaction_count: ExpandElement = transaction_count.into();
412 scope.register(BarrierOps::ExpectTx {
413 barrier,
414 transaction_count_update: transaction_count.consume(),
415 });
416 }
417
418 pub fn __expand_wait_method(&self, scope: &mut Scope) {
419 let barrier = *self.elem;
420 scope.register(BarrierOps::Wait { barrier });
421 }
422
423 pub fn __expand_arrive_and_wait_method(&self, scope: &mut Scope) {
424 let barrier = *self.elem;
425 scope.register(BarrierOps::ArriveAndWait { barrier });
426 }
427}