use super::Handle;
use crate::{
client::ComputeClient,
compiler::CompilationError,
config::{GlobalConfig, compilation::BoundsCheckMode},
kernel::KernelMetadata,
logging::ServerLogger,
memory_management::{ManagedMemoryHandle, MemoryAllocationMode, MemoryUsage},
runtime::Runtime,
server::Binding,
storage::{ComputeStorage, ManagedResource},
tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
};
use alloc::boxed::Box;
#[cfg(feature = "profile-tracy")]
use alloc::format;
use alloc::string::String;
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::fmt::Debug;
use cubecl_common::{
backtrace::BackTrace,
bytes::Bytes,
device::{self, DeviceId},
future::DynFut,
profile::ProfileDuration,
stream_id::StreamId,
};
use cubecl_ir::{DeviceProperties, ElemType, StorageType};
use cubecl_zspace::{Shape, Strides, metadata::Metadata};
use thiserror::Error;
#[derive(Error, Clone)]
#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
pub enum ProfileError {
#[error(
"An unknown error happened during profiling\nCaused by:\n {reason}\nBacktrace:\n{backtrace}"
)]
Unknown {
reason: String,
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error("No profiling registered\nBacktrace:\n{backtrace}")]
NotRegistered {
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error("A launch error happened during profiling\nCaused by:\n {0}")]
Launch(#[from] LaunchError),
#[error("An execution error happened during profiling\nCaused by:\n {0}")]
Server(#[from] Box<ServerError>),
}
impl core::fmt::Debug for ProfileError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("{self}"))
}
}
pub struct ServerUtilities<Server: ComputeServer> {
#[cfg(feature = "profile-tracy")]
pub epoch_time: web_time::Instant,
#[cfg(feature = "profile-tracy")]
pub gpu_client: tracy_client::GpuContext,
pub properties: DeviceProperties,
pub properties_hash: u64,
pub info: Server::Info,
pub logger: Arc<ServerLogger>,
pub layout_policy: Server::MemoryLayoutPolicy,
pub check_mode: BoundsCheckMode,
}
pub trait MemoryLayoutPolicy: Send + Sync + 'static {
fn apply(
&self,
stream_id: StreamId,
descriptors: &[MemoryLayoutDescriptor],
) -> (Handle, Vec<MemoryLayout>);
}
impl<Server: core::fmt::Debug> core::fmt::Debug for ServerUtilities<Server>
where
Server: ComputeServer,
Server::Info: core::fmt::Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
f.debug_struct("ServerUtilities")
.field("properties", &self.properties)
.field("info", &self.info)
.field("logger", &self.logger)
.finish()
}
}
impl<S: ComputeServer> ServerUtilities<S> {
pub fn new(
properties: DeviceProperties,
logger: Arc<ServerLogger>,
info: S::Info,
allocator: S::MemoryLayoutPolicy,
) -> Self {
#[cfg(feature = "profile-tracy")]
let client = tracy_client::Client::start();
Self {
properties_hash: properties.checksum(),
properties,
logger,
#[cfg(feature = "profile-tracy")]
gpu_client: client
.clone()
.new_gpu_context(
Some(&format!("{info:?}")),
tracy_client::GpuContextType::Invalid,
0, 1.0, )
.unwrap(),
#[cfg(feature = "profile-tracy")]
epoch_time: web_time::Instant::now(),
info,
layout_policy: allocator,
check_mode: GlobalConfig::get().compilation.check_mode,
}
}
}
#[derive(Error, Clone)]
#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
pub enum LaunchError {
#[error("A compilation error happened during launch\nCaused by:\n {0}")]
CompilationError(#[from] CompilationError),
#[error(
"An out-of-memory error happened during launch\nCaused by:\n {reason}\nBacktrace\n{backtrace}"
)]
OutOfMemory {
reason: String,
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error("Too many resources were requested during launch\n{0}")]
TooManyResources(#[from] ResourceLimitError),
#[error(
"An unknown error happened during launch\nCaused by:\n {reason}\nBacktrace\n{backtrace}"
)]
Unknown {
reason: String,
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error("An io error happened during launch\nCaused by:\n {0}")]
IoError(#[from] IoError),
}
#[derive(Error, Clone)]
#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
pub enum ResourceLimitError {
#[error(
"Too much shared memory requested.\nRequested {requested} bytes, maximum {max} bytes available.\nBacktrace\n{backtrace}"
)]
SharedMemory {
requested: usize,
max: usize,
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error(
"Total unit count exceeds maximum.\nRequested {requested} units, max units is {max}.\nBacktrace\n{backtrace}"
)]
Units {
requested: u32,
max: u32,
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error(
"Cube dim exceeds maximum bounds.\nRequested {requested:?}, max is {max:?}.\nBacktrace\n{backtrace}"
)]
CubeDim {
requested: (u32, u32, u32),
max: (u32, u32, u32),
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
}
impl core::fmt::Debug for LaunchError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("{self}"))
}
}
impl core::fmt::Debug for ResourceLimitError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("{self}"))
}
}
#[derive(Error, Debug, Clone)]
#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
pub enum ServerError {
#[error("An error happened during execution\nCaused by:\n {reason}\nBacktrace:\n{backtrace}")]
Generic {
reason: String,
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error("A launch error happened during profiling\nCaused by:\n {0}")]
Launch(#[from] LaunchError),
#[error("An execution error happened during profiling\nCaused by:\n {0}")]
Profile(#[from] ProfileError),
#[error("An execution error happened during profiling\nCaused by:\n {0}")]
Io(#[from] IoError),
#[error("The server is in an invalid state\nCaused by:\n {errors:?}")]
ServerUnhealthy {
errors: Vec<Self>,
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
}
#[derive(Clone, Copy)]
pub struct StreamErrorMode {
pub ignore: bool,
pub flush: bool,
}
pub trait ComputeServer:
Send + core::fmt::Debug + ServerCommunication + device::DeviceService + 'static
where
Self: Sized,
{
type Kernel: KernelMetadata;
type Info: Debug + Send + Sync;
type MemoryLayoutPolicy: MemoryLayoutPolicy;
type Storage: ComputeStorage;
fn initialize_memory(&mut self, memory: ManagedMemoryHandle, size: u64, stream_id: StreamId);
fn staging(
&mut self,
_sizes: &[usize],
_stream_id: StreamId,
) -> Result<Vec<Bytes>, ServerError> {
Err(IoError::UnsupportedIoOperation {
backtrace: BackTrace::capture(),
}
.into())
}
fn logger(&self) -> Arc<ServerLogger>;
fn utilities(&self) -> Arc<ServerUtilities<Self>>;
fn read(
&mut self,
descriptors: Vec<CopyDescriptor>,
stream_id: StreamId,
) -> DynFut<Result<Vec<Bytes>, ServerError>>;
fn write(&mut self, descriptors: Vec<(CopyDescriptor, Bytes)>, stream_id: StreamId);
fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ServerError>>;
fn get_resource(
&mut self,
binding: Binding,
stream_id: StreamId,
) -> Result<ManagedResource<<Self::Storage as ComputeStorage>::Resource>, ServerError>;
unsafe fn launch(
&mut self,
kernel: Self::Kernel,
count: CubeCount,
bindings: KernelArguments,
kind: ExecutionMode,
stream_id: StreamId,
);
fn flush(&mut self, stream_id: StreamId) -> Result<(), ServerError>;
fn memory_usage(&mut self, stream_id: StreamId) -> Result<MemoryUsage, ServerError>;
fn memory_cleanup(&mut self, stream_id: StreamId);
fn start_profile(&mut self, stream_id: StreamId) -> Result<ProfilingToken, ServerError>;
fn end_profile(
&mut self,
stream_id: StreamId,
token: ProfilingToken,
) -> Result<ProfileDuration, ProfileError>;
fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
}
pub enum ReduceOperation {
Sum,
Mean,
}
pub trait ServerCommunication {
const SERVER_COMM_ENABLED: bool;
#[allow(unused_variables)]
fn sync_collective(&mut self, stream_id: StreamId) -> Result<(), ServerError> {
todo!() }
#[allow(unused_variables)]
fn all_reduce(
&mut self,
src: Binding,
dst: Binding,
dtype: ElemType,
stream_id: StreamId,
op: ReduceOperation,
device_ids: Vec<DeviceId>,
) -> Result<(), ServerError> {
unimplemented!()
}
#[allow(unused_variables)]
fn copy(
handle_dst: Handle,
server_src: &mut Self,
server_dst: &mut Self,
src: CopyDescriptor,
stream_id_src: StreamId,
stream_id_dst: StreamId,
) -> Result<(), ServerError> {
if !Self::SERVER_COMM_ENABLED {
panic!("Server-to-server communication is not supported by this server.");
} else {
panic!(
"[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
);
}
}
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct ProfilingToken {
pub id: u64,
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub enum MemoryLayoutStrategy {
Contiguous,
Optimized,
}
#[derive(new, Debug, Clone)]
pub struct MemoryLayoutDescriptor {
pub strategy: MemoryLayoutStrategy,
pub shape: Shape,
pub elem_size: usize,
}
impl MemoryLayoutDescriptor {
pub fn optimized(shape: Shape, elem_size: usize) -> Self {
MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Optimized, shape, elem_size)
}
pub fn contiguous(shape: Shape, elem_size: usize) -> Self {
MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Contiguous, shape, elem_size)
}
}
#[derive(Debug, Clone)]
pub struct MemoryLayout {
pub memory: Handle,
pub strides: Strides,
}
impl MemoryLayout {
pub fn new(handle: Handle, strides: impl Into<Strides>) -> Self {
MemoryLayout {
memory: handle,
strides: strides.into(),
}
}
}
#[derive(Default, Clone)]
pub struct Reason {
inner: ReasonInner,
}
#[cfg(std_io)]
mod _reason_serde {
use super::*;
use alloc::string::ToString;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
impl Serialize for Reason {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl<'de> Deserialize<'de> for Reason {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Ok(Reason {
inner: ReasonInner::Dynamic(Arc::new(s)),
})
}
}
}
#[derive(Default, Clone)]
enum ReasonInner {
Static(&'static str),
Dynamic(Arc<String>),
#[default]
NotProvided,
}
impl core::fmt::Display for Reason {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match &self.inner {
ReasonInner::Static(content) => f.write_str(content),
ReasonInner::Dynamic(content) => f.write_str(content),
ReasonInner::NotProvided => f.write_str("No reason provided for the error"),
}
}
}
impl core::fmt::Debug for Reason {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Display::fmt(&self, f)
}
}
impl From<&'static str> for Reason {
fn from(value: &'static str) -> Self {
Self {
inner: ReasonInner::Static(value),
}
}
}
impl From<String> for Reason {
fn from(value: String) -> Self {
Self {
inner: ReasonInner::Dynamic(Arc::new(value)),
}
}
}
#[derive(Error, Clone)]
#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
pub enum IoError {
#[error("can't allocate buffer of size: {size}\n{backtrace}")]
BufferTooBig {
size: u64,
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error("the provided strides are not supported for this operation\n{backtrace}")]
UnsupportedStrides {
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error("couldn't find resource for that handle: {reason}\n{backtrace}")]
NotFound {
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
reason: Reason,
},
#[error("couldn't free the handle, since it is currently in used. \n{backtrace}")]
FreeError {
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error("Unknown error happened during execution\n{backtrace}")]
Unknown {
description: String,
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error("The current IO operation is not supported\n{backtrace}")]
UnsupportedIoOperation {
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error("Can't perform the IO operation because of a runtime error: {0}")]
Execution(#[from] Box<ServerError>),
}
impl core::fmt::Debug for IoError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("{self}"))
}
}
#[derive(Debug, Default)]
pub struct KernelArguments {
pub buffers: Vec<Binding>,
pub info: MetadataBindingInfo,
pub tensor_maps: Vec<TensorMapBinding>,
}
impl core::fmt::Display for KernelArguments {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("KernelArguments")?;
for b in self.buffers.iter() {
f.write_fmt(format_args!("\n - buffer: {b:?}\n"))?;
}
Ok(())
}
}
impl KernelArguments {
pub fn new() -> Self {
Self::default()
}
pub fn with_buffer(mut self, binding: Binding) -> Self {
self.buffers.push(binding);
self
}
pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
self.buffers.extend(bindings);
self
}
pub fn with_info(mut self, info: MetadataBindingInfo) -> Self {
self.info = info;
self
}
pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
self.tensor_maps.extend(bindings);
self
}
}
#[derive(new, Debug, Default)]
pub struct MetadataBindingInfo {
pub data: Vec<u64>,
pub dynamic_metadata_offset: usize,
}
impl MetadataBindingInfo {
pub fn custom(data: Vec<u64>) -> Self {
Self::new(data, 0)
}
}
#[derive(new, Debug)]
pub struct CopyDescriptor {
pub handle: Binding,
pub shape: Shape,
pub strides: Strides,
pub elem_size: usize,
}
#[derive(new, Debug)]
pub struct TensorMapBinding {
pub binding: Binding,
pub map: TensorMapMeta,
}
#[derive(Debug, Clone)]
pub struct TensorMapMeta {
pub format: TensorMapFormat,
pub metadata: Metadata,
pub elem_stride: Strides,
pub interleave: TensorMapInterleave,
pub swizzle: TensorMapSwizzle,
pub prefetch: TensorMapPrefetch,
pub oob_fill: OobFill,
pub storage_ty: StorageType,
}
#[allow(clippy::large_enum_variant)]
pub enum CubeCount {
Static(u32, u32, u32),
Dynamic(Binding),
}
pub enum CubeCountSelection {
Exact(CubeCount),
Approx(CubeCount, u32),
}
impl CubeCountSelection {
pub fn new<R: Runtime>(client: &ComputeClient<R>, num_cubes: u32) -> Self {
let cube_count = cube_count_spread(&client.properties().hardware.max_cube_count, num_cubes);
let num_cubes_actual = cube_count[0] * cube_count[1] * cube_count[2];
let cube_count = CubeCount::Static(cube_count[0], cube_count[1], cube_count[2]);
match num_cubes_actual == num_cubes {
true => CubeCountSelection::Exact(cube_count),
false => CubeCountSelection::Approx(cube_count, num_cubes_actual),
}
}
pub fn has_idle(&self) -> bool {
matches!(self, Self::Approx(..))
}
pub fn cube_count(self) -> CubeCount {
match self {
CubeCountSelection::Exact(cube_count) => cube_count,
CubeCountSelection::Approx(cube_count, _) => cube_count,
}
}
}
impl From<CubeCountSelection> for CubeCount {
fn from(value: CubeCountSelection) -> Self {
value.cube_count()
}
}
impl CubeCount {
pub fn new_single() -> Self {
CubeCount::Static(1, 1, 1)
}
pub fn new_1d(x: u32) -> Self {
CubeCount::Static(x, 1, 1)
}
pub fn new_2d(x: u32, y: u32) -> Self {
CubeCount::Static(x, y, 1)
}
pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
CubeCount::Static(x, y, z)
}
pub fn is_empty(&self) -> bool {
match self {
Self::Static(x, y, z) => *x == 0 || *y == 0 || *z == 0,
Self::Dynamic(_) => false,
}
}
}
impl Debug for CubeCount {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
CubeCount::Dynamic(_) => f.write_str("binding"),
}
}
}
impl Clone for CubeCount {
fn clone(&self) -> Self {
match self {
Self::Static(x, y, z) => Self::Static(*x, *y, *z),
Self::Dynamic(binding) => Self::Dynamic(binding.clone()),
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
#[allow(missing_docs)]
pub struct CubeDim {
pub x: u32,
pub y: u32,
pub z: u32,
}
impl CubeDim {
pub fn new<R: Runtime>(client: &ComputeClient<R>, working_units: usize) -> Self {
let properties = client.properties();
let plane_size = properties.hardware.plane_size_max;
let plane_count = Self::calculate_plane_count_per_cube(
working_units as u32,
plane_size,
properties.hardware.num_cpu_cores,
);
let limit = properties.hardware.max_units_per_cube / plane_size;
Self::new_2d(plane_size, u32::min(limit, plane_count).max(1))
}
fn calculate_plane_count_per_cube(
working_units: u32,
plane_dim: u32,
num_cpu_cores: Option<u32>,
) -> u32 {
match num_cpu_cores {
Some(num_cores) => core::cmp::min(num_cores, working_units),
None => {
let plane_count_max = core::cmp::max(1, working_units / plane_dim);
const NUM_PLANE_MAX: u32 = 8u32;
const NUM_PLANE_MAX_LOG2: u32 = NUM_PLANE_MAX.ilog2();
let plane_count_max_log2 =
core::cmp::min(NUM_PLANE_MAX_LOG2, u32::ilog2(plane_count_max));
2u32.pow(plane_count_max_log2)
}
}
}
pub const fn new_single() -> Self {
Self { x: 1, y: 1, z: 1 }
}
pub const fn new_1d(x: u32) -> Self {
Self { x, y: 1, z: 1 }
}
pub const fn new_2d(x: u32, y: u32) -> Self {
Self { x, y, z: 1 }
}
pub const fn new_3d(x: u32, y: u32, z: u32) -> Self {
Self { x, y, z }
}
pub const fn num_elems(&self) -> u32 {
self.x * self.y * self.z
}
pub const fn can_contain(&self, other: CubeDim) -> bool {
self.x >= other.x && self.y >= other.y && self.z >= other.z
}
}
impl From<(u32, u32, u32)> for CubeDim {
fn from(value: (u32, u32, u32)) -> Self {
CubeDim::new_3d(value.0, value.1, value.2)
}
}
impl From<CubeDim> for (u32, u32, u32) {
fn from(val: CubeDim) -> Self {
(val.x, val.y, val.z)
}
}
#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
pub enum ExecutionMode {
#[default]
Checked,
Validate,
Unchecked,
}
fn cube_count_spread(max: &(u32, u32, u32), num_cubes: u32) -> [u32; 3] {
let max_cube_counts = [max.0, max.1, max.2];
let mut num_cubes = [num_cubes, 1, 1];
let base = 2;
let mut reduce_count = |i: usize| {
if num_cubes[i] <= max_cube_counts[i] {
return true;
}
loop {
num_cubes[i] = num_cubes[i].div_ceil(base);
num_cubes[i + 1] *= base;
if num_cubes[i] <= max_cube_counts[i] {
return false;
}
}
};
for i in 0..2 {
if reduce_count(i) {
break;
}
}
num_cubes
}
#[cfg(test)]
mod tests {
use super::*;
#[test_log::test]
fn safe_num_cubes_even() {
let max = (32, 32, 32);
let required = 2048;
let actual = cube_count_spread(&max, required);
let expected = [32, 32, 2];
assert_eq!(actual, expected);
}
#[test_log::test]
fn safe_num_cubes_odd() {
let max = (48, 32, 16);
let required = 3177;
let actual = cube_count_spread(&max, required);
let expected = [25, 32, 4];
assert_eq!(actual, expected);
}
}