cubecl_ir/
barrier.rs

1use crate::{Instruction, TypeHash};
2use alloc::{format, string::String, vec::Vec};
3use core::fmt::{Display, Write};
4
5use crate::OperationReflect;
6
7use super::Variable;
8
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, Copy, PartialOrd, Ord)]
11pub enum BarrierLevel {
12    Unit,
13    Cube,
14}
15
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
17#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationReflect)]
18#[operation(opcode_name = BarrierOpCode)]
19/// Operations available on a barrier
20pub enum BarrierOps {
21    /// Declare the barrier, without doing any initialization
22    Declare {
23        barrier: Variable,
24    },
25    /// Initialize the barrier, optionally with a cta proxy fence
26    Init {
27        barrier: Variable,
28        is_elected: Variable,
29        arrival_count: Variable,
30    },
31    /// Manually initialize the barrier with an arrival count, without any sync or election handling
32    InitManual {
33        barrier: Variable,
34        arrival_count: Variable,
35    },
36    /// Copy source to destination
37    MemCopyAsync {
38        barrier: Variable,
39        source: Variable,
40        source_length: Variable,
41        offset_source: Variable,
42        offset_out: Variable,
43    },
44    /// Copy source to destination, with cooperative behaviour
45    MemCopyAsyncCooperative {
46        barrier: Variable,
47        source: Variable,
48        source_length: Variable,
49        offset_source: Variable,
50        offset_out: Variable,
51    },
52    /// Copy source to destination, with transaction count
53    MemCopyAsyncTx {
54        barrier: Variable,
55        source: Variable,
56        source_length: Variable,
57        offset_source: Variable,
58        offset_out: Variable,
59    },
60    /// Copy source to destination
61    CopyAsync {
62        source: Variable,
63        source_length: Variable,
64        offset_source: Variable,
65        offset_out: Variable,
66        copy_length: u32,
67        checked: bool,
68    },
69    TmaLoad {
70        barrier: Variable,
71        tensor_map: Variable,
72        indices: Vec<Variable>,
73        offset_out: Variable,
74    },
75    TmaLoadIm2col {
76        barrier: Variable,
77        tensor_map: Variable,
78        indices: Vec<Variable>,
79        offsets: Vec<Variable>,
80        offset_out: Variable,
81    },
82    /// Arrives at the barrier (decrements barrier count)
83    Arrive {
84        barrier: Variable,
85    },
86    ArriveTx {
87        barrier: Variable,
88        arrive_count_update: Variable,
89        transaction_count_update: Variable,
90    },
91    CommitCopyAsync {
92        barrier: Variable,
93    },
94    ExpectTx {
95        barrier: Variable,
96        transaction_count_update: Variable,
97    },
98    Wait {
99        barrier: Variable,
100        token: Variable,
101    },
102    WaitParity {
103        barrier: Variable,
104        phase: Variable,
105    },
106    /// Waits until data is loaded
107    ArriveAndWait {
108        barrier: Variable,
109    },
110}
111
112impl Display for BarrierOps {
113    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
114        match self {
115            BarrierOps::Declare { .. } => Ok(()),
116            BarrierOps::Init {
117                barrier,
118                arrival_count,
119                ..
120            } => write!(f, "{barrier}.init_barrier({arrival_count})"),
121            BarrierOps::InitManual {
122                barrier,
123                arrival_count,
124            } => write!(f, "{barrier}.init_barrier({arrival_count})"),
125            BarrierOps::MemCopyAsync {
126                barrier,
127                source,
128                offset_source,
129                offset_out,
130                ..
131            } => {
132                write!(
133                    f,
134                    "out[{offset_out}] = mem_copy_async({barrier}, source: {source}[{offset_source}])",
135                )
136            }
137            BarrierOps::MemCopyAsyncCooperative {
138                barrier,
139                source,
140                offset_source,
141                offset_out,
142                ..
143            } => {
144                write!(
145                    f,
146                    "out[{offset_out}] = mem_copy_async_cooperative({barrier}, source: {source}[{offset_source}])",
147                )
148            }
149            BarrierOps::MemCopyAsyncTx {
150                barrier,
151                source,
152                offset_source,
153                offset_out,
154                ..
155            } => {
156                write!(
157                    f,
158                    "out[{offset_out}] = mem_copy_async_tx({barrier}, source: {source}[{offset_source}])",
159                )
160            }
161            BarrierOps::CopyAsync {
162                source,
163                source_length,
164                offset_source,
165                offset_out,
166                copy_length,
167                checked,
168            } => {
169                let source_slice = if *checked {
170                    format!("[{offset_source}..][..{source_length}]")
171                } else {
172                    format!("[{offset_source}]")
173                };
174                write!(
175                    f,
176                    "out[{offset_out}] = copy_async(source: {source}{source_slice}, bytes: {copy_length})",
177                )
178            }
179            BarrierOps::ArriveAndWait { barrier } => write!(f, "arrive_and_wait({barrier})"),
180            BarrierOps::TmaLoad {
181                barrier,
182                tensor_map,
183                offset_out,
184                indices,
185            } => {
186                let rank = indices.len();
187                let indices = indices.iter().fold(String::new(), |mut s, it| {
188                    let _ = write!(s, "{it}, ");
189                    s
190                });
191                write!(
192                    f,
193                    "out[{offset_out}] = tma_load::<{rank}>({barrier}, {tensor_map}, {indices})"
194                )
195            }
196            BarrierOps::TmaLoadIm2col {
197                barrier,
198                tensor_map,
199                indices,
200                offsets,
201                offset_out,
202            } => {
203                let rank = indices.len();
204                let indices = indices.iter().fold(String::new(), |mut s, it| {
205                    let _ = write!(s, "{it}, ");
206                    s
207                });
208                let offsets = offsets.iter().fold(String::new(), |mut s, it| {
209                    let _ = write!(s, "{it}, ");
210                    s
211                });
212                write!(
213                    f,
214                    "out[{offset_out}] = tma_load_im2col::<{rank}>({barrier}, {tensor_map}, indices: ({indices}), offsets: ({offsets}))"
215                )
216            }
217            BarrierOps::Arrive { barrier } => write!(f, "arrive({barrier})"),
218            BarrierOps::CommitCopyAsync { barrier } => write!(f, "commit_copy_async({barrier})"),
219            BarrierOps::ArriveTx {
220                barrier,
221                arrive_count_update,
222                transaction_count_update,
223            } => write!(
224                f,
225                "arrive_tx({barrier}, {arrive_count_update}, {transaction_count_update})"
226            ),
227            BarrierOps::ExpectTx {
228                barrier,
229                transaction_count_update,
230            } => write!(f, "expect_tx({barrier}, {transaction_count_update})"),
231            BarrierOps::Wait { barrier, token } => write!(f, "wait({barrier}, {token})"),
232            BarrierOps::WaitParity { barrier, phase } => {
233                write!(f, "wait_parity({barrier}, {phase})")
234            }
235        }
236    }
237}
238
239impl Display for BarrierLevel {
240    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
241        match self {
242            BarrierLevel::Unit => f.write_str("unit"),
243            BarrierLevel::Cube => f.write_str("cube"),
244        }
245    }
246}
247
248impl From<BarrierOps> for Instruction {
249    fn from(value: BarrierOps) -> Self {
250        Instruction::no_out(value)
251    }
252}