cubecl_core/frontend/element/
atomic.rs1use cubecl_ir::{AtomicOp, ExpandElement};
2
3use super::{
4 ExpandElementBaseInit, ExpandElementTyped, Int, LaunchArgExpand, Numeric, init_expand_element,
5};
6use crate::{
7 frontend::{CubePrimitive, CubeType},
8 ir::{BinaryOperator, CompareAndSwapOperator, Elem, Instruction, Item, Scope, UnaryOperator},
9 prelude::KernelBuilder,
10 unexpanded,
11};
12
13#[derive(Clone, Copy, Hash, PartialEq, Eq)]
17pub struct Atomic<Inner: CubePrimitive> {
18 pub val: Inner,
19}
20
21impl<Inner: Numeric> Atomic<Inner> {
22 #[allow(unused_variables)]
24 pub fn load(pointer: &Self) -> Inner {
25 unexpanded!()
26 }
27
28 #[allow(unused_variables)]
30 pub fn store(pointer: &Self, value: Inner) {
31 unexpanded!()
32 }
33
34 #[allow(unused_variables)]
36 pub fn swap(pointer: &Self, value: Inner) -> Inner {
37 unexpanded!()
38 }
39
40 #[allow(unused_variables)]
42 pub fn add(pointer: &Self, value: Inner) -> Inner {
43 unexpanded!()
44 }
45
46 #[allow(unused_variables)]
49 pub fn max(pointer: &Self, value: Inner) -> Inner {
50 unexpanded!()
51 }
52
53 #[allow(unused_variables)]
56 pub fn min(pointer: &Self, value: Inner) -> Inner {
57 unexpanded!()
58 }
59
60 #[allow(unused_variables)]
62 pub fn sub(pointer: &Self, value: Inner) -> Inner {
63 unexpanded!()
64 }
65
66 pub fn __expand_load(
67 scope: &mut Scope,
68 pointer: <Self as CubeType>::ExpandType,
69 ) -> <Inner as CubeType>::ExpandType {
70 let pointer: ExpandElement = pointer.into();
71 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
72 scope.register(Instruction::new(
73 AtomicOp::Load(UnaryOperator { input: *pointer }),
74 *new_var,
75 ));
76 new_var.into()
77 }
78
79 pub fn __expand_store(
80 scope: &mut Scope,
81 pointer: <Self as CubeType>::ExpandType,
82 value: <Inner as CubeType>::ExpandType,
83 ) {
84 let ptr: ExpandElement = pointer.into();
85 let value: ExpandElement = value.into();
86 scope.register(Instruction::new(
87 AtomicOp::Store(UnaryOperator { input: *value }),
88 *ptr,
89 ));
90 }
91
92 pub fn __expand_swap(
93 scope: &mut Scope,
94 pointer: <Self as CubeType>::ExpandType,
95 value: <Inner as CubeType>::ExpandType,
96 ) -> <Inner as CubeType>::ExpandType {
97 let ptr: ExpandElement = pointer.into();
98 let value: ExpandElement = value.into();
99 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
100 scope.register(Instruction::new(
101 AtomicOp::Swap(BinaryOperator {
102 lhs: *ptr,
103 rhs: *value,
104 }),
105 *new_var,
106 ));
107 new_var.into()
108 }
109
110 pub fn __expand_add(
111 scope: &mut Scope,
112 pointer: <Self as CubeType>::ExpandType,
113 value: <Inner as CubeType>::ExpandType,
114 ) -> <Inner as CubeType>::ExpandType {
115 let ptr: ExpandElement = pointer.into();
116 let value: ExpandElement = value.into();
117 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
118 scope.register(Instruction::new(
119 AtomicOp::Add(BinaryOperator {
120 lhs: *ptr,
121 rhs: *value,
122 }),
123 *new_var,
124 ));
125 new_var.into()
126 }
127
128 pub fn __expand_sub(
129 scope: &mut Scope,
130 pointer: <Self as CubeType>::ExpandType,
131 value: <Inner as CubeType>::ExpandType,
132 ) -> <Inner as CubeType>::ExpandType {
133 let ptr: ExpandElement = pointer.into();
134 let value: ExpandElement = value.into();
135 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
136 scope.register(Instruction::new(
137 AtomicOp::Sub(BinaryOperator {
138 lhs: *ptr,
139 rhs: *value,
140 }),
141 *new_var,
142 ));
143 new_var.into()
144 }
145
146 pub fn __expand_max(
147 scope: &mut Scope,
148 pointer: <Self as CubeType>::ExpandType,
149 value: <Inner as CubeType>::ExpandType,
150 ) -> <Inner as CubeType>::ExpandType {
151 let ptr: ExpandElement = pointer.into();
152 let value: ExpandElement = value.into();
153 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
154 scope.register(Instruction::new(
155 AtomicOp::Max(BinaryOperator {
156 lhs: *ptr,
157 rhs: *value,
158 }),
159 *new_var,
160 ));
161 new_var.into()
162 }
163
164 pub fn __expand_min(
165 scope: &mut Scope,
166 pointer: <Self as CubeType>::ExpandType,
167 value: <Inner as CubeType>::ExpandType,
168 ) -> <Inner as CubeType>::ExpandType {
169 let ptr: ExpandElement = pointer.into();
170 let value: ExpandElement = value.into();
171 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
172 scope.register(Instruction::new(
173 AtomicOp::Min(BinaryOperator {
174 lhs: *ptr,
175 rhs: *value,
176 }),
177 *new_var,
178 ));
179 new_var.into()
180 }
181}
182
183impl<Inner: Int> Atomic<Inner> {
184 #[allow(unused_variables)]
190 pub fn compare_and_swap(pointer: &Self, cmp: Inner, value: Inner) -> Inner {
191 unexpanded!()
192 }
193
194 #[allow(unused_variables)]
196 pub fn and(pointer: &Self, value: Inner) -> Inner {
197 unexpanded!()
198 }
199
200 #[allow(unused_variables)]
202 pub fn or(pointer: &Self, value: Inner) -> Inner {
203 unexpanded!()
204 }
205
206 #[allow(unused_variables)]
208 pub fn xor(pointer: &Self, value: Inner) -> Inner {
209 unexpanded!()
210 }
211
212 pub fn __expand_compare_and_swap(
213 scope: &mut Scope,
214 pointer: <Self as CubeType>::ExpandType,
215 cmp: <Inner as CubeType>::ExpandType,
216 value: <Inner as CubeType>::ExpandType,
217 ) -> <Inner as CubeType>::ExpandType {
218 let pointer: ExpandElement = pointer.into();
219 let cmp: ExpandElement = cmp.into();
220 let value: ExpandElement = value.into();
221 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
222 scope.register(Instruction::new(
223 AtomicOp::CompareAndSwap(CompareAndSwapOperator {
224 input: *pointer,
225 cmp: *cmp,
226 val: *value,
227 }),
228 *new_var,
229 ));
230 new_var.into()
231 }
232
233 pub fn __expand_and(
234 scope: &mut Scope,
235 pointer: <Self as CubeType>::ExpandType,
236 value: <Inner as CubeType>::ExpandType,
237 ) -> <Inner as CubeType>::ExpandType {
238 let ptr: ExpandElement = pointer.into();
239 let value: ExpandElement = value.into();
240 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
241 scope.register(Instruction::new(
242 AtomicOp::And(BinaryOperator {
243 lhs: *ptr,
244 rhs: *value,
245 }),
246 *new_var,
247 ));
248 new_var.into()
249 }
250
251 pub fn __expand_or(
252 scope: &mut Scope,
253 pointer: <Self as CubeType>::ExpandType,
254 value: <Inner as CubeType>::ExpandType,
255 ) -> <Inner as CubeType>::ExpandType {
256 let ptr: ExpandElement = pointer.into();
257 let value: ExpandElement = value.into();
258 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
259 scope.register(Instruction::new(
260 AtomicOp::Or(BinaryOperator {
261 lhs: *ptr,
262 rhs: *value,
263 }),
264 *new_var,
265 ));
266 new_var.into()
267 }
268
269 pub fn __expand_xor(
270 scope: &mut Scope,
271 pointer: <Self as CubeType>::ExpandType,
272 value: <Inner as CubeType>::ExpandType,
273 ) -> <Inner as CubeType>::ExpandType {
274 let ptr: ExpandElement = pointer.into();
275 let value: ExpandElement = value.into();
276 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
277 scope.register(Instruction::new(
278 AtomicOp::Xor(BinaryOperator {
279 lhs: *ptr,
280 rhs: *value,
281 }),
282 *new_var,
283 ));
284 new_var.into()
285 }
286}
287
288impl<Inner: CubePrimitive> CubeType for Atomic<Inner> {
289 type ExpandType = ExpandElementTyped<Self>;
290}
291
292impl<Inner: CubePrimitive> CubePrimitive for Atomic<Inner> {
293 fn as_elem_native() -> Option<Elem> {
294 match Inner::as_elem_native() {
295 Some(Elem::Float(kind)) => Some(Elem::AtomicFloat(kind)),
296 Some(Elem::Int(kind)) => Some(Elem::AtomicInt(kind)),
297 Some(Elem::UInt(kind)) => Some(Elem::AtomicUInt(kind)),
298 None => None,
299 _ => unreachable!("Atomics can only be float/int/uint"),
300 }
301 }
302
303 fn as_elem(scope: &Scope) -> Elem {
304 match Inner::as_elem(scope) {
305 Elem::Float(kind) => Elem::AtomicFloat(kind),
306 Elem::Int(kind) => Elem::AtomicInt(kind),
307 Elem::UInt(kind) => Elem::AtomicUInt(kind),
308 _ => unreachable!("Atomics can only be float/int/uint"),
309 }
310 }
311
312 fn as_elem_native_unchecked() -> Elem {
313 match Inner::as_elem_native_unchecked() {
314 Elem::Float(kind) => Elem::AtomicFloat(kind),
315 Elem::Int(kind) => Elem::AtomicInt(kind),
316 Elem::UInt(kind) => Elem::AtomicUInt(kind),
317 _ => unreachable!("Atomics can only be float/int/uint"),
318 }
319 }
320
321 fn size() -> Option<usize> {
322 Inner::size()
323 }
324
325 fn from_expand_elem(elem: ExpandElement) -> Self::ExpandType {
326 ExpandElementTyped::new(elem)
327 }
328}
329
330impl<Inner: CubePrimitive> ExpandElementBaseInit for Atomic<Inner> {
331 fn init_elem(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
332 init_expand_element(scope, elem)
333 }
334}
335
336impl<Inner: CubePrimitive> LaunchArgExpand for Atomic<Inner> {
337 type CompilationArg = ();
338
339 fn expand(_: &Self::CompilationArg, builder: &mut KernelBuilder) -> ExpandElementTyped<Self> {
340 builder.scalar(Self::as_elem_native_unchecked()).into()
341 }
342}