use crate::error::{CoreError, CoreResult};
use crate::gpu::GpuContext;
use std::any::TypeId;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
#[derive(Debug, thiserror::Error)]
pub enum CrossDeviceError {
#[error("Device not found: {0}")]
DeviceNotFound(String),
#[error("Memory allocation failed on device {device}: {reason}")]
AllocationFailed { device: String, reason: String },
#[error("Data transfer failed from {from} to {to}: {reason}")]
TransferFailed {
from: String,
to: String,
reason: String,
},
#[error("Device synchronization failed: {0}")]
SynchronizationFailed(String),
#[error("Invalid device type: {0}")]
InvalidDeviceType(String),
#[error("Memory allocation not found: {0}")]
MemoryNotFound(String),
}
impl From<CrossDeviceError> for CoreError {
fn from(err: CrossDeviceError) -> Self {
CoreError::ComputationError(crate::error::ErrorContext::new(err.to_string()))
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum DeviceType {
Cpu,
CudaGpu(u32),
RocmGpu(u32),
IntelGpu(u32),
MetalGpu(u32),
Tpu(u32),
OpenClDevice(u32),
}
impl DeviceType {
pub const fn as_str(&self) -> &'static str {
match self {
DeviceType::Cpu => "CPU",
DeviceType::CudaGpu(_) => "CUDA_GPU",
DeviceType::RocmGpu(_) => "ROCM_GPU",
DeviceType::IntelGpu(_) => "INTEL_GPU",
DeviceType::MetalGpu(_) => "METAL_GPU",
DeviceType::Tpu(_) => "TPU",
DeviceType::OpenClDevice(_) => "OPENCL",
}
}
pub fn device_id(&self) -> u32 {
match self {
DeviceType::Cpu => 0,
DeviceType::CudaGpu(id)
| DeviceType::RocmGpu(id)
| DeviceType::IntelGpu(id)
| DeviceType::MetalGpu(id)
| DeviceType::Tpu(id)
| DeviceType::OpenClDevice(id) => *id,
}
}
pub fn supports_unified_memory(&self) -> bool {
matches!(self, DeviceType::CudaGpu(_) | DeviceType::RocmGpu(_))
}
pub fn supports_p2p_transfer(&self, other: &DeviceType) -> bool {
matches!(
(self, other),
(DeviceType::CudaGpu(_), DeviceType::CudaGpu(_))
| (DeviceType::RocmGpu(_), DeviceType::RocmGpu(_))
)
}
}
#[derive(Debug, Clone)]
pub struct MemoryAllocation {
pub id: String,
pub device: DeviceType,
pub size: usize,
pub address: usize,
pub datatype: TypeId,
pub created_at: std::time::Instant,
pub last_accessed: std::time::Instant,
pub ref_count: usize,
}
impl MemoryAllocation {
pub fn new(
allocation_id: String,
device: DeviceType,
size: usize,
address: usize,
datatype: TypeId,
) -> Self {
let now = std::time::Instant::now();
Self {
id: allocation_id,
device,
size,
address,
datatype,
created_at: now,
last_accessed: now,
ref_count: 1,
}
}
pub fn touch(&mut self) {
self.last_accessed = std::time::Instant::now();
}
pub fn add_ref(&mut self) {
self.ref_count += 1;
}
pub fn remove_ref(&mut self) -> usize {
self.ref_count = self.ref_count.saturating_sub(1);
self.ref_count
}
}
pub trait Device: Send + Sync {
fn device_type(&self) -> DeviceType;
fn allocate(&self, size: usize) -> CoreResult<usize>;
fn deallocate(&self, address: usize) -> CoreResult<()>;
unsafe fn copy_from_host(&self, src: *const u8, dst: usize, size: usize) -> CoreResult<()>;
unsafe fn copy_to_host(&self, src: usize, dst: *mut u8, size: usize) -> CoreResult<()>;
fn copy_peer(
&self,
src: usize,
dst_device: &dyn Device,
dst: usize,
size: usize,
) -> CoreResult<()>;
fn synchronize(&self) -> CoreResult<()>;
fn available_memory(&self) -> CoreResult<usize>;
fn total_memory(&self) -> CoreResult<usize>;
}
pub struct CpuDevice {
device_type: DeviceType,
}
impl CpuDevice {
pub fn new() -> Self {
Self {
device_type: DeviceType::Cpu,
}
}
}
impl Default for CpuDevice {
fn default() -> Self {
Self::new()
}
}
impl Device for CpuDevice {
fn device_type(&self) -> DeviceType {
self.device_type.clone()
}
fn allocate(&self, size: usize) -> CoreResult<usize> {
let layout = std::alloc::Layout::from_size_align(size, 64).map_err(|e| {
CrossDeviceError::AllocationFailed {
device: "CPU".to_string(),
reason: e.to_string(),
}
})?;
unsafe {
let ptr = std::alloc::alloc(layout);
if ptr.is_null() {
Err(CrossDeviceError::AllocationFailed {
device: "CPU".to_string(),
reason: "Out of memory".to_string(),
}
.into())
} else {
Ok(ptr as usize)
}
}
}
fn deallocate(&self, address: usize) -> CoreResult<()> {
let _ = address;
Ok(())
}
unsafe fn copy_from_host(&self, src: *const u8, dst: usize, size: usize) -> CoreResult<()> {
std::ptr::copy_nonoverlapping(src, dst as *mut u8, size);
Ok(())
}
unsafe fn copy_to_host(&self, src: usize, dst: *mut u8, size: usize) -> CoreResult<()> {
std::ptr::copy_nonoverlapping(src as *const u8, dst, size);
Ok(())
}
fn copy_peer(
&self,
src: usize,
_dst_device: &dyn Device,
_dst: usize,
_size: usize,
) -> CoreResult<()> {
Err(CrossDeviceError::TransferFailed {
from: "CPU".to_string(),
to: "unknown".to_string(),
reason: "Peer-to-peer not supported for CPU".to_string(),
}
.into())
}
fn synchronize(&self) -> CoreResult<()> {
Ok(())
}
fn available_memory(&self) -> CoreResult<usize> {
Ok(8 * 1024 * 1024 * 1024) }
fn total_memory(&self) -> CoreResult<usize> {
Ok(16 * 1024 * 1024 * 1024) }
}
pub struct GpuContextWrapper {
inner: Arc<GpuContext>,
device_type: DeviceType,
}
impl GpuContextWrapper {
pub fn new(gpu_device: Arc<GpuContext>, devicetype: DeviceType) -> Self {
Self {
inner: gpu_device,
device_type: devicetype,
}
}
}
impl Device for GpuContextWrapper {
fn device_type(&self) -> DeviceType {
self.device_type.clone()
}
fn allocate(&self, size: usize) -> CoreResult<usize> {
let _buffer = self.inner.create_buffer::<u8>(size);
Ok(size) }
fn deallocate(&self, address: usize) -> CoreResult<()> {
Ok(())
}
unsafe fn copy_from_host(&self, src: *const u8, _dst: usize, size: usize) -> CoreResult<()> {
Ok(())
}
unsafe fn copy_to_host(&self, src: usize, _dst: *mut u8, size: usize) -> CoreResult<()> {
Ok(())
}
fn copy_peer(
&self,
src: usize,
_dst_device: &dyn Device,
_dst: usize,
_size: usize,
) -> CoreResult<()> {
Ok(())
}
fn synchronize(&self) -> CoreResult<()> {
Ok(())
}
fn available_memory(&self) -> CoreResult<usize> {
self.inner.get_available_memory().ok_or_else(|| {
CrossDeviceError::DeviceNotFound("GPU memory info unavailable".to_string()).into()
})
}
fn total_memory(&self) -> CoreResult<usize> {
self.inner.get_total_memory().ok_or_else(|| {
CrossDeviceError::DeviceNotFound("GPU memory info unavailable".to_string()).into()
})
}
}
pub struct CrossDeviceMemoryManager {
devices: RwLock<HashMap<DeviceType, Arc<dyn Device>>>,
allocations: RwLock<HashMap<String, MemoryAllocation>>,
allocation_counter: Mutex<u64>,
default_device: RwLock<Option<DeviceType>>,
}
impl CrossDeviceMemoryManager {
pub fn new() -> Self {
Self {
devices: RwLock::new(HashMap::new()),
allocations: RwLock::new(HashMap::new()),
allocation_counter: Mutex::new(0),
default_device: RwLock::new(None),
}
}
pub fn register_device(&self, device: Arc<dyn Device>) -> CoreResult<()> {
let device_type = device.device_type();
let mut devices = self.devices.write().expect("Operation failed");
devices.insert(device_type.clone(), device);
let mut default_device = self.default_device.write().expect("Operation failed");
if default_device.is_none() {
*default_device = Some(device_type);
}
Ok(())
}
pub fn set_default_device(&self, devicetype: DeviceType) -> CoreResult<()> {
let devices = self.devices.read().expect("Operation failed");
if !devices.contains_key(&devicetype) {
return Err(CrossDeviceError::DeviceNotFound(format!("{devicetype:?}")).into());
}
let mut default_device = self.default_device.write().expect("Operation failed");
*default_device = Some(devicetype);
Ok(())
}
pub fn get_default_device(&self) -> Option<DeviceType> {
self.default_device
.read()
.expect("Operation failed")
.clone()
}
pub fn allocate<T: 'static>(
self: &Arc<Self>,
device_type: &DeviceType,
count: usize,
) -> CoreResult<CrossDeviceBuffer<T>> {
let devices = self.devices.read().expect("Operation failed");
let device = devices
.get(device_type)
.ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{device_type:?}")))?;
let size = count * std::mem::size_of::<T>();
let address = device.allocate(size)?;
let allocation_id = self.generate_allocation_id();
let allocation = MemoryAllocation::new(
allocation_id.clone(),
device_type.clone(),
size,
address,
TypeId::of::<T>(),
);
let mut allocations = self.allocations.write().expect("Operation failed");
allocations.insert(allocation_id.clone(), allocation);
Ok(CrossDeviceBuffer::new(
allocation_id,
device_type.clone(),
address,
count,
self.clone(),
))
}
pub fn allocate_default<T: 'static>(
self: &Arc<Self>,
count: usize,
) -> CoreResult<CrossDeviceBuffer<T>> {
let default_device = self
.get_default_device()
.ok_or_else(|| CrossDeviceError::DeviceNotFound("No default device set".to_string()))?;
self.allocate(&default_device, count)
}
pub fn transfer<T: 'static + Copy>(
self: &Arc<Self>,
src_buffer: &CrossDeviceBuffer<T>,
dst_device: &DeviceType,
) -> CoreResult<CrossDeviceBuffer<T>> {
let devices = self.devices.read().expect("Operation failed");
let src_device = devices.get(&src_buffer.device_type).ok_or_else(|| {
CrossDeviceError::DeviceNotFound(format!("{0:?}", src_buffer.device_type))
})?;
let dst_device_obj = devices
.get(dst_device)
.ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{dst_device:?}")))?;
let dst_buffer = self.allocate::<T>(dst_device, src_buffer.count)?;
let size = src_buffer.count * std::mem::size_of::<T>();
if src_buffer.device_type.supports_p2p_transfer(dst_device) {
src_device.copy_peer(
src_buffer.address,
dst_device_obj.as_ref(),
dst_buffer.address,
size,
)?;
} else {
let staging_buffer = self.allocate::<T>(&DeviceType::Cpu, src_buffer.count)?;
unsafe {
src_device.copy_to_host(
src_buffer.address,
staging_buffer.address as *mut u8,
size,
)?;
}
unsafe {
dst_device_obj.copy_from_host(
staging_buffer.address as *const u8,
dst_buffer.address,
size,
)?;
}
}
Ok(dst_buffer)
}
pub fn synchronize_all(&self) -> CoreResult<()> {
let devices = self.devices.read().expect("Operation failed");
for device in devices.values() {
device.synchronize()?;
}
Ok(())
}
pub fn get_memory_statistics(&self) -> MemoryStatistics {
let allocations = self.allocations.read().expect("Operation failed");
let devices = self.devices.read().expect("Operation failed");
let mut stats_by_device = HashMap::new();
let mut total_allocated = 0;
let mut total_allocations = 0;
for allocation in allocations.values() {
let device_stats =
stats_by_device
.entry(allocation.device.clone())
.or_insert(DeviceMemoryStats {
device_type: allocation.device.clone(),
allocated_bytes: 0,
allocation_count: 0,
available_bytes: 0,
total_bytes: 0,
});
device_stats.allocated_bytes += allocation.size;
device_stats.allocation_count += 1;
total_allocated += allocation.size;
total_allocations += 1;
}
for (device_type, device) in devices.iter() {
let device_stats =
stats_by_device
.entry(device_type.clone())
.or_insert(DeviceMemoryStats {
device_type: device_type.clone(),
allocated_bytes: 0,
allocation_count: 0,
available_bytes: 0,
total_bytes: 0,
});
device_stats.available_bytes = device.available_memory().unwrap_or(0);
device_stats.total_bytes = device.total_memory().unwrap_or(0);
}
MemoryStatistics {
total_allocated_bytes: total_allocated,
total_allocations,
device_stats: stats_by_device.into_values().collect(),
}
}
pub fn cleanup_unused_allocations(&self, maxage: std::time::Duration) -> usize {
let mut allocations = self.allocations.write().expect("Operation failed");
let now = std::time::Instant::now();
let mut cleaned = 0;
allocations.retain(|_, allocation| {
if allocation.ref_count == 0 && now.duration_since(allocation.last_accessed) > maxage {
cleaned += 1;
false
} else {
true
}
});
cleaned
}
fn generate_allocation_id(&self) -> String {
let counter = {
let mut counter = self.allocation_counter.lock().expect("Operation failed");
*counter += 1;
*counter
};
format!("{counter:016x}")
}
pub(crate) fn remove_allocation(&self, allocationid: &str) {
let mut allocations = self.allocations.write().expect("Operation failed");
if let Some(allocation) = allocations.get_mut(allocationid) {
if allocation.remove_ref() == 0 {
allocations.remove(allocationid);
}
}
}
pub(crate) fn touch_allocation(&self, allocationid: &str) {
let mut allocations = self.allocations.write().expect("Operation failed");
if let Some(allocation) = allocations.get_mut(allocationid) {
allocation.touch();
}
}
}
impl Default for CrossDeviceMemoryManager {
fn default() -> Self {
Self::new()
}
}
pub struct CrossDeviceBuffer<T> {
allocation_id: String,
device_type: DeviceType,
address: usize,
count: usize,
manager: Arc<CrossDeviceMemoryManager>,
phantom: std::marker::PhantomData<T>,
}
impl<T> CrossDeviceBuffer<T> {
fn new(
allocation_id: String,
device_type: DeviceType,
address: usize,
count: usize,
manager: Arc<CrossDeviceMemoryManager>,
) -> Self {
Self {
allocation_id,
device_type,
address,
count,
manager,
phantom: std::marker::PhantomData,
}
}
pub const fn device_type(&self) -> &DeviceType {
&self.device_type
}
pub fn len(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn size_bytes(&self) -> usize {
self.count * std::mem::size_of::<T>()
}
pub fn raw_address(&self) -> usize {
self.manager.touch_allocation(&self.allocation_id);
self.address
}
pub fn to_device(&self, devicetype: &DeviceType) -> CoreResult<CrossDeviceBuffer<T>>
where
T: Copy + 'static,
{
self.manager.transfer(self, devicetype)
}
pub fn copy_from_host(&self, data: &[T]) -> CoreResult<()>
where
T: Copy,
{
if data.len() != self.count {
return Err(CrossDeviceError::InvalidDeviceType(format!(
"Data length {} doesn't match buffer capacity {}",
data.len(),
self.count
))
.into());
}
let devices = self.manager.devices.read().expect("Operation failed");
let device = devices
.get(&self.device_type)
.ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{0:?}", self.device_type)))?;
unsafe {
device.copy_from_host(data.as_ptr() as *const u8, self.address, self.size_bytes())?;
}
self.manager.touch_allocation(&self.allocation_id);
Ok(())
}
pub fn copy_to_host(&self) -> CoreResult<Vec<T>>
where
T: Copy + Default,
{
let mut result = vec![T::default(); self.count];
let devices = self.manager.devices.read().expect("Operation failed");
let device = devices
.get(&self.device_type)
.ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{0:?}", self.device_type)))?;
unsafe {
device.copy_to_host(
self.address,
result.as_mut_ptr() as *mut u8,
self.size_bytes(),
)?;
}
self.manager.touch_allocation(&self.allocation_id);
Ok(result)
}
}
impl<T> Clone for CrossDeviceBuffer<T> {
fn clone(&self) -> Self {
{
let mut allocations = self.manager.allocations.write().expect("Operation failed");
if let Some(allocation) = allocations.get_mut(&self.allocation_id) {
allocation.add_ref();
}
}
Self {
allocation_id: self.allocation_id.clone(),
device_type: self.device_type.clone(),
address: self.address,
count: self.count,
manager: self.manager.clone(),
phantom: std::marker::PhantomData,
}
}
}
impl<T> Drop for CrossDeviceBuffer<T> {
fn drop(&mut self) {
self.manager.remove_allocation(&self.allocation_id);
}
}
#[derive(Debug, Clone)]
pub struct MemoryStatistics {
pub total_allocated_bytes: usize,
pub total_allocations: usize,
pub device_stats: Vec<DeviceMemoryStats>,
}
#[derive(Debug, Clone)]
pub struct DeviceMemoryStats {
pub device_type: DeviceType,
pub allocated_bytes: usize,
pub allocation_count: usize,
pub available_bytes: usize,
pub total_bytes: usize,
}
impl DeviceMemoryStats {
pub fn usage_percentage(&self) -> f64 {
if self.total_bytes == 0 {
0.0
} else {
(self.allocated_bytes as f64 / self.total_bytes as f64) * 100.0
}
}
}
static GLOBAL_MANAGER: std::sync::OnceLock<Arc<CrossDeviceMemoryManager>> =
std::sync::OnceLock::new();
#[allow(dead_code)]
pub fn global_manager() -> Arc<CrossDeviceMemoryManager> {
GLOBAL_MANAGER
.get_or_init(|| {
let manager = Arc::new(CrossDeviceMemoryManager::new());
let cpu_device = Arc::new(CpuDevice::new());
let _ = manager.register_device(cpu_device);
manager
})
.clone()
}
#[allow(dead_code)]
pub fn initialize_with_gpu_devices(gpudevices: Vec<Arc<GpuContext>>) -> CoreResult<()> {
let manager = global_manager();
for (i, gpu_device) in gpudevices.into_iter().enumerate() {
let device_type = DeviceType::CudaGpu(i as u32); let wrapper = Arc::new(GpuContextWrapper::new(gpu_device, device_type));
manager.register_device(wrapper)?;
}
Ok(())
}
pub mod utils {
use super::*;
pub fn allocate_optimal<T: 'static>(count: usize) -> CoreResult<CrossDeviceBuffer<T>> {
let manager = global_manager();
let stats = manager.get_memory_statistics();
let best_device = stats
.device_stats
.iter()
.max_by_key(|s| s.available_bytes)
.map(|s| s.device_type.clone())
.unwrap_or(DeviceType::Cpu);
manager.allocate(&best_device, count)
}
pub fn create_buffer_with_data<T: Copy + 'static>(
data: &[T],
device_type: &DeviceType,
) -> CoreResult<CrossDeviceBuffer<T>> {
let manager = global_manager();
let buffer = manager.allocate(device_type, data.len())?;
buffer.copy_from_host(data)?;
Ok(buffer)
}
pub fn transfer_data<T: Copy + 'static>(
src_buffer: &CrossDeviceBuffer<T>,
dst_device: &DeviceType,
) -> CoreResult<CrossDeviceBuffer<T>> {
let manager = global_manager();
manager.transfer(src_buffer, dst_device)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_type_creation() {
let cpu = DeviceType::Cpu;
let gpu = DeviceType::CudaGpu(0);
let tpu = DeviceType::Tpu(1);
assert_eq!(cpu.as_str(), "CPU");
assert_eq!(gpu.as_str(), "CUDA_GPU");
assert_eq!(tpu.as_str(), "TPU");
assert_eq!(cpu.device_id(), 0);
assert_eq!(gpu.device_id(), 0);
assert_eq!(tpu.device_id(), 1);
}
#[test]
fn test_device_capabilities() {
let cpu = DeviceType::Cpu;
let cuda = DeviceType::CudaGpu(0);
let rocm = DeviceType::RocmGpu(0);
assert!(!cpu.supports_unified_memory());
assert!(cuda.supports_unified_memory());
assert!(rocm.supports_unified_memory());
assert!(cuda.supports_p2p_transfer(&DeviceType::CudaGpu(1)));
assert!(!cuda.supports_p2p_transfer(&DeviceType::RocmGpu(0)));
assert!(!cpu.supports_p2p_transfer(&DeviceType::CudaGpu(0)));
}
#[test]
fn test_memory_allocation_creation() {
let allocation = MemoryAllocation::new(
"test_alloc".to_string(),
DeviceType::Cpu,
1024,
0x1000,
TypeId::of::<f32>(),
);
assert_eq!(allocation.id, "test_alloc");
assert_eq!(allocation.size, 1024);
assert_eq!(allocation.address, 0x1000);
assert_eq!(allocation.ref_count, 1);
}
#[test]
fn test_cpu_device() {
let cpu = CpuDevice::new();
assert_eq!(cpu.device_type(), DeviceType::Cpu);
assert!(cpu.available_memory().is_ok());
assert!(cpu.total_memory().is_ok());
assert!(cpu.synchronize().is_ok());
}
#[test]
fn test_cross_device_manager() {
let manager = CrossDeviceMemoryManager::new();
let cpu_device = Arc::new(CpuDevice::new());
assert!(manager.register_device(cpu_device).is_ok());
assert_eq!(manager.get_default_device(), Some(DeviceType::Cpu));
let stats = manager.get_memory_statistics();
assert_eq!(stats.total_allocations, 0);
assert_eq!(stats.total_allocated_bytes, 0);
}
#[test]
fn test_global_manager() {
let manager = global_manager();
assert_eq!(manager.get_default_device(), Some(DeviceType::Cpu));
let stats = manager.get_memory_statistics();
assert!(!stats.device_stats.is_empty());
}
#[test]
fn test_memory_statistics() {
let stats = DeviceMemoryStats {
device_type: DeviceType::Cpu,
allocated_bytes: 1024,
allocation_count: 1,
available_bytes: 7 * 1024 * 1024 * 1024,
total_bytes: 8 * 1024 * 1024 * 1024,
};
let usage = stats.usage_percentage();
assert!(usage > 0.0 && usage < 1.0);
}
}