cubecl_ir/
barrier.rs

1use crate::{Instruction, TypeHash};
2use alloc::{string::String, vec::Vec};
3use core::fmt::{Display, Write};
4
5use crate::OperationReflect;
6
7use super::Variable;
8
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, Copy)]
11pub enum BarrierLevel {
12    Unit,
13    CubeCoop(u32),
14    CubeManual(u32),
15}
16
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationReflect)]
19#[operation(opcode_name = BarrierOpCode)]
20/// Operations available on a barrier
21pub enum BarrierOps {
22    /// Initialize the barrier, optionally with a cta proxy fence
23    Init {
24        barrier: Variable,
25        with_cta_fence: bool,
26    },
27    /// Copy source to destination
28    MemCopyAsync {
29        barrier: Variable,
30        source: Variable,
31        source_length: Variable,
32        offset_source: Variable,
33        offset_out: Variable,
34    },
35    TmaLoad {
36        barrier: Variable,
37        tensor_map: Variable,
38        indices: Vec<Variable>,
39        offset_out: Variable,
40    },
41    TmaLoadIm2col {
42        barrier: Variable,
43        tensor_map: Variable,
44        indices: Vec<Variable>,
45        offsets: Vec<Variable>,
46        offset_out: Variable,
47    },
48    /// Arrives at the barrier (decrements barrier count)
49    Arrive {
50        barrier: Variable,
51    },
52    ArriveTx {
53        barrier: Variable,
54        arrive_count_update: Variable,
55        transaction_count_update: Variable,
56    },
57    ExpectTx {
58        barrier: Variable,
59        transaction_count_update: Variable,
60    },
61    Wait {
62        barrier: Variable,
63    },
64    /// Waits until data is loaded
65    ArriveAndWait {
66        barrier: Variable,
67    },
68}
69
70impl Display for BarrierOps {
71    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
72        match self {
73            BarrierOps::Init {
74                barrier,
75                with_cta_fence,
76            } => match with_cta_fence {
77                true => write!(f, "init_barrier_tma({barrier})"),
78                false => write!(f, "init_barrier({barrier})"),
79            },
80            BarrierOps::MemCopyAsync {
81                barrier,
82                source,
83                offset_source,
84                offset_out,
85                ..
86            } => {
87                write!(
88                    f,
89                    "out[{offset_out}] = mem_copy_async({barrier}, source: {source}[{offset_source}])",
90                )
91            }
92            BarrierOps::ArriveAndWait { barrier } => write!(f, "arrive_and_wait({barrier})"),
93            BarrierOps::TmaLoad {
94                barrier,
95                tensor_map,
96                offset_out,
97                indices,
98            } => {
99                let rank = indices.len();
100                let indices = indices.iter().fold(String::new(), |mut s, it| {
101                    let _ = write!(s, "{it}, ");
102                    s
103                });
104                write!(
105                    f,
106                    "out[{offset_out}] = tma_load::<{rank}>({barrier}, {tensor_map}, {indices})"
107                )
108            }
109            BarrierOps::TmaLoadIm2col {
110                barrier,
111                tensor_map,
112                indices,
113                offsets,
114                offset_out,
115            } => {
116                let rank = indices.len();
117                let indices = indices.iter().fold(String::new(), |mut s, it| {
118                    let _ = write!(s, "{it}, ");
119                    s
120                });
121                let offsets = offsets.iter().fold(String::new(), |mut s, it| {
122                    let _ = write!(s, "{it}, ");
123                    s
124                });
125                write!(
126                    f,
127                    "out[{offset_out}] = tma_load_im2col::<{rank}>({barrier}, {tensor_map}, indices: ({indices}), offsets: ({offsets}))"
128                )
129            }
130            BarrierOps::Arrive { barrier } => write!(f, "arrive({barrier})"),
131            BarrierOps::ArriveTx {
132                barrier,
133                arrive_count_update,
134                transaction_count_update,
135            } => write!(
136                f,
137                "arrive_tx({barrier}, {arrive_count_update}, {transaction_count_update})"
138            ),
139            BarrierOps::ExpectTx {
140                barrier,
141                transaction_count_update,
142            } => write!(f, "expect_tx({barrier}, {transaction_count_update})"),
143            BarrierOps::Wait { barrier } => write!(f, "wait({barrier})"),
144        }
145    }
146}
147
148impl From<BarrierOps> for Instruction {
149    fn from(value: BarrierOps) -> Self {
150        Instruction::no_out(value)
151    }
152}