use std::marker::PhantomData;
use crate::{
self as cubecl,
prelude::{Lined, LinedExpand},
unexpanded,
};
use cubecl_ir::{Instruction, Operation, VariableKind};
use cubecl_macros::{cube, intrinsic};
use crate::{
frontend::{CubePrimitive, CubeType, ExpandElementTyped, IntoMut, indexation::Index},
ir::{Scope, Type},
prelude::{
Line, List, ListExpand, ListMut, ListMutExpand, index, index_assign, index_unchecked,
},
};
type SharedMemoryExpand<T> = ExpandElementTyped<SharedMemory<T>>;
#[derive(Clone, Copy)]
pub struct SharedMemory<T: CubeType> {
_val: PhantomData<T>,
}
impl<T: CubePrimitive> IntoMut for ExpandElementTyped<SharedMemory<T>> {
fn into_mut(self, _scope: &mut Scope) -> Self {
self
}
}
impl<T: CubePrimitive> CubeType for SharedMemory<T> {
type ExpandType = ExpandElementTyped<SharedMemory<T>>;
}
impl<T: CubePrimitive + Clone> SharedMemory<T> {
pub fn new<S: Index>(_size: S) -> Self {
SharedMemory { _val: PhantomData }
}
pub fn new_lined<S: Index>(_size: S, _vectorization_factor: u32) -> SharedMemory<Line<T>> {
SharedMemory { _val: PhantomData }
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> u32 {
unexpanded!()
}
pub fn buffer_len(&self) -> u32 {
unexpanded!()
}
pub fn __expand_new_lined(
scope: &mut Scope,
size: ExpandElementTyped<u32>,
line_size: u32,
) -> <SharedMemory<Line<T>> as CubeType>::ExpandType {
let size = size
.constant()
.expect("Shared memory need constant initialization value")
.as_u32();
let var = scope.create_shared(Type::new(T::as_type(scope)).line(line_size), size, None);
ExpandElementTyped::new(var)
}
pub fn vectorized<S: Index>(_size: S, _vectorization_factor: u32) -> Self {
SharedMemory { _val: PhantomData }
}
pub fn __expand_vectorized(
scope: &mut Scope,
size: ExpandElementTyped<u32>,
line_size: u32,
) -> <Self as CubeType>::ExpandType {
let size = size
.constant()
.expect("Shared memory need constant initialization value")
.as_u32();
let var = scope.create_shared(Type::new(T::as_type(scope)).line(line_size), size, None);
ExpandElementTyped::new(var)
}
pub fn __expand_new(
scope: &mut Scope,
size: ExpandElementTyped<u32>,
) -> <Self as CubeType>::ExpandType {
let size = size
.constant()
.expect("Shared memory need constant initialization value")
.as_u32();
let var = scope.create_shared(Type::new(T::as_type(scope)), size, None);
ExpandElementTyped::new(var)
}
pub fn __expand_len(
scope: &mut Scope,
this: ExpandElementTyped<Self>,
) -> ExpandElementTyped<u32> {
this.__expand_len_method(scope)
}
pub fn __expand_buffer_len(
scope: &mut Scope,
this: ExpandElementTyped<Self>,
) -> ExpandElementTyped<u32> {
this.__expand_buffer_len_method(scope)
}
}
impl<T: CubePrimitive> ExpandElementTyped<SharedMemory<T>> {
pub fn __expand_len_method(self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
len_static(&self)
}
pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> ExpandElementTyped<u32> {
self.__expand_len_method(scope)
}
}
#[cube]
impl<T: CubePrimitive + Clone> SharedMemory<T> {
#[allow(unused_variables)]
pub fn new_aligned(
#[comptime] size: u32,
#[comptime] line_size: u32,
#[comptime] alignment: u32,
) -> SharedMemory<Line<T>> {
intrinsic!(|scope| {
let var = scope.create_shared(
Type::new(T::as_type(scope)).line(line_size),
size,
Some(alignment),
);
ExpandElementTyped::new(var)
})
}
pub unsafe fn free(self) {
intrinsic!(|scope| { scope.register(Instruction::no_out(Operation::Free(*self.expand))) })
}
}
fn len_static<T: CubePrimitive>(
shared: &ExpandElementTyped<SharedMemory<T>>,
) -> ExpandElementTyped<u32> {
let VariableKind::SharedMemory { length, .. } = shared.expand.kind else {
unreachable!("Kind of shared memory is always shared memory")
};
length.into()
}
mod indexation {
use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
use crate::ir::Instruction;
use super::*;
type SharedMemoryExpand<E> = ExpandElementTyped<SharedMemory<E>>;
#[cube]
impl<E: CubePrimitive> SharedMemory<E> {
#[allow(unused_variables)]
pub unsafe fn index_unchecked(&self, i: u32) -> &E {
intrinsic!(|scope| {
let out = scope.create_local(self.expand.ty);
scope.register(Instruction::new(
Operator::UncheckedIndex(IndexOperator {
list: *self.expand,
index: i.expand.consume(),
line_size: 0,
unroll_factor: 1,
}),
*out,
));
out.into()
})
}
#[allow(unused_variables)]
pub unsafe fn index_assign_unchecked(&mut self, i: u32, value: E) {
intrinsic!(|scope| {
scope.register(Instruction::new(
Operator::UncheckedIndexAssign(IndexAssignOperator {
index: i.expand.consume(),
value: value.expand.consume(),
line_size: 0,
unroll_factor: 1,
}),
*self.expand,
));
})
}
}
}
impl<T: CubePrimitive> List<T> for SharedMemory<T> {
fn __expand_read(
scope: &mut Scope,
this: ExpandElementTyped<SharedMemory<T>>,
idx: ExpandElementTyped<u32>,
) -> ExpandElementTyped<T> {
index::expand(scope, this, idx)
}
}
impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<SharedMemory<T>> {
fn __expand_read_method(
&self,
scope: &mut Scope,
idx: ExpandElementTyped<u32>,
) -> ExpandElementTyped<T> {
index::expand(scope, self.clone(), idx)
}
fn __expand_read_unchecked_method(
&self,
scope: &mut Scope,
idx: ExpandElementTyped<u32>,
) -> ExpandElementTyped<T> {
index_unchecked::expand(scope, self.clone(), idx)
}
fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
Self::__expand_len_method(self.clone(), scope)
}
}
impl<T: CubePrimitive> Lined for SharedMemory<T> {}
impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<SharedMemory<T>> {
fn line_size(&self) -> u32 {
self.expand.ty.line_size()
}
}
impl<T: CubePrimitive> ListMut<T> for SharedMemory<T> {
fn __expand_write(
scope: &mut Scope,
this: ExpandElementTyped<SharedMemory<T>>,
idx: ExpandElementTyped<u32>,
value: ExpandElementTyped<T>,
) {
index_assign::expand(scope, this, idx, value);
}
}
impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<SharedMemory<T>> {
fn __expand_write_method(
&self,
scope: &mut Scope,
idx: ExpandElementTyped<u32>,
value: ExpandElementTyped<T>,
) {
index_assign::expand(scope, self.clone(), idx, value);
}
}