use cubecl_ir::{AtomicOp, ConstantScalarValue, ExpandElement, StorageType};
use super::{ExpandElementIntoMut, ExpandElementTyped, Int, Numeric, into_mut_expand_element};
use crate::{
frontend::{CubePrimitive, CubeType},
ir::{BinaryOperator, CompareAndSwapOperator, Instruction, Scope, Type, UnaryOperator},
unexpanded,
};
#[derive(Clone, Copy, Hash, PartialEq, Eq)]
pub struct Atomic<Inner: CubePrimitive> {
pub val: Inner,
}
impl<Inner: Numeric> Atomic<Inner> {
#[allow(unused_variables)]
pub fn load(pointer: &Self) -> Inner {
unexpanded!()
}
#[allow(unused_variables)]
pub fn store(pointer: &Self, value: Inner) {
unexpanded!()
}
#[allow(unused_variables)]
pub fn swap(pointer: &Self, value: Inner) -> Inner {
unexpanded!()
}
#[allow(unused_variables)]
pub fn add(pointer: &Self, value: Inner) -> Inner {
unexpanded!()
}
#[allow(unused_variables)]
pub fn max(pointer: &Self, value: Inner) -> Inner {
unexpanded!()
}
#[allow(unused_variables)]
pub fn min(pointer: &Self, value: Inner) -> Inner {
unexpanded!()
}
#[allow(unused_variables)]
pub fn sub(pointer: &Self, value: Inner) -> Inner {
unexpanded!()
}
pub fn __expand_load(
scope: &mut Scope,
pointer: <Self as CubeType>::ExpandType,
) -> <Inner as CubeType>::ExpandType {
let pointer: ExpandElement = pointer.into();
let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
scope.register(Instruction::new(
AtomicOp::Load(UnaryOperator { input: *pointer }),
*new_var,
));
new_var.into()
}
pub fn __expand_store(
scope: &mut Scope,
pointer: <Self as CubeType>::ExpandType,
value: <Inner as CubeType>::ExpandType,
) {
let ptr: ExpandElement = pointer.into();
let value: ExpandElement = value.into();
scope.register(Instruction::new(
AtomicOp::Store(UnaryOperator { input: *value }),
*ptr,
));
}
pub fn __expand_swap(
scope: &mut Scope,
pointer: <Self as CubeType>::ExpandType,
value: <Inner as CubeType>::ExpandType,
) -> <Inner as CubeType>::ExpandType {
let ptr: ExpandElement = pointer.into();
let value: ExpandElement = value.into();
let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
scope.register(Instruction::new(
AtomicOp::Swap(BinaryOperator {
lhs: *ptr,
rhs: *value,
}),
*new_var,
));
new_var.into()
}
pub fn __expand_add(
scope: &mut Scope,
pointer: <Self as CubeType>::ExpandType,
value: <Inner as CubeType>::ExpandType,
) -> <Inner as CubeType>::ExpandType {
let ptr: ExpandElement = pointer.into();
let value: ExpandElement = value.into();
let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
scope.register(Instruction::new(
AtomicOp::Add(BinaryOperator {
lhs: *ptr,
rhs: *value,
}),
*new_var,
));
new_var.into()
}
pub fn __expand_sub(
scope: &mut Scope,
pointer: <Self as CubeType>::ExpandType,
value: <Inner as CubeType>::ExpandType,
) -> <Inner as CubeType>::ExpandType {
let ptr: ExpandElement = pointer.into();
let value: ExpandElement = value.into();
let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
scope.register(Instruction::new(
AtomicOp::Sub(BinaryOperator {
lhs: *ptr,
rhs: *value,
}),
*new_var,
));
new_var.into()
}
pub fn __expand_max(
scope: &mut Scope,
pointer: <Self as CubeType>::ExpandType,
value: <Inner as CubeType>::ExpandType,
) -> <Inner as CubeType>::ExpandType {
let ptr: ExpandElement = pointer.into();
let value: ExpandElement = value.into();
let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
scope.register(Instruction::new(
AtomicOp::Max(BinaryOperator {
lhs: *ptr,
rhs: *value,
}),
*new_var,
));
new_var.into()
}
pub fn __expand_min(
scope: &mut Scope,
pointer: <Self as CubeType>::ExpandType,
value: <Inner as CubeType>::ExpandType,
) -> <Inner as CubeType>::ExpandType {
let ptr: ExpandElement = pointer.into();
let value: ExpandElement = value.into();
let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
scope.register(Instruction::new(
AtomicOp::Min(BinaryOperator {
lhs: *ptr,
rhs: *value,
}),
*new_var,
));
new_var.into()
}
}
impl<Inner: Int> Atomic<Inner> {
#[allow(unused_variables)]
pub fn compare_and_swap(pointer: &Self, cmp: Inner, value: Inner) -> Inner {
unexpanded!()
}
#[allow(unused_variables)]
pub fn and(pointer: &Self, value: Inner) -> Inner {
unexpanded!()
}
#[allow(unused_variables)]
pub fn or(pointer: &Self, value: Inner) -> Inner {
unexpanded!()
}
#[allow(unused_variables)]
pub fn xor(pointer: &Self, value: Inner) -> Inner {
unexpanded!()
}
pub fn __expand_compare_and_swap(
scope: &mut Scope,
pointer: <Self as CubeType>::ExpandType,
cmp: <Inner as CubeType>::ExpandType,
value: <Inner as CubeType>::ExpandType,
) -> <Inner as CubeType>::ExpandType {
let pointer: ExpandElement = pointer.into();
let cmp: ExpandElement = cmp.into();
let value: ExpandElement = value.into();
let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
scope.register(Instruction::new(
AtomicOp::CompareAndSwap(CompareAndSwapOperator {
input: *pointer,
cmp: *cmp,
val: *value,
}),
*new_var,
));
new_var.into()
}
pub fn __expand_and(
scope: &mut Scope,
pointer: <Self as CubeType>::ExpandType,
value: <Inner as CubeType>::ExpandType,
) -> <Inner as CubeType>::ExpandType {
let ptr: ExpandElement = pointer.into();
let value: ExpandElement = value.into();
let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
scope.register(Instruction::new(
AtomicOp::And(BinaryOperator {
lhs: *ptr,
rhs: *value,
}),
*new_var,
));
new_var.into()
}
pub fn __expand_or(
scope: &mut Scope,
pointer: <Self as CubeType>::ExpandType,
value: <Inner as CubeType>::ExpandType,
) -> <Inner as CubeType>::ExpandType {
let ptr: ExpandElement = pointer.into();
let value: ExpandElement = value.into();
let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
scope.register(Instruction::new(
AtomicOp::Or(BinaryOperator {
lhs: *ptr,
rhs: *value,
}),
*new_var,
));
new_var.into()
}
pub fn __expand_xor(
scope: &mut Scope,
pointer: <Self as CubeType>::ExpandType,
value: <Inner as CubeType>::ExpandType,
) -> <Inner as CubeType>::ExpandType {
let ptr: ExpandElement = pointer.into();
let value: ExpandElement = value.into();
let new_var = scope.create_local(Type::new(Inner::as_type(scope)));
scope.register(Instruction::new(
AtomicOp::Xor(BinaryOperator {
lhs: *ptr,
rhs: *value,
}),
*new_var,
));
new_var.into()
}
}
impl<Inner: CubePrimitive> CubeType for Atomic<Inner> {
type ExpandType = ExpandElementTyped<Self>;
}
impl<Inner: CubePrimitive> CubePrimitive for Atomic<Inner> {
fn as_type_native() -> Option<StorageType> {
Inner::as_type_native().map(|it| StorageType::Atomic(it.elem_type()))
}
fn as_type(scope: &Scope) -> StorageType {
StorageType::Atomic(Inner::as_type(scope).elem_type())
}
fn as_type_native_unchecked() -> StorageType {
StorageType::Atomic(Inner::as_type_native_unchecked().elem_type())
}
fn size() -> Option<usize> {
Inner::size()
}
fn from_expand_elem(elem: ExpandElement) -> Self::ExpandType {
ExpandElementTyped::new(elem)
}
fn from_const_value(_value: ConstantScalarValue) -> Self {
panic!("Can't have constant atomic");
}
}
impl<Inner: CubePrimitive> ExpandElementIntoMut for Atomic<Inner> {
fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
into_mut_expand_element(scope, elem)
}
}