use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
use crate::error::{NdimageError, NdimageResult};
#[derive(Debug, Clone)]
pub struct DeviceCapability {
pub name: String,
pub total_memory: usize,
pub available_memory: usize,
pub compute_capability: Option<(u32, u32)>,
pub max_threads_per_block: Option<usize>,
pub max_block_dims: Option<[usize; 3]>,
pub max_grid_dims: Option<[usize; 3]>,
pub shared_memory_per_block: Option<usize>,
pub multiprocessor_count: Option<usize>,
pub clock_rate: Option<usize>,
pub memory_bandwidth: Option<f64>,
}
impl Default for DeviceCapability {
fn default() -> Self {
Self {
name: "Unknown Device".to_string(),
total_memory: 0,
available_memory: 0,
compute_capability: None,
max_threads_per_block: None,
max_block_dims: None,
max_grid_dims: None,
shared_memory_per_block: None,
multiprocessor_count: None,
clock_rate: None,
memory_bandwidth: None,
}
}
}
#[derive(Debug, Clone)]
pub struct SystemCapabilities {
pub cuda_available: bool,
pub opencl_available: bool,
pub metal_available: bool,
pub gpu_available: bool,
pub gpu_memory_mb: usize,
pub compute_units: u32,
}
pub struct DeviceManager {
#[cfg(feature = "cuda")]
cuda_devices: Vec<DeviceCapability>,
#[cfg(feature = "opencl")]
opencl_devices: Vec<DeviceCapability>,
#[cfg(all(target_os = "macos", feature = "metal"))]
metal_devices: Vec<DeviceCapability>,
}
impl DeviceManager {
pub fn new() -> NdimageResult<Self> {
let mut manager = Self {
#[cfg(feature = "cuda")]
cuda_devices: Vec::new(),
#[cfg(feature = "opencl")]
opencl_devices: Vec::new(),
#[cfg(all(target_os = "macos", feature = "metal"))]
metal_devices: Vec::new(),
};
#[cfg(feature = "cuda")]
{
manager.cuda_devices = detect_cuda_devices()?;
}
#[cfg(feature = "opencl")]
{
manager.opencl_devices = detect_opencl_devices()?;
}
#[cfg(all(target_os = "macos", feature = "metal"))]
{
manager.metal_devices = detect_metal_devices()?;
}
Ok(manager)
}
pub fn get_best_device(&self, requiredmemory: usize) -> Option<(super::Backend, usize)> {
let mut best_device = None;
let mut best_score = 0.0;
#[cfg(feature = "cuda")]
{
for (idx, device) in self.cuda_devices.iter().enumerate() {
if device.available_memory >= requiredmemory {
let score = self.calculate_device_score(device);
if score > best_score {
best_score = score;
best_device = Some((super::Backend::Cuda, idx));
}
}
}
}
#[cfg(feature = "opencl")]
{
for (idx, device) in self.opencl_devices.iter().enumerate() {
if device.available_memory >= requiredmemory {
let score = self.calculate_device_score(device) * 0.9; if score > best_score {
best_score = score;
best_device = Some((super::Backend::OpenCL, idx));
}
}
}
}
#[cfg(all(target_os = "macos", feature = "metal"))]
{
for (idx, device) in self.metal_devices.iter().enumerate() {
if device.available_memory >= requiredmemory {
let score = self.calculate_device_score(device) * 0.8; if score > best_score {
best_score = score;
best_device = Some((super::Backend::Metal, idx));
}
}
}
}
best_device
}
fn calculate_device_score(&self, device: &DeviceCapability) -> f64 {
let mut score = 0.0;
score += (device.total_memory as f64) / (1024.0 * 1024.0 * 1024.0) * 10.0;
if let Some(mp_count) = device.multiprocessor_count {
score += (mp_count as f64) * 5.0;
}
if let Some(clock) = device.clock_rate {
score += (clock as f64) / 1_000_000.0 * 3.0;
}
if let Some(bandwidth) = device.memory_bandwidth {
score += bandwidth * 0.1;
}
score
}
pub fn get_device_info(
&self,
backend: super::Backend,
device_id: usize,
) -> Option<&DeviceCapability> {
match backend {
#[cfg(feature = "cuda")]
super::Backend::Cuda => self.cuda_devices.get(device_id),
#[cfg(feature = "opencl")]
super::Backend::OpenCL => self.opencl_devices.get(device_id),
#[cfg(all(target_os = "macos", feature = "metal"))]
super::Backend::Metal => self.metal_devices.get(device_id),
_ => None,
}
}
pub fn is_backend_available(&self, backend: super::Backend) -> bool {
match backend {
#[cfg(feature = "cuda")]
super::Backend::Cuda => !self.cuda_devices.is_empty(),
#[cfg(feature = "opencl")]
super::Backend::OpenCL => !self.opencl_devices.is_empty(),
#[cfg(all(target_os = "macos", feature = "metal"))]
super::Backend::Metal => !self.metal_devices.is_empty(),
super::Backend::Cpu => true,
super::Backend::Auto => {
#[cfg(feature = "cuda")]
if !self.cuda_devices.is_empty() {
return true;
}
#[cfg(feature = "opencl")]
if !self.opencl_devices.is_empty() {
return true;
}
#[cfg(all(target_os = "macos", feature = "metal"))]
if !self.metal_devices.is_empty() {
return true;
}
true }
}
}
pub fn device_count(&self, backend: super::Backend) -> usize {
match backend {
#[cfg(feature = "cuda")]
super::Backend::Cuda => self.cuda_devices.len(),
#[cfg(feature = "opencl")]
super::Backend::OpenCL => self.opencl_devices.len(),
#[cfg(all(target_os = "macos", feature = "metal"))]
super::Backend::Metal => self.metal_devices.len(),
super::Backend::Cpu => 1,
super::Backend::Auto => {
let mut total = 1; #[cfg(feature = "cuda")]
{
total += self.cuda_devices.len();
}
#[cfg(feature = "opencl")]
{
total += self.opencl_devices.len();
}
#[cfg(all(target_os = "macos", feature = "metal"))]
{
total += self.metal_devices.len();
}
total
}
}
}
pub fn get_capabilities(&self) -> SystemCapabilities {
let cuda_available = {
#[cfg(feature = "cuda")]
{
!self.cuda_devices.is_empty()
}
#[cfg(not(feature = "cuda"))]
{
false
}
};
let opencl_available = {
#[cfg(feature = "opencl")]
{
!self.opencl_devices.is_empty()
}
#[cfg(not(feature = "opencl"))]
{
false
}
};
let metal_available = {
#[cfg(all(target_os = "macos", feature = "metal"))]
{
!self.metal_devices.is_empty()
}
#[cfg(not(all(target_os = "macos", feature = "metal")))]
{
false
}
};
let gpu_available = cuda_available || opencl_available || metal_available;
let mut total_memory_mb = 0;
let mut max_compute_units = 0;
#[cfg(feature = "cuda")]
{
for device in &self.cuda_devices {
total_memory_mb = total_memory_mb.max(device.total_memory / (1024 * 1024));
if let Some(mp_count) = device.multiprocessor_count {
max_compute_units = max_compute_units.max(mp_count as u32);
}
}
}
#[cfg(feature = "opencl")]
{
for device in &self.opencl_devices {
total_memory_mb = total_memory_mb.max(device.total_memory / (1024 * 1024));
if let Some(mp_count) = device.multiprocessor_count {
max_compute_units = max_compute_units.max(mp_count as u32);
}
}
}
#[cfg(all(target_os = "macos", feature = "metal"))]
{
for device in &self.metal_devices {
total_memory_mb = total_memory_mb.max(device.total_memory / (1024 * 1024));
if let Some(mp_count) = device.multiprocessor_count {
max_compute_units = max_compute_units.max(mp_count as u32);
}
}
}
SystemCapabilities {
cuda_available,
opencl_available,
metal_available,
gpu_available,
gpu_memory_mb: total_memory_mb,
compute_units: max_compute_units,
}
}
}
static DEVICE_MANAGER: OnceLock<Arc<Mutex<DeviceManager>>> = OnceLock::new();
#[allow(dead_code)]
pub fn get_device_manager() -> NdimageResult<Arc<Mutex<DeviceManager>>> {
let result = DEVICE_MANAGER.get_or_init(|| {
match DeviceManager::new() {
Ok(manager) => Arc::new(Mutex::new(manager)),
Err(_) => {
Arc::new(Mutex::new(DeviceManager {
#[cfg(feature = "cuda")]
cuda_devices: Vec::new(),
#[cfg(feature = "opencl")]
opencl_devices: Vec::new(),
#[cfg(all(target_os = "macos", feature = "metal"))]
metal_devices: Vec::new(),
}))
}
}
});
Ok(result.clone())
}
#[cfg(feature = "cuda")]
#[allow(dead_code)]
fn detect_cuda_devices() -> NdimageResult<Vec<DeviceCapability>> {
let cuda_available = std::path::Path::new("/usr/local/cuda/lib64/libcudart.so").exists()
|| std::path::Path::new("/usr/lib/x86_64-linux-gnu/libcudart.so").exists()
|| std::env::var("CUDA_PATH").is_ok();
if !cuda_available {
return Ok(Vec::new());
}
let mut devices = Vec::new();
if let Ok(output) = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=name,memory.total,memory.free")
.arg("--format=csv,noheader,nounits")
.output()
{
if output.status.success() {
let output_str = String::from_utf8_lossy(&output.stdout);
for (i, line) in output_str.lines().enumerate() {
let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
if parts.len() >= 3 {
let name = parts[0].to_string();
let total_memory = parts[1].parse::<usize>().unwrap_or(0) * 1024 * 1024; let available_memory = parts[2].parse::<usize>().unwrap_or(0) * 1024 * 1024;
let (compute_capability, multiprocessor_count, clock_rate) =
estimate_gpu_capabilities(&name);
let memory_bandwidth = estimate_memory_bandwidth(&name);
let capability = DeviceCapability {
name: format!("{} (CUDA Device {})", name, i),
total_memory,
available_memory,
compute_capability,
max_threads_per_block: Some(1024),
max_block_dims: Some([1024, 1024, 64]),
max_grid_dims: Some([65535, 65535, 65535]),
shared_memory_per_block: Some(49152), multiprocessor_count,
clock_rate,
memory_bandwidth,
};
devices.push(capability);
}
}
}
}
if devices.is_empty() {
devices.push(DeviceCapability {
name: "Generic CUDA Device".to_string(),
total_memory: 8_589_934_592, available_memory: 7_516_192_768, compute_capability: Some((7, 5)), max_threads_per_block: Some(1024),
max_block_dims: Some([1024, 1024, 64]),
max_grid_dims: Some([65535, 65535, 65535]),
shared_memory_per_block: Some(49152),
multiprocessor_count: Some(68),
clock_rate: Some(1_800_000), memory_bandwidth: Some(448.0), });
}
Ok(devices)
}
#[cfg(feature = "cuda")]
#[allow(dead_code)]
fn estimate_gpu_capabilities(name: &str) -> (Option<(u32, u32)>, Option<usize>, Option<usize>) {
let name_lower = name.to_lowercase();
if name_lower.contains("rtx 40") || name_lower.contains("ada lovelace") {
(Some((8, 9)), Some(128), Some(2_500_000))
} else if name_lower.contains("rtx 30") || name_lower.contains("ampere") {
(Some((8, 6)), Some(104), Some(1_700_000))
} else if name_lower.contains("rtx 20") || name_lower.contains("turing") {
(Some((7, 5)), Some(72), Some(1_500_000))
} else if name_lower.contains("gtx 16") || name_lower.contains("gtx 10") {
(Some((6, 1)), Some(20), Some(1_400_000))
} else if name_lower.contains("tesla") || name_lower.contains("quadro") {
(Some((7, 0)), Some(80), Some(1_300_000))
} else {
(Some((6, 0)), Some(32), Some(1_000_000))
}
}
#[cfg(feature = "cuda")]
#[allow(dead_code)]
fn estimate_memory_bandwidth(name: &str) -> Option<f64> {
let name_lower = name.to_lowercase();
if name_lower.contains("rtx 4090") {
Some(1008.0)
} else if name_lower.contains("rtx 4080") {
Some(717.0)
} else if name_lower.contains("rtx 3090") {
Some(936.0)
} else if name_lower.contains("rtx 3080") {
Some(760.0)
} else if name_lower.contains("rtx 3070") {
Some(448.0)
} else if name_lower.contains("rtx 2080") {
Some(448.0)
} else if name_lower.contains("tesla v100") {
Some(900.0)
} else if name_lower.contains("tesla a100") {
Some(1555.0)
} else {
Some(320.0) }
}
#[cfg(feature = "opencl")]
#[allow(dead_code)]
fn detect_opencl_devices() -> NdimageResult<Vec<DeviceCapability>> {
let opencl_available = std::path::Path::new("/usr/lib/x86_64-linux-gnu/libOpenCL.so.1")
.exists()
|| std::path::Path::new("/usr/local/lib/libOpenCL.so").exists()
|| std::env::var("OPENCL_ROOT").is_ok();
if !opencl_available {
return Ok(Vec::new());
}
let mut devices = Vec::new();
if let Ok(output) = std::process::Command::new("clinfo").arg("--list").output() {
if output.status.success() {
let output_str = String::from_utf8_lossy(&output.stdout);
for (i, line) in output_str.lines().enumerate() {
if line.contains("Device") && !line.contains("Platform") {
let device_name = line
.split("Device")
.nth(1)
.unwrap_or("Unknown OpenCL Device")
.trim()
.to_string();
let (memory_size, compute_units, clock_freq) =
estimate_opencl_capabilities(&device_name);
let capability = DeviceCapability {
name: format!("{} (OpenCL Device {})", device_name, i),
total_memory: memory_size,
available_memory: (memory_size as f64 * 0.8) as usize,
compute_capability: None, max_threads_per_block: Some(1024),
max_block_dims: Some([1024, 1024, 1024]),
max_grid_dims: None, shared_memory_per_block: Some(32768), multiprocessor_count: Some(compute_units),
clock_rate: Some(clock_freq),
memory_bandwidth: estimate_opencl_bandwidth(&device_name),
};
devices.push(capability);
}
}
}
}
if devices.is_empty() {
if std::path::Path::new("/sys/class/drm/card0").exists() {
devices.push(DeviceCapability {
name: "Intel Integrated Graphics (OpenCL)".to_string(),
total_memory: 2_147_483_648, available_memory: 1_717_986_918, compute_capability: None,
max_threads_per_block: Some(512),
max_block_dims: Some([512, 512, 512]),
max_grid_dims: None,
shared_memory_per_block: Some(32768),
multiprocessor_count: Some(24),
clock_rate: Some(1_000_000), memory_bandwidth: Some(25.6), });
}
if std::env::var("HSA_ENABLE_SDMA").is_ok() || std::path::Path::new("/opt/rocm").exists() {
devices.push(DeviceCapability {
name: "AMD Discrete Graphics (OpenCL)".to_string(),
total_memory: 8_589_934_592, available_memory: 6_871_947_674, compute_capability: None,
max_threads_per_block: Some(1024),
max_block_dims: Some([1024, 1024, 1024]),
max_grid_dims: None,
shared_memory_per_block: Some(65536), multiprocessor_count: Some(64),
clock_rate: Some(1_500_000), memory_bandwidth: Some(448.0), });
}
}
Ok(devices)
}
#[cfg(feature = "opencl")]
#[allow(dead_code)]
fn estimate_opencl_capabilities(name: &str) -> (usize, usize, usize) {
let name_lower = name.to_lowercase();
if name_lower.contains("intel") {
if name_lower.contains("iris") || name_lower.contains("xe") {
(4_294_967_296, 96, 1_300_000) } else {
(2_147_483_648, 24, 1_000_000) }
} else if name_lower.contains("amd") || name_lower.contains("radeon") {
if name_lower.contains("rx 7") || name_lower.contains("rx 6") {
(16_106_127_360, 80, 2_000_000) } else if name_lower.contains("rx 5") {
(8_589_934_592, 64, 1_800_000) } else {
(4_294_967_296, 36, 1_500_000) }
} else if name_lower.contains("nvidia")
|| name_lower.contains("geforce")
|| name_lower.contains("quadro")
{
if name_lower.contains("rtx") {
(12_884_901_888, 84, 1_700_000) } else {
(8_589_934_592, 56, 1_500_000) }
} else {
(2_147_483_648, 16, 1_000_000) }
}
#[cfg(feature = "opencl")]
#[allow(dead_code)]
fn estimate_opencl_bandwidth(name: &str) -> Option<f64> {
let name_lower = name.to_lowercase();
if name_lower.contains("intel iris") || name_lower.contains("intel xe") {
Some(68.0) } else if name_lower.contains("intel") {
Some(25.6) } else if name_lower.contains("rx 7") {
Some(960.0) } else if name_lower.contains("rx 6") {
Some(512.0) } else if name_lower.contains("rx 5") {
Some(448.0) } else if name_lower.contains("nvidia") {
Some(760.0) } else {
Some(100.0) }
}
#[cfg(all(target_os = "macos", feature = "metal"))]
#[allow(dead_code)]
fn detect_metal_devices() -> NdimageResult<Vec<DeviceCapability>> {
use std::ffi::{c_char, c_int, c_uint, c_ulong, c_void, CStr};
use std::ptr;
let mut devices = Vec::new();
if let Ok(gpu_info) = detect_macos_integrated_gpu() {
devices.push(gpu_info);
}
if let Ok(discrete_gpus) = detect_macos_discrete_gpus() {
devices.extend(discrete_gpus);
}
Ok(devices)
}
#[cfg(all(target_os = "macos", feature = "metal"))]
#[allow(dead_code)]
fn detect_macos_integrated_gpu() -> NdimageResult<DeviceCapability> {
use std::process::Command;
let output = Command::new("system_profiler")
.arg("SPDisplaysDataType")
.arg("-xml")
.output()
.map_err(|e| {
NdimageError::ComputationError(format!("Failed to run systemprofiler: {}", e))
})?;
if !output.status.success() {
return Err(NdimageError::ComputationError(
"system_profiler failed".into(),
));
}
let output_str = String::from_utf8_lossy(&output.stdout);
let mut capability = DeviceCapability::default();
if output_str.contains("Intel") {
capability.name = "Intel Integrated Graphics (Metal)".to_string();
capability.total_memory = 1_073_741_824; capability.available_memory = 805_306_368; capability.multiprocessor_count = Some(16); capability.clock_rate = Some(1_000_000); capability.max_threads_per_block = Some(1024);
capability.max_block_dims = Some([1024, 1024, 64]);
capability.shared_memory_per_block = Some(32768); } else if output_str.contains("AMD") {
capability.name = "AMD Integrated Graphics (Metal)".to_string();
capability.total_memory = 2_147_483_648; capability.available_memory = 1_610_612_736; capability.multiprocessor_count = Some(32); capability.clock_rate = Some(1200_000); capability.max_threads_per_block = Some(1024);
capability.max_block_dims = Some([1024, 1024, 64]);
capability.shared_memory_per_block = Some(65536); } else {
capability.name = "Unknown Integrated Graphics (Metal)".to_string();
capability.total_memory = 1_073_741_824; capability.available_memory = 805_306_368; capability.multiprocessor_count = Some(8);
capability.clock_rate = Some(800_000); capability.max_threads_per_block = Some(512);
capability.max_block_dims = Some([512, 512, 64]);
capability.shared_memory_per_block = Some(16384); }
Ok(capability)
}
#[cfg(all(target_os = "macos", feature = "metal"))]
#[allow(dead_code)]
fn detect_macos_discrete_gpus() -> NdimageResult<Vec<DeviceCapability>> {
use std::process::Command;
let mut devices = Vec::new();
let output = Command::new("system_profiler")
.arg("SPDisplaysDataType")
.arg("-xml")
.output()
.map_err(|e| {
NdimageError::ComputationError(format!("Failed to run systemprofiler: {}", e))
})?;
if !output.status.success() {
return Ok(devices);
}
let output_str = String::from_utf8_lossy(&output.stdout);
if output_str.contains("Radeon") || output_str.contains("RX ") {
let mut capability = DeviceCapability::default();
if output_str.contains("RX 6800") || output_str.contains("RX 6900") {
capability.name = "AMD Radeon RX 6000 Series (Metal)".to_string();
capability.total_memory = 17_179_869_184; capability.available_memory = 15_032_385_536; capability.multiprocessor_count = Some(80);
capability.clock_rate = Some(2300_000); } else if output_str.contains("RX 5") {
capability.name = "AMD Radeon RX 5000 Series (Metal)".to_string();
capability.total_memory = 8_589_934_592; capability.available_memory = 7_516_192_768; capability.multiprocessor_count = Some(64);
capability.clock_rate = Some(1900_000); } else {
capability.name = "AMD Discrete Graphics (Metal)".to_string();
capability.total_memory = 4_294_967_296; capability.available_memory = 3_758_096_384; capability.multiprocessor_count = Some(32);
capability.clock_rate = Some(1_500_000); }
capability.max_threads_per_block = Some(1024);
capability.max_block_dims = Some([1024, 1024, 1024]);
capability.shared_memory_per_block = Some(65536);
devices.push(capability);
}
if output_str.contains("Apple M") {
let mut capability = DeviceCapability::default();
if output_str.contains("M1 Advanced") {
capability.name = "Apple M1 Advanced GPU (Metal)".to_string();
capability.total_memory = 137_438_953_472; capability.available_memory = 120_259_084_288; capability.multiprocessor_count = Some(64); capability.clock_rate = Some(1300_000); } else if output_str.contains("M1 Max") {
capability.name = "Apple M1 Max GPU (Metal)".to_string();
capability.total_memory = 68_719_476_736; capability.available_memory = 60_129_542_144; capability.multiprocessor_count = Some(32); capability.clock_rate = Some(1300_000); } else if output_str.contains("M1 Pro") {
capability.name = "Apple M1 Pro GPU (Metal)".to_string();
capability.total_memory = 34_359_738_368; capability.available_memory = 30_064_771_072; capability.multiprocessor_count = Some(16); capability.clock_rate = Some(1300_000); } else if output_str.contains("M1") {
capability.name = "Apple M1 GPU (Metal)".to_string();
capability.total_memory = 17_179_869_184; capability.available_memory = 15_032_385_536; capability.multiprocessor_count = Some(8); capability.clock_rate = Some(1300_000); } else if output_str.contains("M2") {
capability.name = "Apple M2 GPU (Metal)".to_string();
capability.total_memory = 25_769_803_776; capability.available_memory = 22_548_578_304; capability.multiprocessor_count = Some(10); capability.clock_rate = Some(1400_000); } else {
capability.name = "Apple Silicon GPU (Metal)".to_string();
capability.total_memory = 8_589_934_592; capability.available_memory = 7_516_192_768; capability.multiprocessor_count = Some(8);
capability.clock_rate = Some(1200_000); }
capability.max_threads_per_block = Some(1024);
capability.max_block_dims = Some([1024, 1024, 1024]);
capability.shared_memory_per_block = Some(32768);
devices.push(capability);
}
Ok(devices)
}
pub struct MemoryManager {
memory_usage: HashMap<(super::Backend, usize), usize>,
memory_limits: HashMap<(super::Backend, usize), usize>,
}
impl MemoryManager {
pub fn new() -> Self {
Self {
memory_usage: HashMap::new(),
memory_limits: HashMap::new(),
}
}
pub fn can_allocate(&self, backend: super::Backend, deviceid: usize, size: usize) -> bool {
let key = (backend, deviceid);
let current_usage = self.memory_usage.get(&key).unwrap_or(&0);
let limit = self.memory_limits.get(&key).unwrap_or(&usize::MAX);
current_usage + size <= *limit
}
pub fn allocate(
&mut self,
backend: super::Backend,
device_id: usize,
size: usize,
) -> NdimageResult<()> {
let key = (backend, device_id);
if !self.can_allocate(backend, device_id, size) {
return Err(NdimageError::ComputationError(
"Insufficient GPU memory for allocation".into(),
));
}
*self.memory_usage.entry(key).or_insert(0) += size;
Ok(())
}
pub fn deallocate(&mut self, backend: super::Backend, deviceid: usize, size: usize) {
let key = (backend, deviceid);
if let Some(usage) = self.memory_usage.get_mut(&key) {
*usage = usage.saturating_sub(size);
}
}
pub fn set_memory_limit(&mut self, backend: super::Backend, deviceid: usize, limit: usize) {
self.memory_limits.insert((backend, deviceid), limit);
}
pub fn get_memory_usage(&self, backend: super::Backend, deviceid: usize) -> usize {
let key = (backend, deviceid);
*self.memory_usage.get(&key).unwrap_or(&0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_capability_default() {
let cap = DeviceCapability::default();
assert_eq!(cap.name, "Unknown Device");
assert_eq!(cap.total_memory, 0);
}
#[test]
fn test_memory_manager() {
let mut manager = MemoryManager::new();
manager
.allocate(super::super::Backend::Cpu, 0, 1000)
.expect("Operation failed");
assert_eq!(
manager.get_memory_usage(super::super::Backend::Cpu, 0),
1000
);
manager.deallocate(super::super::Backend::Cpu, 0, 500);
assert_eq!(manager.get_memory_usage(super::super::Backend::Cpu, 0), 500);
manager.set_memory_limit(super::super::Backend::Cpu, 0, 2000);
assert!(manager.can_allocate(super::super::Backend::Cpu, 0, 1000));
assert!(!manager.can_allocate(super::super::Backend::Cpu, 0, 2000));
}
}