use bytemuck::cast_slice_mut;
use cubecl_ir::StorageType;
use cubecl_runtime::server::MetadataBinding;
use crate::prelude::InputScalar;
const BUFFER_LEN: u32 = 0;
const LENGTH: u32 = 1;
const BASE_LEN: u32 = 2;
const RANK: u32 = 0;
const SHAPE_OFFSETS: u32 = 1;
const STRIDE_OFFSETS: u32 = 2;
const EXTENDED_LEN: u32 = 3;
#[derive(Clone, Debug, Default)]
pub struct Metadata {
num_meta: u32,
num_extended_meta: u32,
}
impl Metadata {
pub fn new(num_meta: u32, num_extended_meta: u32) -> Self {
Self {
num_meta,
num_extended_meta,
}
}
fn offset_of(&self, id: u32) -> u32 {
self.num_meta * id
}
fn base_len(&self) -> u32 {
self.num_meta * BASE_LEN
}
pub fn static_len(&self) -> u32 {
self.num_meta * BASE_LEN + self.num_extended_meta * EXTENDED_LEN
}
fn offset_of_extended(&self, id: u32) -> u32 {
self.base_len() + self.num_extended_meta * id
}
pub fn buffer_len_index(&self, buffer_idx: u32) -> u32 {
self.offset_of(BUFFER_LEN) + buffer_idx
}
pub fn len_index(&self, buffer_idx: u32) -> u32 {
self.offset_of(LENGTH) + buffer_idx
}
pub fn rank_index(&self, buffer_idx: u32) -> u32 {
self.offset_of_extended(RANK) + buffer_idx
}
pub fn shape_offset_index(&self, buffer_idx: u32) -> u32 {
self.offset_of_extended(SHAPE_OFFSETS) + buffer_idx
}
pub fn stride_offset_index(&self, buffer_idx: u32) -> u32 {
self.offset_of_extended(STRIDE_OFFSETS) + buffer_idx
}
}
pub struct MetadataBuilder {
buffer_lens: Vec<InputScalar>,
lengths: Vec<InputScalar>,
ranks: Vec<InputScalar>,
shapes: Vec<Vec<InputScalar>>,
strides: Vec<Vec<InputScalar>>,
address_type: StorageType,
}
impl MetadataBuilder {
pub fn new(address_type: StorageType) -> Self {
Self {
buffer_lens: Default::default(),
lengths: Default::default(),
ranks: Default::default(),
shapes: Default::default(),
strides: Default::default(),
address_type,
}
}
pub fn with_array(&mut self, buffer_len: u64, len: u64) {
self.buffer_lens
.push(InputScalar::new(buffer_len, self.address_type));
self.lengths.push(InputScalar::new(len, self.address_type));
}
pub fn with_tensor(
&mut self,
rank: u64,
buffer_len: u64,
len: u64,
shape: Vec<u64>,
strides: Vec<u64>,
) {
self.buffer_lens
.push(InputScalar::new(buffer_len, self.address_type));
self.lengths.push(InputScalar::new(len, self.address_type));
self.ranks.push(InputScalar::new(rank, self.address_type));
self.shapes.push(
shape
.into_iter()
.map(|s| InputScalar::new(s, self.address_type))
.collect(),
);
self.strides.push(
strides
.into_iter()
.map(|s| InputScalar::new(s, self.address_type))
.collect(),
);
}
pub fn finish(self) -> MetadataBinding {
let addr_size = self.address_type.size();
let mut meta = self
.buffer_lens
.iter()
.flat_map(|it| it.as_bytes())
.collect::<Vec<_>>();
meta.extend(self.lengths.iter().flat_map(|it| it.as_bytes()));
meta.extend(self.ranks.iter().flat_map(|it| it.as_bytes()));
let num_ext = self.ranks.len();
let mut shape_offsets = Vec::with_capacity(num_ext * addr_size);
let mut stride_offsets = Vec::with_capacity(num_ext * addr_size);
let mut current_offset = meta.len() / addr_size + num_ext * 2;
for shape in self.shapes.iter() {
let offset = InputScalar::new(current_offset, self.address_type);
shape_offsets.extend(offset.as_bytes());
current_offset += shape.len();
}
meta.extend(shape_offsets);
for stride in self.strides.iter() {
let offset = InputScalar::new(current_offset, self.address_type);
stride_offsets.extend(offset.as_bytes());
current_offset += stride.len();
}
meta.extend(stride_offsets);
let static_len = meta.len() / addr_size;
meta.extend(self.shapes.iter().flatten().flat_map(|it| it.as_bytes()));
meta.extend(
self.strides
.into_iter()
.flatten()
.flat_map(|it| it.as_bytes()),
);
let total_size_64 = meta.len().div_ceil(size_of::<u64>());
let mut meta_64 = vec![0u64; total_size_64];
cast_slice_mut::<u64, u8>(&mut meta_64)[..meta.len()].copy_from_slice(&meta);
MetadataBinding::new(meta_64, static_len)
}
}