use cubecl_core::{MemoryConfiguration, ir::MemoryDeviceProperties, server::ExecutionError};
use cubecl_hip_sys::HIP_SUCCESS;
use cubecl_runtime::{
logging::ServerLogger,
memory_management::{MemoryAllocationMode, MemoryManagement, MemoryManagementOptions},
stream::EventStreamBackend,
};
use std::sync::Arc;
use crate::compute::{
cpu::{PINNED_MEMORY_ALIGNMENT, PinnedMemoryStorage},
fence::Fence,
gpu::GpuStorage,
};
#[derive(Debug)]
pub struct Stream {
pub(crate) sys: cubecl_hip_sys::hipStream_t,
pub memory_management_gpu: MemoryManagement<GpuStorage>,
pub memory_management_cpu: MemoryManagement<PinnedMemoryStorage>,
}
#[derive(new, Debug)]
pub struct HipStreamBackend {
mem_props: MemoryDeviceProperties,
mem_config: MemoryConfiguration,
mem_alignment: usize,
logger: Arc<ServerLogger>,
}
impl EventStreamBackend for HipStreamBackend {
type Stream = Stream;
type Event = Fence;
fn create_stream(&self) -> Self::Stream {
let stream = unsafe {
let mut stream: cubecl_hip_sys::hipStream_t = std::ptr::null_mut();
let stream_status = cubecl_hip_sys::hipStreamCreate(&mut stream);
assert_eq!(stream_status, HIP_SUCCESS, "Should create a stream");
stream
};
let storage = GpuStorage::new(self.mem_alignment);
let memory_management_gpu = MemoryManagement::from_configuration(
storage,
&self.mem_props,
self.mem_config.clone(),
self.logger.clone(),
MemoryManagementOptions::new("Main GPU Memory"),
);
let memory_management_cpu = MemoryManagement::from_configuration(
PinnedMemoryStorage::new(),
&MemoryDeviceProperties {
max_page_size: self.mem_props.max_page_size,
alignment: PINNED_MEMORY_ALIGNMENT as u64,
},
self.mem_config.clone(),
self.logger.clone(),
MemoryManagementOptions::new("Pinned CPU Memory").mode(MemoryAllocationMode::Auto),
);
Stream {
sys: stream,
memory_management_gpu,
memory_management_cpu,
}
}
fn flush(stream: &mut Self::Stream) -> Self::Event {
Fence::new(stream.sys)
}
fn wait_event(stream: &mut Self::Stream, event: Self::Event) {
event.wait_async(stream.sys);
}
fn wait_event_sync(event: Self::Event) -> Result<(), ExecutionError> {
event.wait_sync()
}
}