use std::collections::HashMap;
use cubecl_core::{
Info, Metadata,
ir::{Builtin, ElemType, StorageType, UIntKind},
prelude::{KernelDefinition, ScalarKernelArg},
};
use tracel_llvm::mlir_rs::{
dialect::{arith, memref},
ir::{
Block, BlockRef, Location, Region,
r#type::{FunctionType, IntegerType, MemRefType},
},
};
use crate::compiler::{
builtin::BuiltinArray,
passes::shared_memories::{SharedMemories, SharedMemory},
};
use super::prelude::*;
const NB_BUILTIN: usize = 31;
pub(super) struct ArgsManagerBuilder<'a, 'b> {
scalars: Vec<ScalarKernelArg>,
buffers_len: usize,
function_types: Vec<Type<'a>>,
info: Info,
ext_meta_positions: Vec<u32>,
block_inputs: Vec<(Type<'a>, Location<'a>)>,
shared_memories: &'b SharedMemories,
addr_type: Type<'a>,
addr_size: usize,
}
impl<'a, 'b> ArgsManagerBuilder<'a, 'b> {
pub fn new(
kernel: &KernelDefinition,
context: &'a Context,
location: Location<'a>,
shared_memories: &'b SharedMemories,
addr_type: StorageType,
) -> Self {
let total_arg_len = kernel.buffers.len()
+ kernel.scalars.len()
+ NB_PASSED_BUILTIN
+ shared_memories.0.len();
let mut num_ext = 0;
let mut ext_meta_positions = vec![];
let mut all_meta: Vec<_> = kernel
.buffers
.iter()
.map(|buf| (buf.id, buf.has_extended_meta))
.collect();
all_meta.sort_by_key(|(id, _)| *id);
for (_, has_extended_meta) in &all_meta {
ext_meta_positions.push(num_ext);
if *has_extended_meta {
num_ext += 1;
}
}
let num_meta = all_meta.len();
let metadata = Metadata::new(num_meta as u32, num_ext);
let info = Info::new(&kernel.scalars, metadata, addr_type);
let scalars = kernel.scalars.clone();
let mut args = Self {
buffers_len: kernel.buffers.len(),
scalars,
function_types: Vec::with_capacity(total_arg_len),
block_inputs: Vec::with_capacity(total_arg_len),
ext_meta_positions,
shared_memories,
info,
addr_type: addr_type.to_type(context),
addr_size: addr_type.size(),
};
for binding in kernel.buffers.iter() {
let inner_type = binding.ty.storage_type().to_type(context);
let memref = MemRefType::new(inner_type, &[i64::MIN], None, None).into();
args.function_types.push(memref);
args.block_inputs.push((memref, location));
}
for shared_memory in args.shared_memories.0.iter() {
let memref = match shared_memory {
SharedMemory::Array { ty, length, .. } => {
let inner_type = ty.to_type(context);
MemRefType::new(inner_type, &[*length as i64], None, None).into()
}
SharedMemory::Value { ty, .. } => {
let inner_type = ty.to_type(context);
MemRefType::new(inner_type, &[1], None, None).into()
}
};
args.function_types.push(memref);
args.block_inputs.push((memref, location));
}
let inner_type = ElemType::UInt(UIntKind::U8).to_type(context);
let memref = MemRefType::new(inner_type, &[i64::MIN], None, None).into();
args.function_types.push(memref);
args.block_inputs.push((memref, location));
let integer_type: Type<'_> = IntegerType::new(context, 32).into();
for _ in 0..9 {
args.function_types.push(integer_type);
args.block_inputs.push((integer_type, location));
}
args
}
pub fn get_fn_type(&self, context: &'a Context) -> FunctionType<'a> {
FunctionType::new(context, &self.function_types, &[])
}
pub fn create_top_block(
self,
region: &Region<'a>,
context: &'a Context,
location: Location<'a>,
) -> ArgsManager<'a> {
let mut args = ArgsManager {
buffers: Vec::with_capacity(self.buffers_len),
scalars_memref: HashMap::with_capacity(self.scalars.len()),
static_metadata_memref: None,
dynamic_metadata_memref: None,
builtin: [None; NB_BUILTIN],
metadata: self.info.metadata,
shared_memory_values: HashMap::with_capacity(self.shared_memories.0.len()),
ext_meta_positions: self.ext_meta_positions.clone(),
addr_type: self.addr_type,
addr_size: self.addr_size,
};
let block = Block::new(&self.block_inputs);
let mut total_len = 0;
for i in 0..self.buffers_len {
args.buffers.push(block.argument(i).unwrap().into());
}
total_len += self.buffers_len;
for (i, shared_memory) in self.shared_memories.0.iter().enumerate() {
let i = i + total_len;
args.shared_memory_values
.insert(shared_memory.id(), block.argument(i).unwrap().into());
}
total_len += self.shared_memories.0.len();
let info_arg: Value<'a, 'a> = block.argument(total_len).unwrap().into();
total_len += 1;
for field in self.info.scalars.iter() {
let byte_shift = block
.const_int_from_type(context, location, field.offset as i64, Type::index(context))
.unwrap();
let elem_ty = field.ty.to_type(context);
let memref_ty = MemRefType::new(elem_ty, &[field.size as i64], None, None);
let view = memref::view(context, info_arg, byte_shift, &[], memref_ty, location);
args.scalars_memref
.insert(field.ty, block.append_op_result(view).unwrap());
}
if let Some(field) = self.info.sized_meta.as_ref() {
let byte_shift = block
.const_int_from_type(context, location, field.offset as i64, Type::index(context))
.unwrap();
let memref_ty = MemRefType::new(self.addr_type, &[field.size as i64], None, None);
let view = memref::view(context, info_arg, byte_shift, &[], memref_ty, location);
args.static_metadata_memref =
Some(block.append_operation(view).result(0).unwrap().into());
}
{
let zero = block
.const_int_from_type(context, location, 0, Type::index(context))
.unwrap();
let info_size = block
.append_op_result(memref::dim(info_arg, zero, location))
.unwrap();
let byte_shift = block
.const_int_from_type(
context,
location,
self.info.dynamic_meta_offset as i64,
Type::index(context),
)
.unwrap();
let dynamic_size = block
.append_op_result(arith::subi(info_size, byte_shift, location))
.unwrap();
let type_size = block
.const_int_from_type(
context,
location,
self.addr_size as i64,
Type::index(context),
)
.unwrap();
let dynamic_size_elems = block
.append_op_result(arith::divui(dynamic_size, type_size, location))
.unwrap();
let memref_ty = MemRefType::new(self.addr_type, &[i64::MIN], None, None);
let view = memref::view(
context,
info_arg,
byte_shift,
&[dynamic_size_elems],
memref_ty,
location,
);
args.dynamic_metadata_memref =
Some(block.append_operation(view).result(0).unwrap().into());
}
for (i, builtin) in BuiltinArray::builtin_order().into_iter().enumerate() {
let i = i + total_len;
args.set(builtin, block.argument(i).unwrap().into());
}
region.append_block(block);
args
}
}
pub(super) struct ArgsManager<'a> {
pub buffers: Vec<Value<'a, 'a>>,
pub scalars_memref: HashMap<StorageType, Value<'a, 'a>>,
pub static_metadata_memref: Option<Value<'a, 'a>>,
pub dynamic_metadata_memref: Option<Value<'a, 'a>>,
pub ext_meta_positions: Vec<u32>,
pub metadata: Metadata,
pub shared_memory_values: HashMap<u32, Value<'a, 'a>>,
pub builtin: [Option<Value<'a, 'a>>; NB_BUILTIN],
pub addr_type: Type<'a>,
pub addr_size: usize,
}
const NB_PASSED_BUILTIN: usize = 9;
impl<'a> ArgsManager<'a> {
pub fn buffer_position(&self, var: Variable) -> u32 {
var.index().expect("Variable should have index")
}
pub fn ext_meta_position(&self, var: Variable) -> u32 {
let id = var.index().expect("Variable should have index");
self.ext_meta_positions[id as usize]
}
pub fn compute_derived_args_builtin(
&mut self,
block: BlockRef<'a, 'a>,
location: Location<'a>,
) {
let cube_dim_xy = block
.muli(
self.get(Builtin::CubeDimX),
self.get(Builtin::CubeDimY),
location,
)
.unwrap();
let cube_dim = block
.muli(cube_dim_xy, self.get(Builtin::CubeDimZ), location)
.unwrap();
self.set(Builtin::CubeDim, cube_dim);
let unit_pos_z_corrected = block
.muli(self.get(Builtin::UnitPosZ), cube_dim_xy, location)
.unwrap();
let unit_pos_y_corrected = block
.muli(
self.get(Builtin::UnitPosY),
self.get(Builtin::CubeDimX),
location,
)
.unwrap();
let unit_pos_yz_corrected = block
.addi(unit_pos_z_corrected, unit_pos_y_corrected, location)
.unwrap();
let unit_pos = block
.addi(unit_pos_yz_corrected, self.get(Builtin::UnitPosX), location)
.unwrap();
self.set(Builtin::UnitPos, unit_pos);
}
pub fn set(&mut self, builtin: Builtin, value: Value<'a, 'a>) {
self.builtin[builtin as usize] = Some(value);
}
pub fn get(&self, builtin: Builtin) -> Value<'a, 'a> {
self.builtin[builtin as usize]
.unwrap_or_else(|| panic!("Unsupported builtin was used: {builtin:?}"))
}
pub fn as_address_type<'b, 'c: 'b>(
&self,
value: Value<'c, 'c>,
block: &'b Block<'c>,
location: Location<'c>,
) -> Value<'c, 'c>
where
'a: 'c,
{
if self.addr_size > 4 {
let value: Value<'c, 'b> = block.extui(value, self.addr_type, location).unwrap();
unsafe { core::mem::transmute::<Value<'c, 'b>, Value<'c, 'c>>(value) }
} else {
value
}
}
}