cubecl_ir/
barrier.rs

1use crate::{Instruction, TypeHash};
2use alloc::{string::String, vec::Vec};
3use core::fmt::{Display, Write};
4
5use crate::OperationReflect;
6
7use super::Variable;
8
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, Copy)]
11pub enum BarrierLevel {
12    Unit,
13    Cube,
14}
15
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
17#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationReflect)]
18#[operation(opcode_name = BarrierOpCode)]
19/// Operations available on a barrier
20pub enum BarrierOps {
21    /// Declare the barrier, without doing any initialization
22    Declare {
23        barrier: Variable,
24    },
25    /// Initialize the barrier, optionally with a cta proxy fence
26    Init {
27        barrier: Variable,
28        is_elected: Variable,
29        arrival_count: Variable,
30        with_async_proxy_fence: bool,
31    },
32    /// Manually initialize the barrier with an arrival count, without any sync or election handling
33    InitManual {
34        barrier: Variable,
35        arrival_count: Variable,
36    },
37    /// Copy source to destination
38    MemCopyAsync {
39        barrier: Variable,
40        source: Variable,
41        source_length: Variable,
42        offset_source: Variable,
43        offset_out: Variable,
44    },
45    /// Copy source to destination, with cooperative behaviour
46    MemCopyAsyncCooperative {
47        barrier: Variable,
48        source: Variable,
49        source_length: Variable,
50        offset_source: Variable,
51        offset_out: Variable,
52    },
53    /// Copy source to destination, with transaction count
54    MemCopyAsyncTx {
55        barrier: Variable,
56        source: Variable,
57        source_length: Variable,
58        offset_source: Variable,
59        offset_out: Variable,
60    },
61    TmaLoad {
62        barrier: Variable,
63        tensor_map: Variable,
64        indices: Vec<Variable>,
65        offset_out: Variable,
66    },
67    TmaLoadIm2col {
68        barrier: Variable,
69        tensor_map: Variable,
70        indices: Vec<Variable>,
71        offsets: Vec<Variable>,
72        offset_out: Variable,
73    },
74    /// Arrives at the barrier (decrements barrier count)
75    Arrive {
76        barrier: Variable,
77    },
78    ArriveTx {
79        barrier: Variable,
80        arrive_count_update: Variable,
81        transaction_count_update: Variable,
82    },
83    ExpectTx {
84        barrier: Variable,
85        transaction_count_update: Variable,
86    },
87    Wait {
88        barrier: Variable,
89        token: Variable,
90    },
91    WaitParity {
92        barrier: Variable,
93        phase: Variable,
94    },
95    /// Waits until data is loaded
96    ArriveAndWait {
97        barrier: Variable,
98    },
99}
100
101impl Display for BarrierOps {
102    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
103        match self {
104            BarrierOps::Declare { .. } => Ok(()),
105            BarrierOps::Init {
106                barrier,
107                arrival_count,
108                with_async_proxy_fence,
109                ..
110            } => match with_async_proxy_fence {
111                true => write!(f, "init_barrier_tma({barrier}, {arrival_count})"),
112                false => write!(f, "init_barrier({barrier}, {arrival_count})"),
113            },
114            BarrierOps::InitManual {
115                barrier,
116                arrival_count,
117            } => {
118                write!(f, "init_barrier({barrier}, {arrival_count})")
119            }
120            BarrierOps::MemCopyAsync {
121                barrier,
122                source,
123                offset_source,
124                offset_out,
125                ..
126            } => {
127                write!(
128                    f,
129                    "out[{offset_out}] = mem_copy_async({barrier}, source: {source}[{offset_source}])",
130                )
131            }
132            BarrierOps::MemCopyAsyncCooperative {
133                barrier,
134                source,
135                offset_source,
136                offset_out,
137                ..
138            } => {
139                write!(
140                    f,
141                    "out[{offset_out}] = mem_copy_async_cooperative({barrier}, source: {source}[{offset_source}])",
142                )
143            }
144            BarrierOps::MemCopyAsyncTx {
145                barrier,
146                source,
147                offset_source,
148                offset_out,
149                ..
150            } => {
151                write!(
152                    f,
153                    "out[{offset_out}] = mem_copy_async_tx({barrier}, source: {source}[{offset_source}])",
154                )
155            }
156            BarrierOps::ArriveAndWait { barrier } => write!(f, "arrive_and_wait({barrier})"),
157            BarrierOps::TmaLoad {
158                barrier,
159                tensor_map,
160                offset_out,
161                indices,
162            } => {
163                let rank = indices.len();
164                let indices = indices.iter().fold(String::new(), |mut s, it| {
165                    let _ = write!(s, "{it}, ");
166                    s
167                });
168                write!(
169                    f,
170                    "out[{offset_out}] = tma_load::<{rank}>({barrier}, {tensor_map}, {indices})"
171                )
172            }
173            BarrierOps::TmaLoadIm2col {
174                barrier,
175                tensor_map,
176                indices,
177                offsets,
178                offset_out,
179            } => {
180                let rank = indices.len();
181                let indices = indices.iter().fold(String::new(), |mut s, it| {
182                    let _ = write!(s, "{it}, ");
183                    s
184                });
185                let offsets = offsets.iter().fold(String::new(), |mut s, it| {
186                    let _ = write!(s, "{it}, ");
187                    s
188                });
189                write!(
190                    f,
191                    "out[{offset_out}] = tma_load_im2col::<{rank}>({barrier}, {tensor_map}, indices: ({indices}), offsets: ({offsets}))"
192                )
193            }
194            BarrierOps::Arrive { barrier } => write!(f, "arrive({barrier})"),
195            BarrierOps::ArriveTx {
196                barrier,
197                arrive_count_update,
198                transaction_count_update,
199            } => write!(
200                f,
201                "arrive_tx({barrier}, {arrive_count_update}, {transaction_count_update})"
202            ),
203            BarrierOps::ExpectTx {
204                barrier,
205                transaction_count_update,
206            } => write!(f, "expect_tx({barrier}, {transaction_count_update})"),
207            BarrierOps::Wait { barrier, token } => write!(f, "wait({barrier}, {token})"),
208            BarrierOps::WaitParity { barrier, phase } => {
209                write!(f, "wait_parity({barrier}, {phase})")
210            }
211        }
212    }
213}
214
215impl From<BarrierOps> for Instruction {
216    fn from(value: BarrierOps) -> Self {
217        Instruction::no_out(value)
218    }
219}