use super::DummyKernel;
use crate::dummy::DummyCompiler;
use cubecl_common::{bytes::Bytes, future::DynFut, profile::ProfileDuration, stream_id::StreamId};
use cubecl_ir::{
DeviceProperties, ElemType, HardwareProperties, MemoryDeviceProperties, StorageType, UIntKind,
VectorSize, features::Features,
};
use cubecl_runtime::{
allocator::ContiguousMemoryLayoutPolicy,
compiler::{CompilationError, CubeTask},
id::KernelId,
kernel::{CompiledKernel, KernelMetadata},
logging::ServerLogger,
memory_management::{ManagedMemoryHandle, MemoryAllocationMode, MemoryManagement, MemoryUsage},
server::{
Binding, ComputeServer, CopyDescriptor, CubeCount, CubeDim, ExecutionMode, Handle,
KernelArguments, ProfileError, ProfilingToken, ServerCommunication, ServerError,
ServerUtilities,
},
storage::{BytesResource, BytesStorage, ComputeStorage, ManagedResource},
timestamp_profiler::TimestampProfiler,
};
use cubecl_zspace::{Shape, Strides};
use std::sync::Arc;
#[derive(Debug)]
pub struct DummyServer {
memory_management: MemoryManagement<BytesStorage>,
timestamps: TimestampProfiler,
utilities: Arc<ServerUtilities<Self>>,
}
#[derive(Debug, Clone)]
pub struct KernelTask {
kernel: Arc<dyn DummyKernel>,
}
impl KernelMetadata for KernelTask {
fn name(&self) -> &'static str {
self.kernel.name()
}
fn id(&self) -> KernelId {
self.kernel.id()
}
fn address_type(&self) -> cubecl_ir::StorageType {
ElemType::UInt(UIntKind::U32).into()
}
}
impl core::fmt::Display for KernelTask {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Dummy kernel")
}
}
impl CubeTask<DummyCompiler> for KernelTask {
fn compile(
&self,
_compiler: &mut DummyCompiler,
_compilation_options: &<DummyCompiler as cubecl_runtime::compiler::Compiler>::CompilationOptions,
_mode: ExecutionMode,
_addr_type: StorageType,
) -> Result<cubecl_runtime::kernel::CompiledKernel<DummyCompiler>, CompilationError> {
Ok(CompiledKernel {
entrypoint_name: self.kernel.name().to_string(),
debug_name: None,
source: String::new(),
repr: Some(self.clone()),
cube_dim: CubeDim::new_single(),
debug_info: None,
})
}
}
impl KernelTask {
pub fn new(kernel: impl DummyKernel) -> Self {
Self {
kernel: Arc::new(kernel),
}
}
pub fn compute(&self, resources: &mut [&mut BytesResource]) {
self.kernel.compute(resources);
}
}
impl ServerCommunication for DummyServer {
const SERVER_COMM_ENABLED: bool = false;
}
impl ComputeServer for DummyServer {
type Kernel = Box<dyn CubeTask<DummyCompiler>>;
type Storage = BytesStorage;
type MemoryLayoutPolicy = ContiguousMemoryLayoutPolicy;
type Info = ();
fn logger(&self) -> Arc<ServerLogger> {
self.utilities.logger.clone()
}
fn utilities(&self) -> Arc<cubecl_runtime::server::ServerUtilities<Self>> {
self.utilities.clone()
}
fn initialize_memory(&mut self, memory: ManagedMemoryHandle, size: u64, _stream_id: StreamId) {
let reserved = self.memory_management.reserve(size).unwrap();
self.memory_management
.bind(reserved, memory.clone(), 0)
.unwrap();
}
fn read(
&mut self,
descriptors: Vec<CopyDescriptor>,
_stream_id: StreamId,
) -> DynFut<Result<Vec<Bytes>, ServerError>> {
let bytes: Vec<_> = descriptors
.into_iter()
.map(|b| {
let size = b.handle.size_in_used();
let resource = self
.memory_management
.get_resource(
b.handle.memory.clone(),
b.handle.offset_start,
b.handle.offset_end,
)
.unwrap();
(resource, size, b.handle.memory)
})
.collect();
Box::pin(async move {
Ok(bytes
.into_iter()
.map(|(b, size, _binding)| {
Bytes::from_bytes_vec(b.read()[0..size as usize].to_vec())
})
.collect())
})
}
fn write(&mut self, descriptors: Vec<(CopyDescriptor, Bytes)>, _stream_id: StreamId) {
for (descriptor, data) in descriptors {
let storage_h = self
.memory_management
.get_storage(descriptor.handle.memory)
.unwrap();
let mut bytes = self.memory_management.storage().get(&storage_h);
bytes.write()[..data.len()].copy_from_slice(&data);
}
}
fn sync(&mut self, _stream_id: StreamId) -> DynFut<Result<(), ServerError>> {
Box::pin(async move { Ok(()) })
}
fn get_resource(
&mut self,
binding: Binding,
_stream_id: StreamId,
) -> Result<ManagedResource<BytesResource>, ServerError> {
let resource = self.memory_management.get_resource(
binding.memory.clone(),
binding.offset_start,
binding.offset_end,
)?;
Ok(ManagedResource::new(binding.memory, resource))
}
unsafe fn launch(
&mut self,
kernel: Self::Kernel,
_count: CubeCount,
bindings: KernelArguments,
mode: ExecutionMode,
stream_id: StreamId,
) {
let mut resources: Vec<_> = bindings
.buffers
.into_iter()
.map(|b| {
self.memory_management
.get_resource(b.memory, b.offset_start, b.offset_end)
.unwrap()
})
.collect();
let data = bytemuck::cast_slice(&bindings.info.data);
let metadata = Handle::new(stream_id, data.len() as u64);
self.bind_with_data(data, metadata.clone(), stream_id);
resources.push({
self.memory_management
.get_resource(
metadata.memory.binding(),
metadata.offset_start,
metadata.offset_end,
)
.unwrap()
});
let mut resources: Vec<_> = resources.iter_mut().collect();
let kernel = kernel
.compile(&mut DummyCompiler, &(), mode, kernel.address_type())
.unwrap();
kernel.repr.unwrap().compute(resources.as_mut_slice());
}
fn flush(&mut self, _stream_id: StreamId) -> Result<(), ServerError> {
Ok(())
}
fn memory_usage(&mut self, _stream_id: StreamId) -> Result<MemoryUsage, ServerError> {
Ok(self.memory_management.memory_usage())
}
fn memory_cleanup(&mut self, _stream_id: StreamId) {
self.memory_management.cleanup(true);
}
fn start_profile(&mut self, _stream_id: StreamId) -> Result<ProfilingToken, ServerError> {
Ok(self.timestamps.start())
}
fn end_profile(
&mut self,
_stream_id: StreamId,
token: ProfilingToken,
) -> Result<ProfileDuration, ProfileError> {
self.timestamps.stop(token)
}
fn allocation_mode(&mut self, mode: MemoryAllocationMode, _stream_id: StreamId) {
self.memory_management.mode(mode)
}
}
impl DummyServer {
pub fn new(
memory_management: MemoryManagement<BytesStorage>,
mem_props: MemoryDeviceProperties,
) -> Self {
let hardware = HardwareProperties {
load_width: 128,
plane_size_min: 32,
plane_size_max: 32,
max_bindings: 32,
max_shared_memory_size: 48000,
max_cube_count: (u16::MAX as u32, u16::MAX as u32, u16::MAX as u32),
max_units_per_cube: 1024,
max_cube_dim: (1024, 1024, 64),
num_streaming_multiprocessors: None,
num_tensor_cores: None,
min_tensor_cores_dim: None,
num_cpu_cores: None,
max_vector_size: VectorSize::MAX,
};
let features = Features::default();
let timing_method = cubecl_common::profile::TimingMethod::System;
let props = DeviceProperties::new(features, mem_props, hardware, timing_method);
let logger = Arc::new(ServerLogger::default());
let utilities = Arc::new(ServerUtilities::new(
props,
logger,
(),
ContiguousMemoryLayoutPolicy::new(4),
));
Self {
memory_management,
utilities,
timestamps: TimestampProfiler::default(),
}
}
fn bind_with_data(&mut self, data: &[u8], handle: Handle, stream_id: StreamId) {
let strides: Strides = [1].into();
let shape: Shape = [data.len()].into();
self.initialize_memory(handle.memory.clone(), handle.size(), stream_id);
self.write(
vec![(
CopyDescriptor::new(handle.binding(), shape, strides, 1),
Bytes::from_bytes_vec(data.to_vec()),
)],
stream_id,
);
}
}