use crate::{
CudaCompiler,
compute::{
MB, context::CudaContext, io::controller::PinnedMemoryManagedAllocController,
storage::gpu::GpuResource, stream::CudaStreamBackend, sync::Fence,
},
};
use cubecl_common::{
backtrace::BackTrace,
bytes::{AllocationProperty, Bytes},
stream_id::StreamId,
};
#[cfg(debug_assertions)]
use cubecl_core::zspace::striding::try_check_pitched_row_major_strides;
use cubecl_core::{
MemoryUsage,
future::DynFut,
server::{
Binding, CopyDescriptor, ExecutionMode, Handle, IoError, LaunchError, ProfileError,
ServerError,
},
zspace::{Shape, Strides, striding::has_pitched_row_major_strides},
};
use cubecl_runtime::{
compiler::CubeTask,
id::KernelId,
logging::ServerLogger,
memory_management::{ManagedMemoryHandle, MemoryAllocationMode, MemoryHandle},
stream::ResolvedStreams,
};
use cudarc::driver::sys::{
CUDA_MEMCPY2D_st, CUmemorytype, CUstream_st, CUtensorMap, cuMemcpy2DAsync_v2,
};
use std::{ffi::c_void, ops::DerefMut, sync::Arc};
#[derive(new)]
pub struct Command<'a> {
ctx: &'a mut CudaContext,
pub(crate) streams: ResolvedStreams<'a, CudaStreamBackend>,
}
impl<'a> Command<'a> {
pub fn resource(&mut self, binding: Binding) -> Result<GpuResource, IoError> {
self.streams
.get(&binding.stream)
.memory_management_gpu
.get_resource(binding.memory, binding.offset_start, binding.offset_end)
}
pub fn cursor(&self) -> u64 {
self.streams.cursor
}
pub fn memory_usage(&mut self) -> MemoryUsage {
self.streams.current().memory_management_gpu.memory_usage()
}
pub fn memory_cleanup(&mut self) {
self.streams.current().memory_management_gpu.cleanup(true)
}
pub fn allocation_mode(&mut self, mode: MemoryAllocationMode) {
self.streams.current().memory_management_gpu.mode(mode)
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
pub fn reserve(&mut self, size: u64) -> Result<ManagedMemoryHandle, IoError> {
let handle = self.streams.current().memory_management_gpu.reserve(size)?;
Ok(handle)
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
pub fn empty(&mut self, size: u64) -> Result<Handle, IoError> {
let handle = Handle::new(self.streams.current, size);
let reserved = self.reserve(size)?;
self.bind(reserved, handle.memory.clone());
Ok(handle)
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
pub fn bind(&mut self, reserved: ManagedMemoryHandle, new: ManagedMemoryHandle) {
let cursor = self.cursor();
self.streams
.current()
.memory_management_gpu
.bind(reserved, new, cursor)
.unwrap();
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
pub fn reserve_cpu(
&mut self,
size: usize,
marked_pinned: bool,
origin: Option<StreamId>,
) -> Bytes {
if !marked_pinned && size > 100 * MB {
return Bytes::from_bytes_vec(vec![0; size]);
}
self.reserve_pinned(size, origin)
.unwrap_or_else(|| Bytes::from_bytes_vec(vec![0; size]))
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
fn reserve_pinned(&mut self, size: usize, origin: Option<StreamId>) -> Option<Bytes> {
let stream = match origin {
Some(id) => self.streams.get(&id),
None => self.streams.current(),
};
let handle = stream.memory_management_cpu.reserve(size as u64).ok()?;
let binding = MemoryHandle::binding(handle);
let resource = stream
.memory_management_cpu
.get_resource(binding.clone(), None, None)
.ok()?;
let controller = Box::new(PinnedMemoryManagedAllocController::init(binding, resource));
Some(unsafe { Bytes::from_controller(controller, size) })
}
pub fn read_async(
&mut self,
descriptors: Vec<CopyDescriptor>,
) -> impl Future<Output = Result<Vec<Bytes>, ServerError>> + Send + use<> {
let descriptors_moved = descriptors
.iter()
.map(|b| b.handle.clone())
.collect::<Vec<_>>();
let result = self.copies_to_bytes(descriptors, true);
let fence = Fence::new(self.streams.current().sys);
async move {
let sync = fence.wait_sync();
core::mem::drop(descriptors_moved);
sync?;
let bytes = result?;
Ok(bytes)
}
}
#[allow(unused)]
pub fn read_async_origin(
&mut self,
descriptors: Vec<CopyDescriptor>,
) -> impl Future<Output = Result<Vec<Bytes>, IoError>> + Send + use<> {
let results = self.copies_to_bytes_origin(descriptors, true);
async move {
let (bytes, fences) = results?;
for fence in fences {
fence.wait_sync();
}
Ok(bytes)
}
}
fn copies_to_bytes(
&mut self,
descriptors: Vec<CopyDescriptor>,
pinned: bool,
) -> Result<Vec<Bytes>, IoError> {
let mut result = Vec::with_capacity(descriptors.len());
for descriptor in descriptors {
result.push(self.copy_to_bytes(descriptor, pinned, None)?);
}
Ok(result)
}
fn copies_to_bytes_origin(
&mut self,
descriptors: Vec<CopyDescriptor>,
pinned: bool,
) -> Result<(Vec<Bytes>, Vec<Fence>), IoError> {
let mut data = Vec::with_capacity(descriptors.len());
let mut fences = Vec::with_capacity(descriptors.len());
let mut fenced = Vec::with_capacity(descriptors.len());
for descriptor in descriptors {
let stream = descriptor.handle.stream;
let bytes = self.copy_to_bytes(descriptor, pinned, Some(stream))?;
if !fenced.contains(&stream) {
let fence = Fence::new(self.streams.get(&stream).sys);
fenced.push(stream);
fences.push(fence);
}
data.push(bytes);
}
Ok((data, fences))
}
pub fn copy_to_bytes(
&mut self,
descriptor: CopyDescriptor,
pinned: bool,
stream_id: Option<StreamId>,
) -> Result<Bytes, IoError> {
let num_bytes = descriptor.shape.iter().product::<usize>() * descriptor.elem_size;
let mut bytes = self.reserve_cpu(num_bytes, pinned, stream_id);
self.write_to_cpu(descriptor, &mut bytes, stream_id)?;
Ok(bytes)
}
pub fn write_to_cpu(
&mut self,
descriptor: CopyDescriptor,
bytes: &mut Bytes,
stream_id: Option<StreamId>,
) -> Result<(), IoError> {
let CopyDescriptor {
handle: binding,
shape,
strides,
elem_size,
} = descriptor;
if !has_pitched_row_major_strides(&shape, &strides) {
return Err(IoError::UnsupportedStrides {
backtrace: BackTrace::capture(),
});
}
let resource = self.resource(binding)?;
let stream = match stream_id {
Some(id) => self.streams.get(&id),
None => self.streams.current(),
};
unsafe {
write_to_cpu(&shape, &strides, elem_size, bytes, resource.ptr, stream.sys)?;
}
Ok(())
}
pub fn error(&mut self, error: ServerError) {
let stream = self.streams.current();
stream.errors.push(error);
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(self, descriptor, data))
)]
pub fn write_to_gpu(&mut self, descriptor: CopyDescriptor, data: Bytes) -> Result<(), IoError> {
let CopyDescriptor {
handle,
shape,
strides,
elem_size,
} = descriptor;
if !has_pitched_row_major_strides(&shape, &strides) {
return Err(IoError::UnsupportedStrides {
backtrace: BackTrace::capture(),
});
}
let resource = self.resource(handle)?;
let size = data.len();
let data = match data.property() {
AllocationProperty::File => {
let mut buffer = self.reserve_pinned(size, None).unwrap();
data.copy_into(&mut buffer);
buffer
}
_ => data,
};
let current = self.streams.current();
unsafe {
write_to_gpu(
&shape,
&strides,
elem_size,
&data,
resource.ptr,
current.sys,
)
}?;
current.drop_queue.push(data);
Ok(())
}
pub fn create_with_data(&mut self, data: &[u8]) -> Result<Handle, IoError> {
let mut staging =
self.reserve_pinned(data.len(), None)
.ok_or_else(|| IoError::Unknown {
backtrace: BackTrace::capture(),
description: "Unable to reserve pinned memory".into(),
})?;
staging.copy_from_slice(data);
let handle = self.empty(staging.len() as u64)?;
self.write_to_gpu(
CopyDescriptor {
handle: handle.clone().binding(),
shape: [data.len()].into(),
strides: [1].into(),
elem_size: 1,
},
staging,
)?;
Ok(handle)
}
pub fn sync(&mut self) -> DynFut<Result<(), ServerError>> {
let fence = Fence::new(self.streams.current().sys);
Box::pin(async { fence.wait_sync() })
}
#[allow(clippy::too_many_arguments)]
pub fn kernel(
&mut self,
kernel_id: KernelId,
kernel: Box<dyn CubeTask<CudaCompiler>>,
mode: ExecutionMode,
dispatch_count: (u32, u32, u32),
tensor_maps: &[CUtensorMap],
resources: &[GpuResource],
const_info: Option<*mut c_void>,
logger: Arc<ServerLogger>,
) -> Result<(), LaunchError> {
if !self.ctx.module_names.contains_key(&kernel_id) {
self.ctx.compile_kernel(&kernel_id, kernel, mode, logger)?;
}
let stream = self.streams.current();
let result = self.ctx.execute_task(
stream,
kernel_id,
dispatch_count,
tensor_maps,
resources,
const_info,
);
if stream.drop_queue.should_flush() {
stream.drop_queue.flush(|| Fence::new(stream.sys));
}
if let Err(err) = result {
match self.ctx.timestamps.is_empty() {
true => return Err(err),
false => self.ctx.timestamps.error(ProfileError::Launch(err)),
}
};
Ok(())
}
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(strides, data, dst_ptr, stream))
)]
pub(crate) unsafe fn write_to_gpu(
shape: &Shape,
strides: &Strides,
elem_size: usize,
data: &[u8],
dst_ptr: u64,
stream: *mut CUstream_st,
) -> Result<(), IoError> {
#[cfg(debug_assertions)]
try_check_pitched_row_major_strides(shape, strides).map_err(|e| IoError::Unknown {
description: format!("write_to_gpu: invalid strides: {e}"),
backtrace: BackTrace::capture(),
})?;
let rank = shape.len();
if rank <= 1 {
unsafe {
cudarc::driver::result::memcpy_htod_async(dst_ptr, data, stream).map_err(|e| {
IoError::Unknown {
description: format!("CUDA memcpy_htod failed: {e}"),
backtrace: BackTrace::capture(),
}
})
}
} else {
let dim_x_shape = shape[rank - 1];
let width_bytes = dim_x_shape * elem_size;
let dim_y_shape: usize = shape[..rank - 1].iter().product();
let pitch = strides[rank - 2] * elem_size;
let cpy = CUDA_MEMCPY2D_st {
srcMemoryType: CUmemorytype::CU_MEMORYTYPE_HOST,
srcHost: data.as_ptr() as *const c_void,
srcPitch: width_bytes,
dstMemoryType: CUmemorytype::CU_MEMORYTYPE_DEVICE,
dstDevice: dst_ptr,
dstPitch: pitch,
WidthInBytes: width_bytes,
Height: dim_y_shape,
srcXInBytes: Default::default(),
srcY: Default::default(),
srcDevice: Default::default(),
srcArray: Default::default(),
dstXInBytes: Default::default(),
dstY: Default::default(),
dstHost: Default::default(),
dstArray: Default::default(),
};
unsafe {
cuMemcpy2DAsync_v2(&cpy, stream)
.result()
.map_err(|e| IoError::Unknown {
description: format!("CUDA memcpy failed: {e}"),
backtrace: BackTrace::capture(),
})
}
}
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(strides, bytes, resource_ptr, stream))
)]
pub(crate) unsafe fn write_to_cpu(
shape: &Shape,
strides: &Strides,
elem_size: usize,
bytes: &mut Bytes,
resource_ptr: u64,
stream: *mut CUstream_st,
) -> Result<(), IoError> {
#[cfg(debug_assertions)]
try_check_pitched_row_major_strides(shape, strides).map_err(|e| IoError::Unknown {
description: format!("write_to_cpu: invalid strides: {e}"),
backtrace: BackTrace::capture(),
})?;
let rank = shape.len();
let bytes = bytes.deref_mut();
if rank <= 1 {
unsafe {
cudarc::driver::result::memcpy_dtoh_async(bytes, resource_ptr, stream).map_err(|e| {
IoError::Unknown {
description: format!("CUDA memcpy_dtoh failed: {e}"),
backtrace: BackTrace::capture(),
}
})
}
} else {
let dim_x_shape = shape[rank - 1];
let width_bytes = dim_x_shape * elem_size;
let dim_y_shape: usize = shape[..rank - 1].iter().product();
let pitch = strides[rank - 2] * elem_size;
let cpy = CUDA_MEMCPY2D_st {
srcMemoryType: CUmemorytype::CU_MEMORYTYPE_DEVICE,
srcDevice: resource_ptr,
srcPitch: pitch,
dstMemoryType: CUmemorytype::CU_MEMORYTYPE_HOST,
dstHost: bytes.as_mut_ptr() as *mut c_void,
dstPitch: width_bytes,
WidthInBytes: width_bytes,
Height: dim_y_shape,
srcXInBytes: Default::default(),
srcY: Default::default(),
srcArray: Default::default(),
dstXInBytes: Default::default(),
dstY: Default::default(),
dstArray: Default::default(),
srcHost: Default::default(),
dstDevice: Default::default(),
};
unsafe {
cuMemcpy2DAsync_v2(&cpy, stream)
.result()
.map_err(|e| IoError::Unknown {
description: format!("CUDA 2D memcpy failed: {e}"),
backtrace: BackTrace::capture(),
})
}
}
}