use crate::{
memory_management::{
MemoryHandle, MemoryUsage,
memory_pool::{SliceBinding, SliceHandle},
},
storage::{BindingResource, ComputeStorage},
tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
};
use alloc::collections::BTreeMap;
use alloc::vec::Vec;
use core::{fmt::Debug, future::Future};
use cubecl_common::{ExecutionMode, benchmark::ProfileDuration};
use cubecl_ir::Elem;
pub trait ComputeServer: Send + core::fmt::Debug
where
Self: Sized,
{
type Kernel: Send;
type Info: Debug + Send + Sync;
type Storage: ComputeStorage;
type Feature: Ord + Copy + Debug + Send + Sync;
fn read(
&mut self,
bindings: Vec<Binding>,
) -> impl Future<Output = Vec<Vec<u8>>> + Send + 'static;
fn read_tensor(
&mut self,
bindings: Vec<BindingWithMeta>,
) -> impl Future<Output = Vec<Vec<u8>>> + Send + 'static;
fn get_resource(
&mut self,
binding: Binding,
) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
fn create(&mut self, data: &[u8]) -> Handle;
fn create_tensor(
&mut self,
data: &[u8],
shape: &[usize],
elem_size: usize,
) -> (Handle, Vec<usize>);
fn empty(&mut self, size: usize) -> Handle;
fn empty_tensor(&mut self, shape: &[usize], elem_size: usize) -> (Handle, Vec<usize>);
unsafe fn execute(
&mut self,
kernel: Self::Kernel,
count: CubeCount,
bindings: Bindings,
kind: ExecutionMode,
);
fn flush(&mut self);
fn sync(&mut self) -> impl Future<Output = ()> + Send + 'static;
fn memory_usage(&self) -> MemoryUsage;
fn memory_cleanup(&mut self);
fn start_profile(&mut self);
fn end_profile(&mut self) -> ProfileDuration;
}
#[derive(new, Debug)]
pub struct Handle {
pub memory: SliceHandle,
pub offset_start: Option<u64>,
pub offset_end: Option<u64>,
size: u64,
}
impl Handle {
pub fn offset_start(mut self, offset: u64) -> Self {
if let Some(val) = &mut self.offset_start {
*val += offset;
} else {
self.offset_start = Some(offset);
}
self
}
pub fn offset_end(mut self, offset: u64) -> Self {
if let Some(val) = &mut self.offset_end {
*val += offset;
} else {
self.offset_end = Some(offset);
}
self
}
pub fn size(&self) -> u64 {
self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
}
}
#[derive(Debug, Default)]
pub struct Bindings {
pub buffers: Vec<Binding>,
pub metadata: MetadataBinding,
pub scalars: BTreeMap<Elem, ScalarBinding>,
pub tensor_maps: Vec<TensorMapBinding>,
}
impl Bindings {
pub fn new() -> Self {
Self::default()
}
pub fn with_buffer(mut self, binding: Binding) -> Self {
self.buffers.push(binding);
self
}
pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
self.buffers.extend(bindings);
self
}
pub fn with_scalar(mut self, elem: Elem, length: usize, data: Vec<u64>) -> Self {
self.scalars
.insert(elem, ScalarBinding::new(elem, length, data));
self
}
pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
self.scalars
.extend(bindings.into_iter().map(|binding| (binding.elem, binding)));
self
}
pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
self.metadata = meta;
self
}
pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
self.tensor_maps.extend(bindings);
self
}
}
#[derive(new, Debug, Default)]
pub struct MetadataBinding {
pub data: Vec<u32>,
pub static_len: usize,
}
#[derive(new, Debug)]
pub struct ScalarBinding {
pub elem: Elem,
pub length: usize,
pub data: Vec<u64>,
}
impl ScalarBinding {
pub fn data(&self) -> &[u8] {
bytemuck::cast_slice(&self.data)
}
}
#[derive(new, Debug)]
pub struct Binding {
pub memory: SliceBinding,
pub offset_start: Option<u64>,
pub offset_end: Option<u64>,
}
#[derive(new, Debug)]
pub struct BindingWithMeta {
pub binding: Binding,
pub shape: Vec<usize>,
pub strides: Vec<usize>,
pub elem_size: usize,
}
#[derive(new, Debug, Clone)]
pub struct TensorMapBinding {
pub binding: Binding,
pub map: TensorMapMeta,
}
#[derive(Debug, Clone)]
pub struct TensorMapMeta {
pub format: TensorMapFormat,
pub rank: usize,
pub shape: Vec<usize>,
pub strides: Vec<usize>,
pub elem_stride: Vec<usize>,
pub interleave: TensorMapInterleave,
pub swizzle: TensorMapSwizzle,
pub prefetch: TensorMapPrefetch,
pub oob_fill: OobFill,
pub elem: Elem,
}
impl Handle {
pub fn can_mut(&self) -> bool {
self.memory.can_mut()
}
}
impl Handle {
pub fn binding(self) -> Binding {
Binding {
memory: MemoryHandle::binding(self.memory),
offset_start: self.offset_start,
offset_end: self.offset_end,
}
}
pub fn binding_with_meta(
self,
shape: Vec<usize>,
strides: Vec<usize>,
elem_size: usize,
) -> BindingWithMeta {
BindingWithMeta {
shape,
strides,
elem_size,
binding: self.binding(),
}
}
}
impl Clone for Handle {
fn clone(&self) -> Self {
Self {
memory: self.memory.clone(),
offset_start: self.offset_start,
offset_end: self.offset_end,
size: self.size,
}
}
}
impl Clone for Binding {
fn clone(&self) -> Self {
Self {
memory: self.memory.clone(),
offset_start: self.offset_start,
offset_end: self.offset_end,
}
}
}
#[allow(clippy::large_enum_variant)]
pub enum CubeCount {
Static(u32, u32, u32),
Dynamic(Binding),
}
impl CubeCount {
pub fn new_single() -> Self {
CubeCount::Static(1, 1, 1)
}
pub fn new_1d(x: u32) -> Self {
CubeCount::Static(x, 1, 1)
}
pub fn new_2d(x: u32, y: u32) -> Self {
CubeCount::Static(x, y, 1)
}
pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
CubeCount::Static(x, y, z)
}
}
impl Debug for CubeCount {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
CubeCount::Dynamic(_) => f.write_str("binding"),
}
}
}
impl Clone for CubeCount {
fn clone(&self) -> Self {
match self {
Self::Static(x, y, z) => Self::Static(*x, *y, *z),
Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
}
}
}