use super::DeviceType;
use crate::error::{RusTorchError, RusTorchResult};
use std::collections::HashMap;
use std::sync::Arc;
pub trait GpuDevice: Send + Sync {
fn id(&self) -> usize;
fn name(&self) -> String;
fn device_type(&self) -> String;
fn is_available(&self) -> bool;
fn total_memory(&self) -> usize;
fn allocated_memory(&self) -> usize;
fn compute_capability(&self) -> Option<(u32, u32)>;
fn is_cpu(&self) -> bool;
fn synchronize(&self);
fn create_stream(&self) -> Arc<dyn GpuStream>;
}
pub trait GpuStream: Send + Sync {
fn synchronize(&self);
fn id(&self) -> usize;
}
pub struct GpuBackend {
devices: Vec<Arc<dyn GpuDevice>>,
}
impl GpuBackend {
pub fn new() -> Self {
#[allow(unused_mut)]
let mut devices: Vec<Arc<dyn GpuDevice>> = vec![Arc::new(CpuDevice::new())];
#[cfg(feature = "cuda")]
{
for device_id in 0..Self::get_cuda_device_count() {
if let Ok(device) = CudaDevice::new(device_id) {
devices.push(Arc::new(device));
}
}
}
#[cfg(feature = "metal")]
{
if let Ok(device) = MetalDevice::new() {
devices.push(Arc::new(device));
}
}
#[cfg(feature = "opencl")]
{
for device_id in 0..Self::get_opencl_device_count() {
if let Ok(device) = OpenCLDevice::new(0, device_id) {
devices.push(Arc::new(device));
}
}
}
Self { devices }
}
pub fn list_devices(&self) -> &[Arc<dyn GpuDevice>] {
&self.devices
}
pub fn get_device(&self, id: usize) -> Option<&Arc<dyn GpuDevice>> {
self.devices.get(id)
}
#[cfg(feature = "cuda")]
fn get_cuda_device_count() -> usize {
(0..8).filter(|&i| CudaDevice::new(i).is_ok()).count()
}
#[cfg(feature = "opencl")]
fn get_opencl_device_count() -> usize {
(0..8).filter(|&i| OpenCLDevice::new(0, i).is_ok()).count()
}
}
impl Default for GpuBackend {
fn default() -> Self {
Self::new()
}
}
pub struct CpuDevice {
id: usize,
}
impl CpuDevice {
pub fn new() -> Self {
Self { id: 0 }
}
}
impl Default for CpuDevice {
fn default() -> Self {
Self::new()
}
}
impl GpuDevice for CpuDevice {
fn id(&self) -> usize {
self.id
}
fn name(&self) -> String {
"CPU".to_string()
}
fn device_type(&self) -> String {
"cpu".to_string()
}
fn is_available(&self) -> bool {
true
}
fn total_memory(&self) -> usize {
8 * 1024 * 1024 * 1024 }
fn allocated_memory(&self) -> usize {
0 }
fn compute_capability(&self) -> Option<(u32, u32)> {
None
}
fn is_cpu(&self) -> bool {
true
}
fn synchronize(&self) {
}
fn create_stream(&self) -> Arc<dyn GpuStream> {
Arc::new(CpuStream::new())
}
}
pub struct CpuStream;
impl CpuStream {
pub fn new() -> Self {
Self
}
}
impl Default for CpuStream {
fn default() -> Self {
Self::new()
}
}
impl GpuStream for CpuStream {
fn synchronize(&self) {
}
fn id(&self) -> usize {
0
}
}
#[cfg(feature = "cuda")]
#[derive(Debug)]
pub struct CudaDevice {
device_id: usize,
name: String,
total_memory: usize,
compute_capability: (u32, u32),
}
#[cfg(feature = "cuda")]
impl CudaDevice {
pub fn new(device_id: usize) -> RusTorchResult<Self> {
Ok(Self {
device_id,
name: format!("CUDA Device {}", device_id),
total_memory: 8 * 1024 * 1024 * 1024, compute_capability: (7, 5),
})
}
}
#[cfg(feature = "cuda")]
impl GpuDevice for CudaDevice {
fn id(&self) -> usize {
self.device_id
}
fn name(&self) -> String {
self.name.clone()
}
fn device_type(&self) -> String {
"cuda".to_string()
}
fn is_available(&self) -> bool {
true
}
fn total_memory(&self) -> usize {
self.total_memory
}
fn allocated_memory(&self) -> usize {
1024 * 1024 * 1024 }
fn compute_capability(&self) -> Option<(u32, u32)> {
Some(self.compute_capability)
}
fn is_cpu(&self) -> bool {
false
}
fn synchronize(&self) {
}
fn create_stream(&self) -> Arc<dyn GpuStream> {
Arc::new(CudaStream::new())
}
}
#[cfg(feature = "cuda")]
#[derive(Debug)]
pub struct CudaStream;
#[cfg(feature = "cuda")]
impl CudaStream {
pub fn new() -> Self {
Self
}
}
#[cfg(feature = "cuda")]
impl GpuStream for CudaStream {
fn synchronize(&self) {
}
fn id(&self) -> usize {
1
}
}
#[cfg(feature = "metal")]
#[derive(Debug)]
pub struct MetalDevice {
name: String,
total_memory: usize,
}
#[cfg(feature = "metal")]
impl MetalDevice {
pub fn new() -> RusTorchResult<Self> {
Ok(Self {
name: "Apple M-Series GPU".to_string(),
total_memory: 16 * 1024 * 1024 * 1024, })
}
}
#[cfg(feature = "metal")]
impl GpuDevice for MetalDevice {
fn id(&self) -> usize {
0
}
fn name(&self) -> String {
self.name.clone()
}
fn device_type(&self) -> String {
"metal".to_string()
}
fn is_available(&self) -> bool {
cfg!(target_os = "macos")
}
fn total_memory(&self) -> usize {
self.total_memory
}
fn allocated_memory(&self) -> usize {
2 * 1024 * 1024 * 1024 }
fn compute_capability(&self) -> Option<(u32, u32)> {
None
}
fn is_cpu(&self) -> bool {
false
}
fn synchronize(&self) {
}
fn create_stream(&self) -> Arc<dyn GpuStream> {
Arc::new(MetalStream::new())
}
}
#[cfg(feature = "metal")]
#[derive(Debug)]
pub struct MetalStream;
#[cfg(feature = "metal")]
impl MetalStream {
pub fn new() -> Self {
Self
}
}
#[cfg(feature = "metal")]
impl GpuStream for MetalStream {
fn synchronize(&self) {
}
fn id(&self) -> usize {
2
}
}
#[cfg(feature = "opencl")]
#[derive(Debug)]
pub struct OpenCLDevice {
platform_id: usize,
device_id: usize,
name: String,
total_memory: usize,
}
#[cfg(feature = "opencl")]
impl OpenCLDevice {
pub fn new(platform_id: usize, device_id: usize) -> RusTorchResult<Self> {
Ok(Self {
platform_id,
device_id,
name: format!("OpenCL Device {}:{}", platform_id, device_id),
total_memory: 4 * 1024 * 1024 * 1024, })
}
}
#[cfg(feature = "opencl")]
impl GpuDevice for OpenCLDevice {
fn id(&self) -> usize {
self.device_id
}
fn name(&self) -> String {
self.name.clone()
}
fn device_type(&self) -> String {
"opencl".to_string()
}
fn is_available(&self) -> bool {
true
}
fn total_memory(&self) -> usize {
self.total_memory
}
fn allocated_memory(&self) -> usize {
512 * 1024 * 1024 }
fn compute_capability(&self) -> Option<(u32, u32)> {
None
}
fn is_cpu(&self) -> bool {
false
}
fn synchronize(&self) {
}
fn create_stream(&self) -> Arc<dyn GpuStream> {
Arc::new(OpenCLStream::new())
}
}
#[cfg(feature = "opencl")]
#[derive(Debug)]
pub struct OpenCLStream;
#[cfg(feature = "opencl")]
impl OpenCLStream {
pub fn new() -> Self {
Self
}
}
#[cfg(feature = "opencl")]
impl GpuStream for OpenCLStream {
fn synchronize(&self) {
}
fn id(&self) -> usize {
3
}
}
pub struct DeviceCapabilities {
pub name: String,
pub total_memory: usize,
pub available_memory: usize,
pub compute_major: u32,
pub compute_minor: u32,
pub max_threads_per_block: u32,
pub max_block_dims: [u32; 3],
pub max_grid_dims: [u32; 3],
pub shared_memory_per_block: u32,
pub warp_size: u32,
pub supports_double: bool,
pub supports_half: bool,
pub supports_tensor_cores: bool,
}
impl Default for DeviceCapabilities {
fn default() -> Self {
DeviceCapabilities {
name: "CPU".to_string(),
total_memory: 0,
available_memory: 0,
compute_major: 0,
compute_minor: 0,
max_threads_per_block: 1,
max_block_dims: [1, 1, 1],
max_grid_dims: [1, 1, 1],
shared_memory_per_block: 0,
warp_size: 1,
supports_double: true,
supports_half: false,
supports_tensor_cores: false,
}
}
}
pub struct DeviceInfo {
device_type: DeviceType,
capabilities: DeviceCapabilities,
is_available: bool,
}
impl DeviceInfo {
pub fn cpu() -> Self {
DeviceInfo {
device_type: DeviceType::Cpu,
capabilities: DeviceCapabilities::default(),
is_available: true,
}
}
pub fn cuda(_device_id: usize) -> RusTorchResult<Self> {
#[cfg(feature = "cuda")]
{
let capabilities = DeviceCapabilities {
name: format!("CUDA Device {}", _device_id),
total_memory: 8 * 1024 * 1024 * 1024, available_memory: 7 * 1024 * 1024 * 1024, compute_major: 7,
compute_minor: 5,
max_threads_per_block: 1024,
max_block_dims: [1024, 1024, 64],
max_grid_dims: [2147483647, 65535, 65535],
shared_memory_per_block: 49152,
warp_size: 32,
supports_double: true,
supports_half: true,
supports_tensor_cores: true,
};
Ok(DeviceInfo {
device_type: DeviceType::Cuda(_device_id),
capabilities,
is_available: true,
})
}
#[cfg(not(feature = "cuda"))]
{
Err(RusTorchError::gpu("CUDA not supported"))
}
}
pub fn metal(_device_id: usize) -> RusTorchResult<Self> {
#[cfg(feature = "metal")]
{
let capabilities = DeviceCapabilities {
name: format!("Metal Device {}", _device_id),
total_memory: 16 * 1024 * 1024 * 1024, available_memory: 14 * 1024 * 1024 * 1024,
compute_major: 3,
compute_minor: 0,
max_threads_per_block: 1024,
max_block_dims: [1024, 1024, 1024],
max_grid_dims: [65535, 65535, 65535],
shared_memory_per_block: 32768,
warp_size: 32, supports_double: false, supports_half: true,
supports_tensor_cores: false,
};
Ok(DeviceInfo {
device_type: DeviceType::Metal(_device_id),
capabilities,
is_available: true,
})
}
#[cfg(not(feature = "metal"))]
{
Err(RusTorchError::gpu("Metal not supported"))
}
}
pub fn device_type(&self) -> DeviceType {
self.device_type
}
pub fn capabilities(&self) -> &DeviceCapabilities {
&self.capabilities
}
pub fn is_available(&self) -> bool {
self.is_available
}
pub fn optimal_block_size(&self, problem_size: usize) -> (u32, u32, u32) {
match self.device_type {
DeviceType::Cpu => (1, 1, 1),
DeviceType::Cuda(_) => {
let threads_per_block = if problem_size < 256 {
128
} else if problem_size < 1024 {
256
} else {
512
};
(threads_per_block, 1, 1)
}
DeviceType::Metal(_) => {
let threads_per_group = if problem_size < 256 {
64
} else if problem_size < 1024 {
128
} else {
256
};
(threads_per_group, 1, 1)
}
DeviceType::OpenCL(_) => {
(64, 1, 1)
}
#[cfg(feature = "coreml")]
DeviceType::CoreML(_) => {
(1, 1, 1)
}
DeviceType::Auto => {
(128, 1, 1)
}
#[cfg(feature = "mac-hybrid")]
DeviceType::MacHybrid => {
(256, 1, 1)
}
}
}
pub fn optimal_grid_size(
&self,
problem_size: usize,
block_size: (u32, u32, u32),
) -> (u32, u32, u32) {
let total_threads = block_size.0 * block_size.1 * block_size.2;
let num_blocks = (problem_size as u32).div_ceil(total_threads).max(1);
match self.device_type {
DeviceType::Cpu => (1, 1, 1),
_ => {
(num_blocks, 1, 1)
}
}
}
pub fn supports_operation(&self, operation: &str) -> bool {
match operation {
"matmul" => true,
"conv2d" => true,
"batchnorm" => true,
"activation" => true,
"reduction" => true,
"double_precision" => self.capabilities.supports_double,
"half_precision" => self.capabilities.supports_half,
"tensor_cores" => self.capabilities.supports_tensor_cores,
_ => false,
}
}
pub fn memory_info(&self) -> (usize, usize, f32) {
let total = self.capabilities.total_memory;
let available = self.capabilities.available_memory;
let usage_percent = if total > 0 {
((total - available) as f32 / total as f32) * 100.0
} else {
0.0
};
(total, available, usage_percent)
}
}
pub struct DeviceRegistry {
devices: HashMap<DeviceType, DeviceInfo>,
}
impl DeviceRegistry {
pub fn new() -> Self {
let mut registry = DeviceRegistry {
devices: HashMap::new(),
};
registry.devices.insert(DeviceType::Cpu, DeviceInfo::cpu());
#[cfg(feature = "cuda")]
{
for device_id in 0..Self::get_cuda_device_count() {
if let Ok(device_info) = DeviceInfo::cuda(device_id) {
registry
.devices
.insert(DeviceType::Cuda(device_id), device_info);
}
}
}
#[cfg(feature = "metal")]
{
if let Ok(device_info) = DeviceInfo::metal(0) {
registry.devices.insert(DeviceType::Metal(0), device_info);
}
}
registry
}
pub fn get_device(&self, device_type: DeviceType) -> Option<&DeviceInfo> {
self.devices.get(&device_type)
}
pub fn list_devices(&self) -> Vec<DeviceType> {
self.devices.keys().copied().collect()
}
pub fn best_device_for_operation(&self, operation: &str, data_size: usize) -> DeviceType {
let mut best_device = DeviceType::Cpu;
let mut best_score = 0.0f32;
for (device_type, device_info) in &self.devices {
if !device_info.supports_operation(operation) {
continue;
}
let mut score = match device_type {
DeviceType::Cpu => 1.0,
DeviceType::Cuda(_) => 10.0,
DeviceType::Metal(_) => 8.0,
DeviceType::OpenCL(_) => 6.0,
#[cfg(feature = "coreml")]
DeviceType::CoreML(_) => 12.0, DeviceType::Auto => 0.5, #[cfg(feature = "mac-hybrid")]
DeviceType::MacHybrid => 15.0, };
if data_size < 1000 {
if matches!(device_type, DeviceType::Cpu) {
score *= 2.0;
} else {
score *= 0.5;
}
}
let (_, available_memory, _) = device_info.memory_info();
if available_memory > 0 {
let memory_score = (available_memory as f32 / (1024.0 * 1024.0 * 1024.0)).min(10.0);
score *= 1.0 + memory_score * 0.1;
}
if score > best_score {
best_score = score;
best_device = *device_type;
}
}
best_device
}
#[cfg(feature = "cuda")]
fn get_cuda_device_count() -> usize {
0 }
}
impl Default for DeviceRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_capabilities_default() {
let caps = DeviceCapabilities::default();
assert_eq!(caps.name, "CPU");
assert!(caps.supports_double);
assert!(!caps.supports_half);
}
#[test]
fn test_device_info_cpu() {
let device = DeviceInfo::cpu();
assert_eq!(device.device_type(), DeviceType::Cpu);
assert!(device.is_available());
assert!(device.supports_operation("matmul"));
}
#[test]
fn test_optimal_block_size() {
let device = DeviceInfo::cpu();
let block_size = device.optimal_block_size(1000);
assert_eq!(block_size, (1, 1, 1));
}
#[test]
fn test_device_registry() {
let registry = DeviceRegistry::new();
assert!(registry.get_device(DeviceType::Cpu).is_some());
let devices = registry.list_devices();
assert!(!devices.is_empty());
assert!(devices.contains(&DeviceType::Cpu));
}
#[test]
fn test_best_device_selection() {
let registry = DeviceRegistry::new();
let devices = registry.list_devices();
println!("Registered devices: {:?}", devices);
let device = registry.best_device_for_operation("matmul", 100);
println!("Selected device for small operation: {:?}", device);
#[cfg(all(target_os = "macos", feature = "metal"))]
{
if devices.contains(&DeviceType::Metal(0)) {
assert!(device == DeviceType::Cpu || device == DeviceType::Metal(0));
} else {
assert_eq!(device, DeviceType::Cpu);
}
}
#[cfg(not(all(target_os = "macos", feature = "metal")))]
{
assert_eq!(device, DeviceType::Cpu);
}
let device = registry.best_device_for_operation("matmul", 1000000);
println!("Selected device for large operation: {:?}", device);
#[cfg(all(target_os = "macos", feature = "metal"))]
{
if devices.contains(&DeviceType::Metal(0)) {
assert_eq!(device, DeviceType::Metal(0));
} else {
assert_eq!(device, DeviceType::Cpu);
}
}
#[cfg(not(all(target_os = "macos", feature = "metal")))]
{
assert_eq!(device, DeviceType::Cpu);
}
}
#[test]
fn test_memory_info() {
let device = DeviceInfo::cpu();
let (total, available, usage) = device.memory_info();
assert_eq!(total, 0);
assert_eq!(available, 0);
assert_eq!(usage, 0.0);
}
}