cubecl_core/frontend/element/
atomic.rs1use cubecl_ir::{AtomicOp, ExpandElement};
2
3use super::{
4 ExpandElementIntoMut, ExpandElementTyped, Int, LaunchArgExpand, Numeric,
5 into_mut_expand_element,
6};
7use crate::{
8 frontend::{CubePrimitive, CubeType},
9 ir::{BinaryOperator, CompareAndSwapOperator, Elem, Instruction, Item, Scope, UnaryOperator},
10 prelude::KernelBuilder,
11 unexpanded,
12};
13
14#[derive(Clone, Copy, Hash, PartialEq, Eq)]
18pub struct Atomic<Inner: CubePrimitive> {
19 pub val: Inner,
20}
21
22impl<Inner: Numeric> Atomic<Inner> {
23 #[allow(unused_variables)]
25 pub fn load(pointer: &Self) -> Inner {
26 unexpanded!()
27 }
28
29 #[allow(unused_variables)]
31 pub fn store(pointer: &Self, value: Inner) {
32 unexpanded!()
33 }
34
35 #[allow(unused_variables)]
37 pub fn swap(pointer: &Self, value: Inner) -> Inner {
38 unexpanded!()
39 }
40
41 #[allow(unused_variables)]
43 pub fn add(pointer: &Self, value: Inner) -> Inner {
44 unexpanded!()
45 }
46
47 #[allow(unused_variables)]
50 pub fn max(pointer: &Self, value: Inner) -> Inner {
51 unexpanded!()
52 }
53
54 #[allow(unused_variables)]
57 pub fn min(pointer: &Self, value: Inner) -> Inner {
58 unexpanded!()
59 }
60
61 #[allow(unused_variables)]
63 pub fn sub(pointer: &Self, value: Inner) -> Inner {
64 unexpanded!()
65 }
66
67 pub fn __expand_load(
68 scope: &mut Scope,
69 pointer: <Self as CubeType>::ExpandType,
70 ) -> <Inner as CubeType>::ExpandType {
71 let pointer: ExpandElement = pointer.into();
72 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
73 scope.register(Instruction::new(
74 AtomicOp::Load(UnaryOperator { input: *pointer }),
75 *new_var,
76 ));
77 new_var.into()
78 }
79
80 pub fn __expand_store(
81 scope: &mut Scope,
82 pointer: <Self as CubeType>::ExpandType,
83 value: <Inner as CubeType>::ExpandType,
84 ) {
85 let ptr: ExpandElement = pointer.into();
86 let value: ExpandElement = value.into();
87 scope.register(Instruction::new(
88 AtomicOp::Store(UnaryOperator { input: *value }),
89 *ptr,
90 ));
91 }
92
93 pub fn __expand_swap(
94 scope: &mut Scope,
95 pointer: <Self as CubeType>::ExpandType,
96 value: <Inner as CubeType>::ExpandType,
97 ) -> <Inner as CubeType>::ExpandType {
98 let ptr: ExpandElement = pointer.into();
99 let value: ExpandElement = value.into();
100 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
101 scope.register(Instruction::new(
102 AtomicOp::Swap(BinaryOperator {
103 lhs: *ptr,
104 rhs: *value,
105 }),
106 *new_var,
107 ));
108 new_var.into()
109 }
110
111 pub fn __expand_add(
112 scope: &mut Scope,
113 pointer: <Self as CubeType>::ExpandType,
114 value: <Inner as CubeType>::ExpandType,
115 ) -> <Inner as CubeType>::ExpandType {
116 let ptr: ExpandElement = pointer.into();
117 let value: ExpandElement = value.into();
118 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
119 scope.register(Instruction::new(
120 AtomicOp::Add(BinaryOperator {
121 lhs: *ptr,
122 rhs: *value,
123 }),
124 *new_var,
125 ));
126 new_var.into()
127 }
128
129 pub fn __expand_sub(
130 scope: &mut Scope,
131 pointer: <Self as CubeType>::ExpandType,
132 value: <Inner as CubeType>::ExpandType,
133 ) -> <Inner as CubeType>::ExpandType {
134 let ptr: ExpandElement = pointer.into();
135 let value: ExpandElement = value.into();
136 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
137 scope.register(Instruction::new(
138 AtomicOp::Sub(BinaryOperator {
139 lhs: *ptr,
140 rhs: *value,
141 }),
142 *new_var,
143 ));
144 new_var.into()
145 }
146
147 pub fn __expand_max(
148 scope: &mut Scope,
149 pointer: <Self as CubeType>::ExpandType,
150 value: <Inner as CubeType>::ExpandType,
151 ) -> <Inner as CubeType>::ExpandType {
152 let ptr: ExpandElement = pointer.into();
153 let value: ExpandElement = value.into();
154 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
155 scope.register(Instruction::new(
156 AtomicOp::Max(BinaryOperator {
157 lhs: *ptr,
158 rhs: *value,
159 }),
160 *new_var,
161 ));
162 new_var.into()
163 }
164
165 pub fn __expand_min(
166 scope: &mut Scope,
167 pointer: <Self as CubeType>::ExpandType,
168 value: <Inner as CubeType>::ExpandType,
169 ) -> <Inner as CubeType>::ExpandType {
170 let ptr: ExpandElement = pointer.into();
171 let value: ExpandElement = value.into();
172 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
173 scope.register(Instruction::new(
174 AtomicOp::Min(BinaryOperator {
175 lhs: *ptr,
176 rhs: *value,
177 }),
178 *new_var,
179 ));
180 new_var.into()
181 }
182}
183
184impl<Inner: Int> Atomic<Inner> {
185 #[allow(unused_variables)]
191 pub fn compare_and_swap(pointer: &Self, cmp: Inner, value: Inner) -> Inner {
192 unexpanded!()
193 }
194
195 #[allow(unused_variables)]
197 pub fn and(pointer: &Self, value: Inner) -> Inner {
198 unexpanded!()
199 }
200
201 #[allow(unused_variables)]
203 pub fn or(pointer: &Self, value: Inner) -> Inner {
204 unexpanded!()
205 }
206
207 #[allow(unused_variables)]
209 pub fn xor(pointer: &Self, value: Inner) -> Inner {
210 unexpanded!()
211 }
212
213 pub fn __expand_compare_and_swap(
214 scope: &mut Scope,
215 pointer: <Self as CubeType>::ExpandType,
216 cmp: <Inner as CubeType>::ExpandType,
217 value: <Inner as CubeType>::ExpandType,
218 ) -> <Inner as CubeType>::ExpandType {
219 let pointer: ExpandElement = pointer.into();
220 let cmp: ExpandElement = cmp.into();
221 let value: ExpandElement = value.into();
222 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
223 scope.register(Instruction::new(
224 AtomicOp::CompareAndSwap(CompareAndSwapOperator {
225 input: *pointer,
226 cmp: *cmp,
227 val: *value,
228 }),
229 *new_var,
230 ));
231 new_var.into()
232 }
233
234 pub fn __expand_and(
235 scope: &mut Scope,
236 pointer: <Self as CubeType>::ExpandType,
237 value: <Inner as CubeType>::ExpandType,
238 ) -> <Inner as CubeType>::ExpandType {
239 let ptr: ExpandElement = pointer.into();
240 let value: ExpandElement = value.into();
241 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
242 scope.register(Instruction::new(
243 AtomicOp::And(BinaryOperator {
244 lhs: *ptr,
245 rhs: *value,
246 }),
247 *new_var,
248 ));
249 new_var.into()
250 }
251
252 pub fn __expand_or(
253 scope: &mut Scope,
254 pointer: <Self as CubeType>::ExpandType,
255 value: <Inner as CubeType>::ExpandType,
256 ) -> <Inner as CubeType>::ExpandType {
257 let ptr: ExpandElement = pointer.into();
258 let value: ExpandElement = value.into();
259 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
260 scope.register(Instruction::new(
261 AtomicOp::Or(BinaryOperator {
262 lhs: *ptr,
263 rhs: *value,
264 }),
265 *new_var,
266 ));
267 new_var.into()
268 }
269
270 pub fn __expand_xor(
271 scope: &mut Scope,
272 pointer: <Self as CubeType>::ExpandType,
273 value: <Inner as CubeType>::ExpandType,
274 ) -> <Inner as CubeType>::ExpandType {
275 let ptr: ExpandElement = pointer.into();
276 let value: ExpandElement = value.into();
277 let new_var = scope.create_local(Item::new(Inner::as_elem(scope)));
278 scope.register(Instruction::new(
279 AtomicOp::Xor(BinaryOperator {
280 lhs: *ptr,
281 rhs: *value,
282 }),
283 *new_var,
284 ));
285 new_var.into()
286 }
287}
288
289impl<Inner: CubePrimitive> CubeType for Atomic<Inner> {
290 type ExpandType = ExpandElementTyped<Self>;
291}
292
293impl<Inner: CubePrimitive> CubePrimitive for Atomic<Inner> {
294 fn as_elem_native() -> Option<Elem> {
295 match Inner::as_elem_native() {
296 Some(Elem::Float(kind)) => Some(Elem::AtomicFloat(kind)),
297 Some(Elem::Int(kind)) => Some(Elem::AtomicInt(kind)),
298 Some(Elem::UInt(kind)) => Some(Elem::AtomicUInt(kind)),
299 None => None,
300 _ => unreachable!("Atomics can only be float/int/uint"),
301 }
302 }
303
304 fn as_elem(scope: &Scope) -> Elem {
305 match Inner::as_elem(scope) {
306 Elem::Float(kind) => Elem::AtomicFloat(kind),
307 Elem::Int(kind) => Elem::AtomicInt(kind),
308 Elem::UInt(kind) => Elem::AtomicUInt(kind),
309 _ => unreachable!("Atomics can only be float/int/uint"),
310 }
311 }
312
313 fn as_elem_native_unchecked() -> Elem {
314 match Inner::as_elem_native_unchecked() {
315 Elem::Float(kind) => Elem::AtomicFloat(kind),
316 Elem::Int(kind) => Elem::AtomicInt(kind),
317 Elem::UInt(kind) => Elem::AtomicUInt(kind),
318 _ => unreachable!("Atomics can only be float/int/uint"),
319 }
320 }
321
322 fn size() -> Option<usize> {
323 Inner::size()
324 }
325
326 fn from_expand_elem(elem: ExpandElement) -> Self::ExpandType {
327 ExpandElementTyped::new(elem)
328 }
329}
330
331impl<Inner: CubePrimitive> ExpandElementIntoMut for Atomic<Inner> {
332 fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
333 into_mut_expand_element(scope, elem)
334 }
335}
336
337impl<Inner: CubePrimitive> LaunchArgExpand for Atomic<Inner> {
338 type CompilationArg = ();
339
340 fn expand(_: &Self::CompilationArg, builder: &mut KernelBuilder) -> ExpandElementTyped<Self> {
341 builder.scalar(Self::as_elem_native_unchecked()).into()
342 }
343}