use crate::{
client::ComputeClient,
compiler::CompilationError,
kernel::KernelMetadata,
logging::ServerLogger,
memory_management::{
MemoryAllocationMode, MemoryHandle, MemoryUsage,
memory_pool::{SliceBinding, SliceHandle},
},
runtime::Runtime,
storage::{BindingResource, ComputeStorage},
tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
};
use alloc::collections::BTreeMap;
use alloc::string::String;
use alloc::sync::Arc;
use alloc::vec;
use alloc::vec::Vec;
use core::fmt::Debug;
use cubecl_common::{
backtrace::BackTrace, bytes::Bytes, device, future::DynFut, profile::ProfileDuration,
stream_id::StreamId,
};
use cubecl_ir::{DeviceProperties, StorageType};
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Error, Clone)]
pub enum ProfileError {
#[error(
"An unknown error happened during profiling\nCaused by:\n {reason}\nBacktrace:\n{backtrace}"
)]
Unknown {
reason: String,
backtrace: BackTrace,
},
#[error("No profiling registered\nBacktrace:\n{backtrace}")]
NotRegistered {
backtrace: BackTrace,
},
#[error("A launch error happened during profiling\nCaused by:\n {0}")]
Launch(#[from] LaunchError),
#[error("An execution error happened during profiling\nCaused by:\n {0}")]
Execution(#[from] ExecutionError),
}
impl core::fmt::Debug for ProfileError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("{self}"))
}
}
pub struct ServerUtilities<Server: ComputeServer> {
#[cfg(feature = "profile-tracy")]
pub epoch_time: web_time::Instant,
#[cfg(feature = "profile-tracy")]
pub gpu_client: tracy_client::GpuContext,
pub properties: DeviceProperties,
pub info: Server::Info,
pub logger: Arc<ServerLogger>,
}
impl<Server: core::fmt::Debug> core::fmt::Debug for ServerUtilities<Server>
where
Server: ComputeServer,
Server::Info: core::fmt::Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
f.debug_struct("ServerUtilities")
.field("properties", &self.properties)
.field("info", &self.info)
.field("logger", &self.logger)
.finish()
}
}
impl<S: ComputeServer> ServerUtilities<S> {
pub fn new(properties: DeviceProperties, logger: Arc<ServerLogger>, info: S::Info) -> Self {
#[cfg(feature = "profile-tracy")]
let client = tracy_client::Client::start();
Self {
properties,
logger,
#[cfg(feature = "profile-tracy")]
gpu_client: client
.clone()
.new_gpu_context(
Some(&format!("{info:?}")),
tracy_client::GpuContextType::Invalid,
0, 1.0, )
.unwrap(),
#[cfg(feature = "profile-tracy")]
epoch_time: web_time::Instant::now(),
info,
}
}
}
#[derive(Error, Clone)]
#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
pub enum LaunchError {
#[error("A compilation error happened during launch\nCaused by:\n {0}")]
CompilationError(#[from] CompilationError),
#[error(
"An out-of-memory error happened during launch\nCaused by:\n {reason}\nBacktrace\n{backtrace}"
)]
OutOfMemory {
reason: String,
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error(
"An unknown error happened during launch\nCaused by:\n {reason}\nBacktrace\n{backtrace}"
)]
Unknown {
reason: String,
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error("An io error happened during launch\nCaused by:\n {0}")]
IoError(#[from] IoError),
}
impl core::fmt::Debug for LaunchError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("{self}"))
}
}
#[derive(Error, Debug, Clone)]
#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
pub enum ExecutionError {
#[error("An error happened during execution\nCaused by:\n {reason}\nBacktrace:\n{backtrace}")]
Generic {
reason: String,
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
}
pub trait ComputeServer:
Send + core::fmt::Debug + ServerCommunication + device::DeviceState + 'static
where
Self: Sized,
{
type Kernel: KernelMetadata;
type Info: Debug + Send + Sync;
type Storage: ComputeStorage;
fn create(
&mut self,
descriptors: Vec<AllocationDescriptor<'_>>,
stream_id: StreamId,
) -> Result<Vec<Allocation>, IoError>;
fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
Err(IoError::UnsupportedIoOperation {
backtrace: BackTrace::capture(),
})
}
fn logger(&self) -> Arc<ServerLogger>;
fn utilities(&self) -> Arc<ServerUtilities<Self>>;
fn create_with_data(&mut self, data: &[u8], stream_id: StreamId) -> Result<Handle, IoError> {
let alloc = self
.create(
vec![AllocationDescriptor::new(
AllocationKind::Contiguous,
&[data.len()],
1,
)],
stream_id,
)?
.remove(0);
self.write(
vec![(
CopyDescriptor::new(
alloc.handle.clone().binding(),
&[data.len()],
&alloc.strides,
1,
),
Bytes::from_bytes_vec(data.to_vec()),
)],
stream_id,
)?;
Ok(alloc.handle)
}
fn create_with_bytes(&mut self, data: Bytes, stream_id: StreamId) -> Result<Handle, IoError> {
let alloc = self
.create(
vec![AllocationDescriptor::new(
AllocationKind::Contiguous,
&[data.len()],
1,
)],
stream_id,
)?
.remove(0);
self.write(
vec![(
CopyDescriptor::new(
alloc.handle.clone().binding(),
&[data.len()],
&alloc.strides,
1,
),
data,
)],
stream_id,
)?;
Ok(alloc.handle)
}
fn read<'a>(
&mut self,
descriptors: Vec<CopyDescriptor<'a>>,
stream_id: StreamId,
) -> DynFut<Result<Vec<Bytes>, IoError>>;
fn write(
&mut self,
descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
stream_id: StreamId,
) -> Result<(), IoError>;
fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>>;
fn get_resource(
&mut self,
binding: Binding,
stream_id: StreamId,
) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
unsafe fn launch(
&mut self,
kernel: Self::Kernel,
count: CubeCount,
bindings: Bindings,
kind: ExecutionMode,
stream_id: StreamId,
) -> Result<(), LaunchError>;
fn flush(&mut self, stream_id: StreamId);
fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage;
fn memory_cleanup(&mut self, stream_id: StreamId);
fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken;
fn end_profile(
&mut self,
stream_id: StreamId,
token: ProfilingToken,
) -> Result<ProfileDuration, ProfileError>;
fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
}
pub trait ServerCommunication {
const SERVER_COMM_ENABLED: bool;
#[allow(unused_variables)]
fn copy(
server_src: &mut Self,
server_dst: &mut Self,
src: CopyDescriptor<'_>,
stream_id_src: StreamId,
stream_id_dst: StreamId,
) -> Result<Allocation, IoError> {
if !Self::SERVER_COMM_ENABLED {
panic!("Server-to-server communication is not supported by this server.");
} else {
panic!(
"[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
);
}
}
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct ProfilingToken {
pub id: u64,
}
#[derive(new, Debug, PartialEq, Eq)]
pub struct Handle {
pub memory: SliceHandle,
pub offset_start: Option<u64>,
pub offset_end: Option<u64>,
pub stream: cubecl_common::stream_id::StreamId,
pub cursor: u64,
size: u64,
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub enum AllocationKind {
Contiguous,
Optimized,
}
#[derive(new, Debug, Clone, Copy)]
pub struct AllocationDescriptor<'a> {
pub kind: AllocationKind,
pub shape: &'a [usize],
pub elem_size: usize,
}
impl<'a> AllocationDescriptor<'a> {
pub fn optimized(shape: &'a [usize], elem_size: usize) -> Self {
AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size)
}
pub fn contiguous(shape: &'a [usize], elem_size: usize) -> Self {
AllocationDescriptor::new(AllocationKind::Contiguous, shape, elem_size)
}
}
#[derive(new, Debug)]
pub struct Allocation {
pub handle: Handle,
pub strides: Vec<usize>,
}
#[derive(Error, Clone)]
#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
pub enum IoError {
#[error("can't allocate buffer of size: {size}\n{backtrace}")]
BufferTooBig {
size: u64,
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error("the provided strides are not supported for this operation\n{backtrace}")]
UnsupportedStrides {
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error("couldn't find resource for that handle\n{backtrace}")]
InvalidHandle {
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error("Unknown error happened during execution\n{backtrace}")]
Unknown {
description: String,
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error("The current IO operation is not supported\n{backtrace}")]
UnsupportedIoOperation {
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error("Can't perform the IO operation because of a runtime error")]
Execution(#[from] ExecutionError),
}
impl core::fmt::Debug for IoError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("{self}"))
}
}
impl Handle {
pub fn offset_start(mut self, offset: u64) -> Self {
if let Some(val) = &mut self.offset_start {
*val += offset;
} else {
self.offset_start = Some(offset);
}
self
}
pub fn offset_end(mut self, offset: u64) -> Self {
if let Some(val) = &mut self.offset_end {
*val += offset;
} else {
self.offset_end = Some(offset);
}
self
}
pub fn size(&self) -> u64 {
self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
}
}
#[derive(Debug, Default)]
pub struct Bindings {
pub buffers: Vec<Binding>,
pub metadata: MetadataBinding,
pub scalars: BTreeMap<StorageType, ScalarBinding>,
pub tensor_maps: Vec<TensorMapBinding>,
}
impl Bindings {
pub fn new() -> Self {
Self::default()
}
pub fn with_buffer(mut self, binding: Binding) -> Self {
self.buffers.push(binding);
self
}
pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
self.buffers.extend(bindings);
self
}
pub fn with_scalar(mut self, ty: StorageType, length: usize, data: Vec<u64>) -> Self {
self.scalars
.insert(ty, ScalarBinding::new(ty, length, data));
self
}
pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
self.scalars
.extend(bindings.into_iter().map(|binding| (binding.ty, binding)));
self
}
pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
self.metadata = meta;
self
}
pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
self.tensor_maps.extend(bindings);
self
}
}
#[derive(new, Debug, Default)]
pub struct MetadataBinding {
pub data: Vec<u64>,
pub static_len: usize,
}
#[derive(new, Debug, Clone)]
pub struct ScalarBinding {
pub ty: StorageType,
pub length: usize,
pub data: Vec<u64>,
}
impl ScalarBinding {
pub fn data(&self) -> &[u8] {
bytemuck::cast_slice(&self.data)
}
}
#[derive(new, Debug)]
pub struct Binding {
pub memory: SliceBinding,
pub offset_start: Option<u64>,
pub offset_end: Option<u64>,
pub stream: cubecl_common::stream_id::StreamId,
pub cursor: u64,
size: u64,
}
impl Binding {
pub fn size(&self) -> u64 {
self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
}
}
#[derive(new, Debug, Clone)]
pub struct CopyDescriptor<'a> {
pub binding: Binding,
pub shape: &'a [usize],
pub strides: &'a [usize],
pub elem_size: usize,
}
#[derive(new, Debug, Clone)]
pub struct TensorMapBinding {
pub binding: Binding,
pub map: TensorMapMeta,
}
#[derive(Debug, Clone)]
pub struct TensorMapMeta {
pub format: TensorMapFormat,
pub rank: usize,
pub shape: Vec<usize>,
pub strides: Vec<usize>,
pub elem_stride: Vec<usize>,
pub interleave: TensorMapInterleave,
pub swizzle: TensorMapSwizzle,
pub prefetch: TensorMapPrefetch,
pub oob_fill: OobFill,
pub storage_ty: StorageType,
}
impl Handle {
pub fn can_mut(&self) -> bool {
self.memory.can_mut() && self.stream == StreamId::current()
}
}
impl Handle {
pub fn binding(self) -> Binding {
Binding {
memory: MemoryHandle::binding(self.memory),
offset_start: self.offset_start,
offset_end: self.offset_end,
size: self.size,
stream: self.stream,
cursor: self.cursor,
}
}
pub fn copy_descriptor<'a>(
&'a self,
shape: &'a [usize],
strides: &'a [usize],
elem_size: usize,
) -> CopyDescriptor<'a> {
CopyDescriptor {
shape,
strides,
elem_size,
binding: self.clone().binding(),
}
}
}
impl Clone for Handle {
fn clone(&self) -> Self {
Self {
memory: self.memory.clone(),
offset_start: self.offset_start,
offset_end: self.offset_end,
size: self.size,
stream: self.stream,
cursor: self.cursor,
}
}
}
impl Clone for Binding {
fn clone(&self) -> Self {
Self {
memory: self.memory.clone(),
offset_start: self.offset_start,
offset_end: self.offset_end,
size: self.size,
stream: self.stream,
cursor: self.cursor,
}
}
}
#[allow(clippy::large_enum_variant)]
pub enum CubeCount {
Static(u32, u32, u32),
Dynamic(Binding),
}
pub enum CubeCountSelection {
Exact(CubeCount),
Approx(CubeCount, u32),
}
impl CubeCountSelection {
pub fn new<R: Runtime>(client: &ComputeClient<R>, num_cubes: u32) -> Self {
let cube_count = cube_count_spread(&client.properties().hardware.max_cube_count, num_cubes);
let num_cubes_actual = cube_count[0] * cube_count[1] * cube_count[2];
let cube_count = CubeCount::Static(cube_count[0], cube_count[1], cube_count[2]);
match num_cubes_actual == num_cubes {
true => CubeCountSelection::Exact(cube_count),
false => CubeCountSelection::Approx(cube_count, num_cubes_actual),
}
}
pub fn has_idle(&self) -> bool {
matches!(self, Self::Approx(..))
}
pub fn cube_count(self) -> CubeCount {
match self {
CubeCountSelection::Exact(cube_count) => cube_count,
CubeCountSelection::Approx(cube_count, _) => cube_count,
}
}
}
impl From<CubeCountSelection> for CubeCount {
fn from(value: CubeCountSelection) -> Self {
value.cube_count()
}
}
impl CubeCount {
pub fn new_single() -> Self {
CubeCount::Static(1, 1, 1)
}
pub fn new_1d(x: u32) -> Self {
CubeCount::Static(x, 1, 1)
}
pub fn new_2d(x: u32, y: u32) -> Self {
CubeCount::Static(x, y, 1)
}
pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
CubeCount::Static(x, y, z)
}
}
impl Debug for CubeCount {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
CubeCount::Dynamic(_) => f.write_str("binding"),
}
}
}
impl Clone for CubeCount {
fn clone(&self) -> Self {
match self {
Self::Static(x, y, z) => Self::Static(*x, *y, *z),
Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, serde::Serialize, serde::Deserialize)]
#[allow(missing_docs)]
pub struct CubeDim {
pub x: u32,
pub y: u32,
pub z: u32,
}
impl CubeDim {
pub fn new<R: Runtime>(client: &ComputeClient<R>, working_units: usize) -> Self {
let properties = client.properties();
let plane_size = properties.hardware.plane_size_max;
let plane_count = Self::calculate_plane_count_per_cube(
working_units as u32,
plane_size,
properties.hardware.num_cpu_cores,
);
let limit = properties.hardware.max_units_per_cube / plane_size;
Self::new_2d(plane_size, u32::min(limit, plane_count))
}
fn calculate_plane_count_per_cube(
working_units: u32,
plane_dim: u32,
num_cpu_cores: Option<u32>,
) -> u32 {
match num_cpu_cores {
Some(num_cores) => core::cmp::min(num_cores, working_units),
None => {
let plane_count_max = core::cmp::max(1, working_units / plane_dim);
const NUM_PLANE_MAX: u32 = 8u32;
const NUM_PLANE_MAX_LOG2: u32 = NUM_PLANE_MAX.ilog2();
let plane_count_max_log2 =
core::cmp::min(NUM_PLANE_MAX_LOG2, u32::ilog2(plane_count_max));
2u32.pow(plane_count_max_log2)
}
}
}
pub const fn new_single() -> Self {
Self { x: 1, y: 1, z: 1 }
}
pub const fn new_1d(x: u32) -> Self {
Self { x, y: 1, z: 1 }
}
pub const fn new_2d(x: u32, y: u32) -> Self {
Self { x, y, z: 1 }
}
pub const fn new_3d(x: u32, y: u32, z: u32) -> Self {
Self { x, y, z }
}
pub const fn num_elems(&self) -> u32 {
self.x * self.y * self.z
}
pub const fn can_contain(&self, other: CubeDim) -> bool {
self.x >= other.x && self.y >= other.y && self.z >= other.z
}
}
#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy, Serialize, Deserialize)]
pub enum ExecutionMode {
#[default]
Checked,
Unchecked,
}
fn cube_count_spread(max: &(u32, u32, u32), num_cubes: u32) -> [u32; 3] {
let max_cube_counts = [max.0, max.1, max.2];
let mut num_cubes = [num_cubes, 1, 1];
let base = 2;
let mut reduce_count = |i: usize| {
if num_cubes[i] <= max_cube_counts[i] {
return true;
}
loop {
num_cubes[i] = num_cubes[i].div_ceil(base);
num_cubes[i + 1] *= base;
if num_cubes[i] <= max_cube_counts[i] {
return false;
}
}
};
for i in 0..2 {
if reduce_count(i) {
break;
}
}
num_cubes
}
#[cfg(test)]
mod tests {
use super::*;
#[test_log::test]
fn safe_num_cubes_even() {
let max = (32, 32, 32);
let required = 2048;
let actual = cube_count_spread(&max, required);
let expected = [32, 32, 2];
assert_eq!(actual, expected);
}
#[test_log::test]
fn safe_num_cubes_odd() {
let max = (48, 32, 16);
let required = 3177;
let actual = cube_count_spread(&max, required);
let expected = [25, 32, 4];
assert_eq!(actual, expected);
}
}