use cubecl_ir::{AtomicOp, ConstantValue, ManagedVariable, StorageType};
use cubecl_macros::intrinsic;
use super::{NativeAssign, NativeExpand, Numeric};
use crate::{
self as cubecl,
frontend::{CubePrimitive, CubeType},
ir::{BinaryOperator, CompareAndSwapOperator, Instruction, Scope, Type, UnaryOperator},
prelude::*,
};
#[derive(Clone, Copy, Hash, PartialEq, Eq)]
pub struct Atomic<Inner: CubePrimitive> {
pub val: Inner,
}
type AtomicExpand<Inner> = NativeExpand<Atomic<Inner>>;
#[cube]
impl<Inner: CubePrimitive<Scalar: Numeric>> Atomic<Inner> {
#[allow(unused_variables)]
pub fn load(&self) -> Inner {
intrinsic!(|scope| {
let pointer: ManagedVariable = self.into();
let new_var = scope.create_local(Inner::as_type(scope));
scope.register(Instruction::new(
AtomicOp::Load(UnaryOperator { input: *pointer }),
*new_var,
));
new_var.into()
})
}
#[allow(unused_variables)]
pub fn store(&self, value: Inner) {
intrinsic!(|scope| {
let ptr: ManagedVariable = self.into();
let value: ManagedVariable = value.into();
scope.register(Instruction::new(
AtomicOp::Store(UnaryOperator { input: *value }),
*ptr,
));
})
}
#[allow(unused_variables)]
pub fn swap(&self, value: Inner) -> Inner {
intrinsic!(|scope| {
let ptr: ManagedVariable = self.into();
let value: ManagedVariable = value.into();
let new_var = scope.create_local(Inner::as_type(scope));
scope.register(Instruction::new(
AtomicOp::Swap(BinaryOperator {
lhs: *ptr,
rhs: *value,
}),
*new_var,
));
new_var.into()
})
}
#[allow(unused_variables)]
pub fn fetch_add(&self, value: Inner) -> Inner {
intrinsic!(|scope| {
let ptr: ManagedVariable = self.into();
let value: ManagedVariable = value.into();
let new_var = scope.create_local(Inner::as_type(scope));
scope.register(Instruction::new(
AtomicOp::Add(BinaryOperator {
lhs: *ptr,
rhs: *value,
}),
*new_var,
));
new_var.into()
})
}
#[allow(unused_variables)]
pub fn fetch_sub(&self, value: Inner) -> Inner {
intrinsic!(|scope| {
let ptr: ManagedVariable = self.into();
let value: ManagedVariable = value.into();
let new_var = scope.create_local(Inner::as_type(scope));
scope.register(Instruction::new(
AtomicOp::Sub(BinaryOperator {
lhs: *ptr,
rhs: *value,
}),
*new_var,
));
new_var.into()
})
}
#[allow(unused_variables)]
pub fn fetch_max(&self, value: Inner) -> Inner {
intrinsic!(|scope| {
let ptr: ManagedVariable = self.into();
let value: ManagedVariable = value.into();
let new_var = scope.create_local(Inner::as_type(scope));
scope.register(Instruction::new(
AtomicOp::Max(BinaryOperator {
lhs: *ptr,
rhs: *value,
}),
*new_var,
));
new_var.into()
})
}
#[allow(unused_variables)]
pub fn fetch_min(&self, value: Inner) -> Inner {
intrinsic!(|scope| {
let ptr: ManagedVariable = self.into();
let value: ManagedVariable = value.into();
let new_var = scope.create_local(Inner::as_type(scope));
scope.register(Instruction::new(
AtomicOp::Min(BinaryOperator {
lhs: *ptr,
rhs: *value,
}),
*new_var,
));
new_var.into()
})
}
}
#[cube]
impl<Inner: CubePrimitive<Scalar: Int>> Atomic<Inner> {
#[allow(unused_variables)]
pub fn compare_exchange_weak(&self, cmp: Inner, value: Inner) -> Inner {
intrinsic!(|scope| {
let pointer: ManagedVariable = self.into();
let cmp: ManagedVariable = cmp.into();
let value: ManagedVariable = value.into();
let new_var = scope.create_local(Inner::as_type(scope));
scope.register(Instruction::new(
AtomicOp::CompareAndSwap(CompareAndSwapOperator {
input: *pointer,
cmp: *cmp,
val: *value,
}),
*new_var,
));
new_var.into()
})
}
#[allow(unused_variables)]
pub fn fetch_and(&self, value: Inner) -> Inner {
intrinsic!(|scope| {
let ptr: ManagedVariable = self.into();
let value: ManagedVariable = value.into();
let new_var = scope.create_local(Inner::as_type(scope));
scope.register(Instruction::new(
AtomicOp::And(BinaryOperator {
lhs: *ptr,
rhs: *value,
}),
*new_var,
));
new_var.into()
})
}
#[allow(unused_variables)]
pub fn fetch_or(&self, value: Inner) -> Inner {
intrinsic!(|scope| {
let ptr: ManagedVariable = self.into();
let value: ManagedVariable = value.into();
let new_var = scope.create_local(Inner::as_type(scope));
scope.register(Instruction::new(
AtomicOp::Or(BinaryOperator {
lhs: *ptr,
rhs: *value,
}),
*new_var,
));
new_var.into()
})
}
#[allow(unused_variables)]
pub fn fetch_xor(&self, value: Inner) -> Inner {
intrinsic!(|scope| {
let ptr: ManagedVariable = self.into();
let value: ManagedVariable = value.into();
let new_var = scope.create_local(Inner::as_type(scope));
scope.register(Instruction::new(
AtomicOp::Xor(BinaryOperator {
lhs: *ptr,
rhs: *value,
}),
*new_var,
));
new_var.into()
})
}
}
impl<Inner: CubePrimitive> CubeType for Atomic<Inner> {
type ExpandType = NativeExpand<Self>;
}
impl<Inner: CubePrimitive> CubePrimitive for Atomic<Inner> {
type Scalar = Inner::Scalar;
type Size = Const<1>;
type WithScalar<S: Scalar> = Atomic<S>;
fn as_type_native() -> Option<Type> {
Inner::as_type_native().map(|it| it.with_storage_type(StorageType::Atomic(it.elem_type())))
}
fn as_type(scope: &Scope) -> Type {
let inner = Inner::as_type(scope);
inner.with_storage_type(StorageType::Atomic(inner.elem_type()))
}
fn as_type_native_unchecked() -> Type {
let inner = Inner::as_type_native_unchecked();
inner.with_storage_type(StorageType::Atomic(inner.elem_type()))
}
fn size() -> Option<usize> {
Inner::size()
}
fn from_expand_elem(elem: ManagedVariable) -> Self::ExpandType {
NativeExpand::new(elem)
}
fn from_const_value(_value: ConstantValue) -> Self {
panic!("Can't have constant atomic");
}
}
impl<Inner: CubePrimitive> NativeAssign for Atomic<Inner> {}