use torsh_core::device::DeviceType as CoreDeviceType;
#[cfg(not(feature = "std"))]
use alloc::{string::String, vec::Vec};
#[derive(Debug, Clone, PartialEq)]
pub struct Device {
pub id: usize,
pub device_type: CoreDeviceType,
pub name: String,
pub info: DeviceInfo,
}
impl Device {
pub fn new(id: usize, device_type: CoreDeviceType, name: String, info: DeviceInfo) -> Self {
Self {
id,
device_type,
name,
info,
}
}
pub fn builder() -> DeviceBuilder {
DeviceBuilder::new()
}
pub const fn id(&self) -> usize {
self.id
}
pub const fn device_type(&self) -> CoreDeviceType {
self.device_type
}
pub fn name(&self) -> &str {
&self.name
}
pub fn info(&self) -> &DeviceInfo {
&self.info
}
pub fn supports_feature(&self, feature: DeviceFeature) -> bool {
self.info.features.contains(&feature)
}
pub fn cpu() -> crate::BackendResult<Self> {
DeviceBuilder::new()
.with_device_type(CoreDeviceType::Cpu)
.with_name("CPU".to_string())
.with_vendor("Generic".to_string())
.with_compute_units(num_cpus::get())
.build()
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct DeviceInfo {
pub vendor: String,
pub driver_version: String,
pub total_memory: usize,
pub available_memory: usize,
pub compute_units: usize,
pub max_work_group_size: usize,
pub max_work_group_dimensions: Vec<usize>,
pub clock_frequency_mhz: u32,
pub memory_bandwidth_gbps: f32,
pub peak_gflops: f32,
pub features: Vec<DeviceFeature>,
pub properties: Vec<(String, String)>,
}
impl Default for DeviceInfo {
fn default() -> Self {
Self {
vendor: "Unknown".to_string(),
driver_version: "Unknown".to_string(),
total_memory: 0,
available_memory: 0,
compute_units: 1,
max_work_group_size: 256,
max_work_group_dimensions: vec![256, 1, 1],
clock_frequency_mhz: 1000,
memory_bandwidth_gbps: 10.0,
peak_gflops: 100.0,
features: Vec::new(),
properties: Vec::new(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum DeviceFeature {
DoublePrecision,
HalfPrecision,
UnifiedMemory,
AtomicOperations,
SubGroups,
Printf,
Profiling,
PeerToPeer,
ConcurrentExecution,
AsyncMemory,
ImageSupport,
FastMath,
TimestampQuery,
TimestampQueryInsideEncoders,
PipelineStatistics,
MappableBuffers,
BufferArrays,
StorageArrays,
UnsizedBindingArray,
IndirectFirstInstance,
ShaderF16,
ShaderI16,
ShaderPrimitiveIndex,
ShaderEarlyDepthTest,
MultiDrawIndirect,
MultiDrawIndirectCount,
Multisampling,
ClearTexture,
SpirvShaderPassthrough,
Custom(String),
}
#[derive(Debug, Clone)]
pub struct DeviceBuilder {
id: usize,
device_type: Option<CoreDeviceType>,
name: Option<String>,
info: DeviceInfo,
}
impl DeviceBuilder {
pub fn new() -> Self {
Self {
id: 0,
device_type: None,
name: None,
info: DeviceInfo::default(),
}
}
pub fn with_id(mut self, id: usize) -> Self {
self.id = id;
self
}
pub fn with_device_type(mut self, device_type: CoreDeviceType) -> Self {
self.device_type = Some(device_type);
self
}
pub fn with_name(mut self, name: String) -> Self {
self.name = Some(name);
self
}
pub fn with_vendor(mut self, vendor: String) -> Self {
self.info.vendor = vendor;
self
}
pub fn with_driver_version(mut self, version: String) -> Self {
self.info.driver_version = version;
self
}
pub fn with_memory(mut self, total: usize, available: usize) -> Self {
self.info.total_memory = total;
self.info.available_memory = available;
self
}
pub fn with_compute_units(mut self, units: usize) -> Self {
self.info.compute_units = units;
self
}
pub fn with_performance(mut self, gflops: f32, bandwidth_gbps: f32) -> Self {
self.info.peak_gflops = gflops;
self.info.memory_bandwidth_gbps = bandwidth_gbps;
self
}
pub fn with_feature(mut self, feature: DeviceFeature) -> Self {
self.info.features.push(feature);
self
}
pub fn with_property(mut self, key: String, value: String) -> Self {
self.info.properties.push((key, value));
self
}
pub fn build(self) -> crate::BackendResult<Device> {
let device_type = self.device_type.ok_or_else(|| {
torsh_core::error::TorshError::BackendError("Device type is required".to_string())
})?;
let name = self.name.ok_or_else(|| {
torsh_core::error::TorshError::BackendError("Device name is required".to_string())
})?;
Ok(Device {
id: self.id,
device_type,
name,
info: self.info,
})
}
}
impl Default for DeviceBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DeviceType {
Cpu,
Cuda,
Metal,
WebGpu,
OpenCl,
Vulkan,
Custom,
}
impl From<CoreDeviceType> for DeviceType {
fn from(core_type: CoreDeviceType) -> Self {
match core_type {
CoreDeviceType::Cpu => DeviceType::Cpu,
CoreDeviceType::Cuda(_) => DeviceType::Cuda,
CoreDeviceType::Metal(_) => DeviceType::Metal,
CoreDeviceType::Wgpu(_) => DeviceType::WebGpu,
}
}
}
impl From<DeviceType> for CoreDeviceType {
fn from(device_type: DeviceType) -> Self {
match device_type {
DeviceType::Cpu => CoreDeviceType::Cpu,
DeviceType::Cuda => CoreDeviceType::Cuda(0), DeviceType::Metal => CoreDeviceType::Metal(0), DeviceType::WebGpu => CoreDeviceType::Wgpu(0), DeviceType::OpenCl => CoreDeviceType::Cpu, DeviceType::Vulkan => CoreDeviceType::Cpu, DeviceType::Custom => CoreDeviceType::Cpu, }
}
}
#[derive(Default)]
pub struct DeviceSelector {
pub device_type: Option<DeviceType>,
pub min_memory: Option<usize>,
pub min_compute_units: Option<usize>,
pub required_features: Vec<DeviceFeature>,
pub preferred_vendor: Option<String>,
#[allow(clippy::type_complexity)]
pub custom_filter: Option<Box<dyn Fn(&Device) -> bool + Send + Sync>>,
}
impl DeviceSelector {
pub fn new() -> Self {
Self::default()
}
pub fn with_device_type(mut self, device_type: DeviceType) -> Self {
self.device_type = Some(device_type);
self
}
pub fn with_min_memory(mut self, min_memory: usize) -> Self {
self.min_memory = Some(min_memory);
self
}
pub fn with_min_compute_units(mut self, min_compute_units: usize) -> Self {
self.min_compute_units = Some(min_compute_units);
self
}
pub fn with_feature(mut self, feature: DeviceFeature) -> Self {
self.required_features.push(feature);
self
}
pub fn with_vendor(mut self, vendor: String) -> Self {
self.preferred_vendor = Some(vendor);
self
}
pub fn matches(&self, device: &Device) -> bool {
if let Some(required_type) = &self.device_type {
if device.device_type != (*required_type).into() {
return false;
}
}
if let Some(min_memory) = self.min_memory {
if device.info.total_memory < min_memory {
return false;
}
}
if let Some(min_compute_units) = self.min_compute_units {
if device.info.compute_units < min_compute_units {
return false;
}
}
for feature in &self.required_features {
if !device.supports_feature(feature.clone()) {
return false;
}
}
if let Some(ref preferred_vendor) = self.preferred_vendor {
if device.info.vendor != *preferred_vendor {
return false;
}
}
if let Some(ref filter) = self.custom_filter {
if !filter(device) {
return false;
}
}
true
}
}
pub trait DeviceManager: Send + Sync {
fn enumerate_devices(&self) -> crate::BackendResult<Vec<Device>>;
fn get_device_info(&self, device_id: usize) -> crate::BackendResult<DeviceInfo>;
fn check_device_features(
&self,
device_id: usize,
features: &[DeviceFeature],
) -> crate::BackendResult<Vec<bool>>;
fn get_optimal_device_config(
&self,
device_id: usize,
) -> crate::BackendResult<DeviceConfiguration>;
fn validate_device(&self, device_id: usize) -> crate::BackendResult<bool>;
fn get_performance_info(&self, device_id: usize)
-> crate::BackendResult<DevicePerformanceInfo>;
}
#[derive(Debug, Clone)]
pub struct DeviceConfiguration {
pub optimal_allocation_size: usize,
pub workgroup_size: (u32, u32, u32),
pub memory_alignment: usize,
pub max_concurrent_operations: u32,
pub backend_specific: std::collections::HashMap<String, crate::backend::CapabilityValue>,
}
#[derive(Debug, Clone)]
pub struct DevicePerformanceInfo {
pub memory_bandwidth_gbps: f32,
pub compute_throughput_gflops: f32,
pub memory_latency_ns: f32,
pub cache_hierarchy: Vec<CacheLevel>,
pub thermal_info: Option<ThermalInfo>,
pub power_info: Option<PowerInfo>,
}
#[derive(Debug, Clone)]
pub struct CacheLevel {
pub level: u8,
pub size_bytes: usize,
pub line_size_bytes: usize,
pub associativity: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct ThermalInfo {
pub current_temperature_celsius: f32,
pub max_temperature_celsius: f32,
pub thermal_throttling_active: bool,
}
#[derive(Debug, Clone)]
pub struct PowerInfo {
pub current_power_watts: f32,
pub max_power_watts: f32,
pub power_limit_watts: f32,
}
pub struct DeviceUtils;
impl DeviceUtils {
pub const fn validate_device_id(device_id: usize, max_devices: usize) -> bool {
device_id < max_devices
}
pub fn calculate_device_score(device: &Device, requirements: &DeviceRequirements) -> f32 {
let mut score = 0.0;
if let Some(min_memory) = requirements.min_memory {
if device.info.total_memory >= min_memory {
score += 20.0;
score += (device.info.total_memory as f32 / min_memory as f32 - 1.0) * 5.0;
} else {
return 0.0; }
}
if let Some(min_compute_units) = requirements.min_compute_units {
if device.info.compute_units >= min_compute_units {
score += 15.0;
score += (device.info.compute_units as f32 / min_compute_units as f32 - 1.0) * 3.0;
} else {
return 0.0;
}
}
for required_feature in &requirements.required_features {
if device.supports_feature(required_feature.clone()) {
score += 10.0;
} else {
return 0.0; }
}
score += device.info.peak_gflops / 1000.0; score += device.info.memory_bandwidth_gbps / 100.0;
match DeviceType::from(device.device_type) {
DeviceType::Cuda => score += 15.0, DeviceType::Metal => score += 10.0, DeviceType::WebGpu => score += 5.0, DeviceType::Cpu => score += 1.0, _ => score += 0.0,
}
score
}
pub fn meets_requirements(device: &Device, requirements: &DeviceRequirements) -> bool {
if let Some(min_memory) = requirements.min_memory {
if device.info.total_memory < min_memory {
return false;
}
}
if let Some(min_compute_units) = requirements.min_compute_units {
if device.info.compute_units < min_compute_units {
return false;
}
}
for required_feature in &requirements.required_features {
if !device.supports_feature(required_feature.clone()) {
return false;
}
}
if let Some(preferred_backend) = requirements.preferred_backend {
let device_backend = match DeviceType::from(device.device_type) {
DeviceType::Cpu => crate::backend::BackendType::Cpu,
DeviceType::Cuda => crate::backend::BackendType::Cuda,
DeviceType::Metal => crate::backend::BackendType::Metal,
DeviceType::WebGpu => crate::backend::BackendType::WebGpu,
_ => return false,
};
if device_backend != preferred_backend {
return false;
}
}
true
}
pub fn get_optimal_workgroup_size(device: &Device, operation_type: &str) -> (u32, u32, u32) {
match DeviceType::from(device.device_type) {
DeviceType::Cuda => {
match operation_type {
"matrix_mul" => (16, 16, 1),
"element_wise" => (256, 1, 1),
"reduction" => (512, 1, 1),
_ => (32, 32, 1),
}
}
DeviceType::Metal => {
match operation_type {
"matrix_mul" => (16, 16, 1),
"element_wise" => (256, 1, 1),
"reduction" => (256, 1, 1),
_ => (32, 32, 1),
}
}
DeviceType::WebGpu => {
match operation_type {
"matrix_mul" => (8, 8, 1),
"element_wise" => (64, 1, 1),
"reduction" => (64, 1, 1),
_ => (8, 8, 1),
}
}
_ => {
(1, 1, 1)
}
}
}
}
pub struct DeviceDiscovery;
impl DeviceDiscovery {
pub fn discover_all() -> crate::BackendResult<Vec<(crate::backend::BackendType, Vec<Device>)>> {
let mut all_devices = Vec::new();
if let Ok(cpu_devices) = Self::discover_cpu_devices() {
all_devices.push((crate::backend::BackendType::Cpu, cpu_devices));
}
#[cfg(feature = "cuda")]
if let Ok(cuda_devices) = Self::discover_cuda_devices() {
if !cuda_devices.is_empty() {
all_devices.push((crate::backend::BackendType::Cuda, cuda_devices));
}
}
#[cfg(all(feature = "metal", target_os = "macos"))]
if let Ok(metal_devices) = Self::discover_metal_devices() {
if !metal_devices.is_empty() {
all_devices.push((crate::backend::BackendType::Metal, metal_devices));
}
}
#[cfg(feature = "webgpu")]
if let Ok(webgpu_devices) = Self::discover_webgpu_devices() {
if !webgpu_devices.is_empty() {
all_devices.push((crate::backend::BackendType::WebGpu, webgpu_devices));
}
}
Ok(all_devices)
}
pub fn find_best_device(
requirements: &DeviceRequirements,
) -> crate::BackendResult<(crate::backend::BackendType, Device)> {
let all_devices = Self::discover_all()?;
let mut best_device = None;
let mut best_score = 0.0;
for (backend_type, devices) in all_devices {
for device in devices {
let score = Self::score_device(&device, requirements);
if score > best_score {
best_score = score;
best_device = Some((backend_type, device));
}
}
}
best_device.ok_or_else(|| {
torsh_core::error::TorshError::BackendError(
"No suitable device found for requirements".to_string(),
)
})
}
fn score_device(device: &Device, requirements: &DeviceRequirements) -> f32 {
DeviceUtils::calculate_device_score(device, requirements)
}
fn discover_cpu_devices() -> crate::BackendResult<Vec<Device>> {
let cpu_device = crate::cpu::CpuDevice::new(0, num_cpus::get())?;
Ok(vec![cpu_device.to_device()])
}
#[cfg(feature = "cuda")]
fn discover_cuda_devices() -> crate::BackendResult<Vec<Device>> {
Ok(vec![])
}
#[cfg(all(feature = "metal", target_os = "macos"))]
fn discover_metal_devices() -> crate::BackendResult<Vec<Device>> {
Ok(vec![])
}
#[cfg(feature = "webgpu")]
fn discover_webgpu_devices() -> crate::BackendResult<Vec<Device>> {
Ok(vec![])
}
}
#[derive(Debug, Clone, Default)]
pub struct DeviceRequirements {
pub min_memory: Option<usize>,
pub min_compute_units: Option<usize>,
pub required_features: Vec<DeviceFeature>,
pub preferred_backend: Option<crate::backend::BackendType>,
pub max_power_consumption: Option<f32>,
pub max_temperature: Option<f32>,
}
impl DeviceRequirements {
pub fn new() -> Self {
Self::default()
}
pub fn with_min_memory(mut self, memory: usize) -> Self {
self.min_memory = Some(memory);
self
}
pub fn with_min_compute_units(mut self, units: usize) -> Self {
self.min_compute_units = Some(units);
self
}
pub fn with_feature(mut self, feature: DeviceFeature) -> Self {
self.required_features.push(feature);
self
}
pub fn with_preferred_backend(mut self, backend: crate::backend::BackendType) -> Self {
self.preferred_backend = Some(backend);
self
}
}
impl Eq for Device {}
impl std::hash::Hash for Device {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.id.hash(state);
self.device_type.hash(state);
self.name.hash(state);
self.info.vendor.hash(state);
self.info.driver_version.hash(state);
self.info.total_memory.hash(state);
self.info.available_memory.hash(state);
self.info.compute_units.hash(state);
self.info.max_work_group_size.hash(state);
self.info.max_work_group_dimensions.hash(state);
self.info.clock_frequency_mhz.hash(state);
self.info.features.hash(state);
self.info.properties.hash(state);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_device_info() -> DeviceInfo {
DeviceInfo {
vendor: "Test Vendor".to_string(),
driver_version: "1.0.0".to_string(),
total_memory: 8 * 1024 * 1024 * 1024, available_memory: 6 * 1024 * 1024 * 1024, compute_units: 32,
max_work_group_size: 1024,
max_work_group_dimensions: vec![1024, 1024, 64],
clock_frequency_mhz: 1500,
memory_bandwidth_gbps: 500.0,
peak_gflops: 10000.0,
features: vec![
DeviceFeature::DoublePrecision,
DeviceFeature::UnifiedMemory,
DeviceFeature::AtomicOperations,
],
properties: vec![
("compute_capability".to_string(), "7.5".to_string()),
("warp_size".to_string(), "32".to_string()),
],
}
}
#[test]
fn test_device_creation() {
let info = create_test_device_info();
let device = Device::new(
0,
CoreDeviceType::Cuda(0),
"Test GPU".to_string(),
info.clone(),
);
assert_eq!(device.id(), 0);
assert_eq!(device.name(), "Test GPU");
assert_eq!(device.device_type(), CoreDeviceType::Cuda(0));
assert_eq!(device.info().vendor, "Test Vendor");
assert_eq!(device.info().compute_units, 32);
}
#[test]
fn test_device_feature_support() {
let info = create_test_device_info();
let device = Device::new(1, CoreDeviceType::Cpu, "Test CPU".to_string(), info);
assert!(device.supports_feature(DeviceFeature::DoublePrecision));
assert!(device.supports_feature(DeviceFeature::UnifiedMemory));
assert!(device.supports_feature(DeviceFeature::AtomicOperations));
assert!(!device.supports_feature(DeviceFeature::HalfPrecision));
assert!(!device.supports_feature(DeviceFeature::SubGroups));
}
#[test]
fn test_device_info_default() {
let info = DeviceInfo::default();
assert_eq!(info.vendor, "Unknown");
assert_eq!(info.driver_version, "Unknown");
assert_eq!(info.total_memory, 0);
assert_eq!(info.available_memory, 0);
assert_eq!(info.compute_units, 1);
assert_eq!(info.max_work_group_size, 256);
assert_eq!(info.max_work_group_dimensions, vec![256, 1, 1]);
assert_eq!(info.clock_frequency_mhz, 1000);
assert_eq!(info.memory_bandwidth_gbps, 10.0);
assert_eq!(info.peak_gflops, 100.0);
assert!(info.features.is_empty());
assert!(info.properties.is_empty());
}
#[test]
fn test_device_type_conversion() {
assert_eq!(DeviceType::from(CoreDeviceType::Cpu), DeviceType::Cpu);
assert_eq!(DeviceType::from(CoreDeviceType::Cuda(0)), DeviceType::Cuda);
assert_eq!(
DeviceType::from(CoreDeviceType::Metal(0)),
DeviceType::Metal
);
assert_eq!(
DeviceType::from(CoreDeviceType::Wgpu(0)),
DeviceType::WebGpu
);
assert_eq!(CoreDeviceType::from(DeviceType::Cpu), CoreDeviceType::Cpu);
assert_eq!(
CoreDeviceType::from(DeviceType::Cuda),
CoreDeviceType::Cuda(0)
);
assert_eq!(
CoreDeviceType::from(DeviceType::Metal),
CoreDeviceType::Metal(0)
);
assert_eq!(
CoreDeviceType::from(DeviceType::WebGpu),
CoreDeviceType::Wgpu(0)
);
assert_eq!(
CoreDeviceType::from(DeviceType::OpenCl),
CoreDeviceType::Cpu
);
assert_eq!(
CoreDeviceType::from(DeviceType::Vulkan),
CoreDeviceType::Cpu
);
assert_eq!(
CoreDeviceType::from(DeviceType::Custom),
CoreDeviceType::Cpu
);
}
#[test]
fn test_device_feature_variants() {
let features = [
DeviceFeature::DoublePrecision,
DeviceFeature::HalfPrecision,
DeviceFeature::UnifiedMemory,
DeviceFeature::AtomicOperations,
DeviceFeature::SubGroups,
DeviceFeature::Printf,
DeviceFeature::Profiling,
DeviceFeature::PeerToPeer,
DeviceFeature::ConcurrentExecution,
DeviceFeature::AsyncMemory,
DeviceFeature::ImageSupport,
DeviceFeature::FastMath,
DeviceFeature::Custom("CustomFeature".to_string()),
];
for (i, feature1) in features.iter().enumerate() {
for (j, feature2) in features.iter().enumerate() {
if i != j {
assert_ne!(feature1, feature2);
}
}
}
}
#[test]
fn test_device_selector_creation() {
let selector = DeviceSelector::new();
assert_eq!(selector.device_type, None);
assert_eq!(selector.min_memory, None);
assert_eq!(selector.min_compute_units, None);
assert!(selector.required_features.is_empty());
assert_eq!(selector.preferred_vendor, None);
assert!(selector.custom_filter.is_none());
}
#[test]
fn test_device_selector_builder() {
let selector = DeviceSelector::new()
.with_device_type(DeviceType::Cuda)
.with_min_memory(4 * 1024 * 1024 * 1024) .with_min_compute_units(16)
.with_feature(DeviceFeature::DoublePrecision)
.with_feature(DeviceFeature::AtomicOperations)
.with_vendor("NVIDIA".to_string());
assert_eq!(selector.device_type, Some(DeviceType::Cuda));
assert_eq!(selector.min_memory, Some(4 * 1024 * 1024 * 1024));
assert_eq!(selector.min_compute_units, Some(16));
assert_eq!(selector.required_features.len(), 2);
assert!(selector
.required_features
.contains(&DeviceFeature::DoublePrecision));
assert!(selector
.required_features
.contains(&DeviceFeature::AtomicOperations));
assert_eq!(selector.preferred_vendor, Some("NVIDIA".to_string()));
}
#[test]
fn test_device_selector_matching() {
let mut info = create_test_device_info();
info.vendor = "NVIDIA".to_string();
info.total_memory = 8 * 1024 * 1024 * 1024; info.compute_units = 32;
let device = Device::new(0, CoreDeviceType::Cuda(0), "RTX 4090".to_string(), info);
let selector1 = DeviceSelector::new()
.with_device_type(DeviceType::Cuda)
.with_min_memory(4 * 1024 * 1024 * 1024) .with_min_compute_units(16)
.with_feature(DeviceFeature::DoublePrecision)
.with_vendor("NVIDIA".to_string());
assert!(selector1.matches(&device));
let selector2 = DeviceSelector::new().with_min_memory(16 * 1024 * 1024 * 1024);
assert!(!selector2.matches(&device));
let selector3 = DeviceSelector::new().with_feature(DeviceFeature::HalfPrecision);
assert!(!selector3.matches(&device));
let selector4 = DeviceSelector::new().with_vendor("AMD".to_string());
assert!(!selector4.matches(&device));
}
#[test]
fn test_custom_device_feature() {
let custom_feature1 = DeviceFeature::Custom("TensorCores".to_string());
let custom_feature2 = DeviceFeature::Custom("TensorCores".to_string());
let custom_feature3 = DeviceFeature::Custom("RTCores".to_string());
assert_eq!(custom_feature1, custom_feature2);
assert_ne!(custom_feature1, custom_feature3);
let mut info = DeviceInfo::default();
info.features.push(custom_feature1.clone());
let device = Device::new(0, CoreDeviceType::Cuda(0), "Custom GPU".to_string(), info);
assert!(device.supports_feature(custom_feature1));
assert!(!device.supports_feature(custom_feature3));
}
}