use crate::{
compute::{
MB, context::HipContext, fence::Fence, gpu::GpuResource,
io::controller::PinnedMemoryManagedAllocController, stream::HipStreamBackend,
},
runtime::HipCompiler,
};
use cubecl_common::{backtrace::BackTrace, bytes::Bytes, stream_id::StreamId};
use cubecl_core::{
MemoryUsage,
bytes::AllocationProperty,
future::DynFut,
server::{
Binding, CopyDescriptor, ExecutionMode, Handle, IoError, LaunchError, ProfileError,
ServerError,
},
zspace::{Shape, Strides, striding::has_pitched_row_major_strides},
};
use cubecl_hip_sys::{
HIP_SUCCESS, hipMemcpyKind_hipMemcpyDeviceToHost, hipMemcpyKind_hipMemcpyHostToDevice,
ihipStream_t,
};
use cubecl_runtime::{
compiler::CubeTask,
id::KernelId,
logging::ServerLogger,
memory_management::{ManagedMemoryHandle, MemoryAllocationMode, MemoryHandle},
stream::ResolvedStreams,
};
use std::{ffi::c_void, sync::Arc};
#[derive(new)]
pub struct Command<'a> {
ctx: &'a mut HipContext,
pub(crate) streams: ResolvedStreams<'a, HipStreamBackend>,
}
impl<'a> Command<'a> {
pub fn resource(&mut self, binding: Binding) -> Result<GpuResource, IoError> {
self.streams
.get(&binding.stream)
.memory_management_gpu
.get_resource(binding.memory, binding.offset_start, binding.offset_end)
}
pub fn memory_usage(&mut self) -> MemoryUsage {
self.streams.current().memory_management_gpu.memory_usage()
}
pub fn memory_cleanup(&mut self) {
self.streams.current().memory_management_gpu.cleanup(true)
}
pub fn allocation_mode(&mut self, mode: MemoryAllocationMode) {
self.streams.current().memory_management_gpu.mode(mode)
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
pub fn reserve(&mut self, size: u64) -> Result<ManagedMemoryHandle, IoError> {
let handle = self.streams.current().memory_management_gpu.reserve(size)?;
Ok(handle)
}
pub fn cursor(&self) -> u64 {
self.streams.cursor
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
pub fn empty(&mut self, size: u64) -> Result<Handle, IoError> {
let handle = Handle::new(self.streams.current, size);
let reserved = self.reserve(size)?;
self.bind(reserved, handle.memory.clone());
Ok(handle)
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
pub fn bind(&mut self, reserved: ManagedMemoryHandle, new: ManagedMemoryHandle) {
let cursor = self.cursor();
self.streams
.current()
.memory_management_gpu
.bind(reserved, new, cursor)
.unwrap();
}
pub fn reserve_cpu(
&mut self,
size: usize,
marked_pinned: bool,
origin: Option<StreamId>,
) -> Bytes {
if !marked_pinned && size > 100 * MB {
return Bytes::from_bytes_vec(vec![0; size]);
}
self.reserve_pinned(size, origin)
.unwrap_or_else(|| Bytes::from_bytes_vec(vec![0; size]))
}
fn reserve_pinned(&mut self, size: usize, origin: Option<StreamId>) -> Option<Bytes> {
let stream = match origin {
Some(id) => self.streams.get(&id),
None => self.streams.current(),
};
let handle = stream.memory_management_cpu.reserve(size as u64).ok()?;
let binding = MemoryHandle::binding(handle);
let resource = stream
.memory_management_cpu
.get_resource(binding.clone(), None, None)
.ok()?;
let controller = Box::new(PinnedMemoryManagedAllocController::init(binding, resource));
Some(unsafe { Bytes::from_controller(controller, size) })
}
pub fn read_async(
&mut self,
descriptors: Vec<CopyDescriptor>,
) -> impl Future<Output = Result<Vec<Bytes>, ServerError>> + Send + use<> {
let descriptors_moved = descriptors
.iter()
.map(|b| b.handle.clone())
.collect::<Vec<_>>();
let result = self.copies_to_bytes(descriptors, true);
let fence = Fence::new(self.streams.current().sys);
async move {
let sync = fence.wait_sync();
core::mem::drop(descriptors_moved);
sync?;
let bytes = result?;
Ok(bytes)
}
}
fn copies_to_bytes(
&mut self,
descriptors: Vec<CopyDescriptor>,
pinned: bool,
) -> Result<Vec<Bytes>, IoError> {
let mut result = Vec::with_capacity(descriptors.len());
for descriptor in descriptors {
result.push(self.copy_to_bytes(descriptor, pinned, None)?);
}
Ok(result)
}
fn copy_to_bytes(
&mut self,
descriptor: CopyDescriptor,
pinned: bool,
stream_id: Option<StreamId>,
) -> Result<Bytes, IoError> {
let num_bytes = descriptor.shape.iter().product::<usize>() * descriptor.elem_size;
let mut bytes = self.reserve_cpu(num_bytes, pinned, stream_id);
self.write_to_cpu(descriptor, &mut bytes, stream_id)?;
Ok(bytes)
}
pub fn write_to_cpu(
&mut self,
descriptor: CopyDescriptor,
bytes: &mut Bytes,
stream_id: Option<StreamId>,
) -> Result<(), IoError> {
let CopyDescriptor {
handle: binding,
shape,
strides,
elem_size,
} = descriptor;
if !has_pitched_row_major_strides(&shape, &strides) {
return Err(IoError::UnsupportedStrides {
backtrace: BackTrace::capture(),
});
}
let resource = self.resource(binding)?;
let stream = match stream_id {
Some(id) => self.streams.get(&id),
None => self.streams.current(),
};
unsafe { write_to_cpu(&shape, &strides, elem_size, bytes, resource.ptr, stream.sys) }
}
pub fn write_to_gpu(&mut self, descriptor: CopyDescriptor, data: Bytes) -> Result<(), IoError> {
let CopyDescriptor {
handle: binding,
shape,
strides,
elem_size,
} = descriptor;
if !has_pitched_row_major_strides(&shape, &strides) {
return Err(IoError::UnsupportedStrides {
backtrace: BackTrace::capture(),
});
}
let resource = self.resource(binding)?;
let size = data.len();
let data = match data.property() {
AllocationProperty::File => {
let mut buffer = self.reserve_pinned(size, None).unwrap();
data.copy_into(&mut buffer);
buffer
}
_ => data,
};
let current = self.streams.current();
unsafe {
write_to_gpu(resource, &shape, &strides, elem_size, &data, current.sys)?;
};
current.drop_queue.push(data);
Ok(())
}
pub fn create_with_data(&mut self, data: &[u8]) -> Result<Handle, IoError> {
let mut staging =
self.reserve_pinned(data.len(), None)
.ok_or_else(|| IoError::Unknown {
backtrace: BackTrace::capture(),
description: "Unable to reserve pinned memory".into(),
})?;
staging.copy_from_slice(data);
let handle = self.empty(staging.len() as u64)?;
self.write_to_gpu(
CopyDescriptor {
handle: handle.clone().binding(),
shape: [data.len()].into(),
strides: [1].into(),
elem_size: 1,
},
staging,
)?;
Ok(handle)
}
pub fn sync(&mut self) -> DynFut<Result<(), ServerError>> {
let fence = Fence::new(self.streams.current().sys);
Box::pin(async { fence.wait_sync() })
}
pub fn kernel(
&mut self,
kernel_id: KernelId,
kernel: Box<dyn CubeTask<HipCompiler>>,
mode: ExecutionMode,
dispatch_count: (u32, u32, u32),
resources: &[GpuResource],
logger: Arc<ServerLogger>,
) -> Result<(), LaunchError> {
if !self.ctx.module_names.contains_key(&kernel_id) {
self.ctx.compile_kernel(&kernel_id, kernel, mode, logger)?;
}
let stream = self.streams.current();
let result = self
.ctx
.execute_task(stream, kernel_id, dispatch_count, resources);
if stream.drop_queue.should_flush() {
stream.drop_queue.flush(|| Fence::new(stream.sys));
}
if let Err(err) = result {
match self.ctx.timestamps.is_empty() {
true => Err(err)?,
false => self.ctx.timestamps.error(ProfileError::Launch(err)),
}
};
Ok(())
}
pub fn error(&mut self, error: ServerError) {
let stream = self.streams.current();
stream.errors.push(error);
}
}
pub(crate) unsafe fn write_to_cpu(
shape: &[usize],
strides: &[usize],
elem_size: usize,
bytes: &mut Bytes,
resource_ptr: *mut c_void,
stream: *mut ihipStream_t,
) -> Result<(), IoError> {
let rank = shape.len();
if rank <= 1 {
let status = unsafe {
cubecl_hip_sys::hipMemcpyDtoHAsync(
bytes.as_mut_ptr() as *mut _,
resource_ptr,
bytes.len(),
stream,
)
};
if status != HIP_SUCCESS {
return Err(IoError::Unknown {
description: format!("HIP memcpy failed: {status}"),
backtrace: BackTrace::capture(),
});
}
return Ok(());
}
let dim_x = shape[rank - 1];
let width_bytes = dim_x * elem_size;
let dim_y: usize = shape.iter().rev().skip(1).product();
let pitch = strides[rank - 2] * elem_size;
unsafe {
let status = cubecl_hip_sys::hipMemcpy2DAsync(
bytes.as_mut_ptr() as *mut _,
width_bytes,
resource_ptr,
pitch,
width_bytes,
dim_y,
hipMemcpyKind_hipMemcpyDeviceToHost,
stream,
);
if status != HIP_SUCCESS {
let status = cubecl_hip_sys::hipMemcpyDtoHAsync(
bytes.as_mut_ptr() as *mut _,
resource_ptr,
bytes.len(),
stream,
);
assert_eq!(status, HIP_SUCCESS, "Should send data to device");
}
}
Ok(())
}
unsafe fn write_to_gpu(
resource: GpuResource,
shape: &Shape,
strides: &Strides,
elem_size: usize,
data: &[u8],
stream: *mut ihipStream_t,
) -> Result<(), IoError> {
let rank = shape.len();
if !has_pitched_row_major_strides(shape, strides) {
return Err(IoError::UnsupportedStrides {
backtrace: BackTrace::capture(),
});
}
let ptr = data as *const _ as *mut _;
if rank > 1 {
let stride = strides[rank - 2];
let width = *shape.last().unwrap_or(&1);
let height: usize = shape.iter().rev().skip(1).product();
let width_bytes = width * elem_size;
let stride_bytes = stride * elem_size;
unsafe {
let status = cubecl_hip_sys::hipMemcpy2DAsync(
resource.ptr,
stride_bytes,
ptr,
width_bytes,
width_bytes,
height.max(1),
hipMemcpyKind_hipMemcpyHostToDevice,
stream,
);
assert_eq!(status, HIP_SUCCESS, "Should send data to device");
}
} else {
unsafe {
assert!(resource.size >= data.len() as u64);
let status = cubecl_hip_sys::hipMemcpyHtoDAsync(resource.ptr, ptr, data.len(), stream);
assert_eq!(status, HIP_SUCCESS, "Should send data to device");
}
};
Ok(())
}