use crate::device::DeviceType;
use crate::error::Result;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct DeviceCapabilities {
device_type: DeviceType,
total_memory: u64,
available_memory: u64,
memory_bandwidth: Option<u64>,
compute_units: u32,
clock_rate: Option<u32>,
simd_features: SimdFeatures,
hardware_features: HashMap<String, bool>,
driver_version: Option<String>,
device_name: String,
pci_info: Option<PciInfo>,
thermal_info: Option<ThermalInfo>,
}
impl DeviceCapabilities {
pub fn detect(device_type: DeviceType) -> Result<Self> {
match device_type {
DeviceType::Cpu => Self::detect_cpu_capabilities(),
DeviceType::Cuda(index) => Self::detect_cuda_capabilities(index),
DeviceType::Metal(index) => Self::detect_metal_capabilities(index),
DeviceType::Wgpu(index) => Self::detect_wgpu_capabilities(index),
}
}
pub fn device_type(&self) -> DeviceType {
self.device_type
}
pub fn total_memory(&self) -> u64 {
self.total_memory
}
pub fn total_memory_mb(&self) -> u64 {
self.total_memory / (1024 * 1024)
}
pub fn available_memory(&self) -> u64 {
self.available_memory
}
pub fn available_memory_mb(&self) -> u64 {
self.available_memory / (1024 * 1024)
}
pub fn memory_bandwidth(&self) -> Option<u64> {
self.memory_bandwidth
}
pub fn compute_units(&self) -> u32 {
self.compute_units
}
pub fn clock_rate(&self) -> Option<u32> {
self.clock_rate
}
pub fn simd_features(&self) -> &SimdFeatures {
&self.simd_features
}
pub fn hardware_features(&self) -> &HashMap<String, bool> {
&self.hardware_features
}
pub fn driver_version(&self) -> Option<&str> {
self.driver_version.as_deref()
}
pub fn device_name(&self) -> &str {
&self.device_name
}
pub fn pci_info(&self) -> Option<&PciInfo> {
self.pci_info.as_ref()
}
pub fn thermal_info(&self) -> Option<&ThermalInfo> {
self.thermal_info.as_ref()
}
pub fn supports_feature(&self, feature: &str) -> bool {
self.hardware_features
.get(feature)
.copied()
.unwrap_or(false)
}
pub fn supports_double_precision(&self) -> bool {
match self.device_type {
DeviceType::Cpu => true,
DeviceType::Cuda(_) => self.supports_feature("double_precision"),
DeviceType::Metal(_) => false, DeviceType::Wgpu(_) => self.supports_feature("double_precision"),
}
}
pub fn supports_half_precision(&self) -> bool {
match self.device_type {
DeviceType::Cpu => self.simd_features.supports_f16(),
DeviceType::Cuda(_) => self.supports_feature("half_precision"),
DeviceType::Metal(_) => true, DeviceType::Wgpu(_) => self.supports_feature("half_precision"),
}
}
pub fn supports_unified_memory(&self) -> bool {
match self.device_type {
DeviceType::Cpu => true, DeviceType::Cuda(_) => self.supports_feature("unified_memory"),
DeviceType::Metal(_) => true, DeviceType::Wgpu(_) => false, }
}
pub fn memory_utilization(&self) -> f64 {
if self.total_memory == 0 {
return 0.0;
}
1.0 - (self.available_memory as f64 / self.total_memory as f64)
}
pub fn peak_bandwidth_gbps(&self) -> Option<f64> {
self.memory_bandwidth
.map(|bw| bw as f64 / (1024.0 * 1024.0 * 1024.0))
}
pub fn compute_score(&self) -> u64 {
let base_score = self.compute_units as u64 * self.clock_rate.unwrap_or(1000) as u64;
match self.device_type {
DeviceType::Cpu => base_score,
DeviceType::Cuda(_) => base_score * 10, DeviceType::Metal(_) => base_score * 8, DeviceType::Wgpu(_) => base_score * 6, }
}
fn detect_cpu_capabilities() -> Result<Self> {
Ok(DeviceCapabilities {
device_type: DeviceType::Cpu,
total_memory: Self::get_system_memory(),
available_memory: Self::get_available_memory(),
memory_bandwidth: Self::estimate_cpu_bandwidth(),
compute_units: Self::get_cpu_cores(),
clock_rate: Self::get_cpu_frequency(),
simd_features: SimdFeatures::detect_cpu(),
hardware_features: Self::detect_cpu_features(),
driver_version: None,
device_name: Self::get_cpu_name(),
pci_info: None,
thermal_info: Self::get_cpu_thermal_info(),
})
}
fn detect_cuda_capabilities(index: usize) -> Result<Self> {
#[cfg(feature = "cuda")]
{
Ok(DeviceCapabilities {
device_type: DeviceType::Cuda(index),
total_memory: 8 * 1024 * 1024 * 1024, available_memory: 7 * 1024 * 1024 * 1024, memory_bandwidth: Some(900 * 1024 * 1024 * 1024), compute_units: 108, clock_rate: Some(1755), simd_features: SimdFeatures::cuda_default(),
hardware_features: Self::detect_cuda_features(index),
driver_version: Some("12.0".to_string()),
device_name: format!("CUDA Device {}", index),
pci_info: Some(PciInfo::mock_cuda()),
thermal_info: Some(ThermalInfo::mock_gpu()),
})
}
#[cfg(not(feature = "cuda"))]
{
Err(crate::error::TorshError::General(
crate::error::GeneralError::DeviceError(format!(
"CUDA device {} not available (CUDA support not compiled)",
index
)),
))
}
}
fn detect_metal_capabilities(index: usize) -> Result<Self> {
#[cfg(target_os = "macos")]
{
Ok(DeviceCapabilities {
device_type: DeviceType::Metal(index),
total_memory: Self::get_system_memory(), available_memory: Self::get_available_memory(),
memory_bandwidth: Some(400 * 1024 * 1024 * 1024), compute_units: 32, clock_rate: Some(1398), simd_features: SimdFeatures::metal_default(),
hardware_features: Self::detect_metal_features(),
driver_version: Some("Metal 3.0".to_string()),
device_name: format!("Metal Device {}", index),
pci_info: None, thermal_info: Some(ThermalInfo::mock_integrated()),
})
}
#[cfg(not(target_os = "macos"))]
{
Err(crate::error::TorshError::General(
crate::error::GeneralError::DeviceError(format!(
"Metal device {} not available (not running on macOS)",
index
)),
))
}
}
fn detect_wgpu_capabilities(index: usize) -> Result<Self> {
#[cfg(feature = "wgpu")]
{
Ok(DeviceCapabilities {
device_type: DeviceType::Wgpu(index),
total_memory: 4 * 1024 * 1024 * 1024, available_memory: 3 * 1024 * 1024 * 1024, memory_bandwidth: Some(300 * 1024 * 1024 * 1024), compute_units: 24, clock_rate: Some(1200), simd_features: SimdFeatures::wgpu_default(),
hardware_features: Self::detect_wgpu_features(),
driver_version: Some("WebGPU 1.0".to_string()),
device_name: format!("WebGPU Device {}", index),
pci_info: None,
thermal_info: None,
})
}
#[cfg(not(feature = "wgpu"))]
{
Err(crate::error::TorshError::General(
crate::error::GeneralError::DeviceError(format!(
"WebGPU device {} not available (WebGPU support not compiled)",
index
)),
))
}
}
fn get_system_memory() -> u64 {
#[cfg(target_os = "linux")]
{
std::fs::read_to_string("/proc/meminfo")
.ok()
.and_then(|content| {
content
.lines()
.find(|line| line.starts_with("MemTotal:"))
.and_then(|line| {
line.split_whitespace()
.nth(1)
.and_then(|s| s.parse::<u64>().ok())
})
})
.map(|kb| kb * 1024)
.unwrap_or(8 * 1024 * 1024 * 1024) }
#[cfg(target_os = "macos")]
{
16 * 1024 * 1024 * 1024 }
#[cfg(target_os = "windows")]
{
16 * 1024 * 1024 * 1024 }
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
{
8 * 1024 * 1024 * 1024 }
}
fn get_available_memory() -> u64 {
Self::get_system_memory() * 80 / 100 }
fn estimate_cpu_bandwidth() -> Option<u64> {
Some(50 * 1024 * 1024 * 1024) }
fn get_cpu_cores() -> u32 {
std::thread::available_parallelism()
.map(|p| p.get() as u32)
.unwrap_or(4)
}
fn get_cpu_frequency() -> Option<u32> {
Some(3000) }
fn get_cpu_name() -> String {
#[cfg(target_os = "linux")]
{
std::fs::read_to_string("/proc/cpuinfo")
.ok()
.and_then(|content| {
content
.lines()
.find(|line| line.starts_with("model name"))
.and_then(|line| line.split(':').nth(1))
.map(|name| name.trim().to_string())
})
.unwrap_or_else(|| "Unknown CPU".to_string())
}
#[cfg(not(target_os = "linux"))]
{
"Unknown CPU".to_string()
}
}
fn get_cpu_thermal_info() -> Option<ThermalInfo> {
Some(ThermalInfo {
current_temp: 45.0,
max_temp: 85.0,
thermal_throttling: false,
})
}
fn detect_cpu_features() -> HashMap<String, bool> {
let mut features = HashMap::new();
features.insert("sse".to_string(), true);
features.insert("sse2".to_string(), true);
features.insert("sse3".to_string(), true);
features.insert("ssse3".to_string(), true);
features.insert("sse4_1".to_string(), true);
features.insert("sse4_2".to_string(), true);
features.insert("avx".to_string(), cfg!(target_feature = "avx"));
features.insert("avx2".to_string(), cfg!(target_feature = "avx2"));
features.insert("avx512f".to_string(), cfg!(target_feature = "avx512f"));
features.insert("fma".to_string(), cfg!(target_feature = "fma"));
features.insert("bmi1".to_string(), cfg!(target_feature = "bmi1"));
features.insert("bmi2".to_string(), cfg!(target_feature = "bmi2"));
features.insert("f32".to_string(), true);
features.insert("f64".to_string(), true);
features.insert("i8".to_string(), true);
features.insert("i16".to_string(), true);
features.insert("i32".to_string(), true);
features.insert("i64".to_string(), true);
features.insert("u8".to_string(), true);
features.insert("u32".to_string(), true);
features.insert("u64".to_string(), true);
features.insert("bool".to_string(), true);
#[cfg(feature = "half")]
{
features.insert("f16".to_string(), true);
features.insert("bf16".to_string(), true);
}
features.insert("c64".to_string(), true);
features.insert("c128".to_string(), true);
features.insert("qint8".to_string(), true);
features.insert("quint8".to_string(), true);
features
}
#[allow(dead_code)]
fn detect_cuda_features(_index: usize) -> HashMap<String, bool> {
let mut features = HashMap::new();
#[cfg(all(feature = "gpu", scirs2_gpu_available))]
{
use crate::gpu;
if let Ok(device) = gpu::GpuDevice::new(index) {
features.insert("double_precision".to_string(), device.supports_f64());
features.insert("half_precision".to_string(), device.supports_f16());
features.insert("tensor_cores".to_string(), device.has_tensor_cores());
features.insert(
"unified_memory".to_string(),
device.supports_unified_memory(),
);
features.insert("peer_to_peer".to_string(), device.supports_p2p());
features.insert(
"concurrent_kernels".to_string(),
device.supports_concurrent_kernels(),
);
features.insert("async_copy".to_string(), device.supports_async_copy());
return features;
}
}
features.insert("double_precision".to_string(), true);
features.insert("half_precision".to_string(), true);
features.insert("tensor_cores".to_string(), true); features.insert("unified_memory".to_string(), true);
features.insert("peer_to_peer".to_string(), true);
features.insert("concurrent_kernels".to_string(), true);
features.insert("async_copy".to_string(), true);
features.insert("dynamic_parallelism".to_string(), true);
features.insert("cooperative_groups".to_string(), true);
features.insert("bf16".to_string(), true); features.insert("tf32".to_string(), true); features.insert("sparse_tensor_cores".to_string(), false); features.insert("mma_operations".to_string(), true);
features
}
#[cfg(target_os = "macos")]
fn detect_metal_features() -> HashMap<String, bool> {
let mut features = HashMap::new();
#[cfg(all(feature = "gpu", scirs2_gpu_available, target_os = "macos"))]
{
use crate::gpu;
if let Ok(device) = gpu::GpuDevice::new(0) {
features.insert("half_precision".to_string(), device.supports_f16());
features.insert("unified_memory".to_string(), true); features.insert("tile_shaders".to_string(), device.supports_tile_shaders());
features.insert("compute_shaders".to_string(), true); features.insert(
"indirect_command_buffers".to_string(),
device.supports_indirect_command_buffers(),
);
return features;
}
}
features.insert("half_precision".to_string(), true);
features.insert("unified_memory".to_string(), true);
features.insert("tile_shaders".to_string(), true);
features.insert("compute_shaders".to_string(), true);
features.insert("indirect_command_buffers".to_string(), true);
features.insert("argument_buffers".to_string(), true);
features.insert("raster_order_groups".to_string(), true);
features.insert("imageblocks".to_string(), true);
features.insert("threadgroup_sharing".to_string(), true);
#[cfg(target_os = "macos")]
{
features.insert("mesh_shaders".to_string(), true);
features.insert("ray_tracing".to_string(), true);
features.insert("function_pointers".to_string(), true);
}
features
}
#[allow(dead_code)]
fn detect_wgpu_features() -> HashMap<String, bool> {
let mut features = HashMap::new();
#[cfg(all(feature = "gpu", scirs2_gpu_available, feature = "wgpu"))]
{
use crate::gpu;
if let Ok(device) = gpu::GpuDevice::new(0) {
features.insert(
"compute_shaders".to_string(),
device.supports_compute_shaders(),
);
features.insert(
"storage_buffers".to_string(),
device.supports_storage_buffers(),
);
features.insert(
"push_constants".to_string(),
device.supports_push_constants(),
);
features.insert("half_precision".to_string(), device.supports_f16());
features.insert("subgroups".to_string(), device.supports_subgroups());
return features;
}
}
features.insert("compute_shaders".to_string(), true);
features.insert("storage_buffers".to_string(), true);
features.insert("push_constants".to_string(), false); features.insert("half_precision".to_string(), false); features.insert("timestamp_queries".to_string(), true);
features.insert("indirect_dispatch".to_string(), true);
features.insert("shader_f16".to_string(), false);
features.insert("subgroups".to_string(), false); features.insert("bgra8unorm_storage".to_string(), false);
features.insert("depth32float_stencil8".to_string(), true);
features.insert("texture_compression_bc".to_string(), false); features.insert("texture_compression_etc2".to_string(), false);
features.insert("texture_compression_astc".to_string(), false);
features
}
pub fn query_gpu_memory(_device_index: usize) -> Option<GpuMemoryInfo> {
#[cfg(all(feature = "gpu", scirs2_gpu_available))]
{
use crate::gpu;
if let Ok(device) = gpu::GpuDevice::new(device_index) {
return Some(GpuMemoryInfo {
total_memory: device.total_memory(),
free_memory: device.free_memory(),
used_memory: device.used_memory(),
supports_unified_memory: device.supports_unified_memory(),
memory_clock_rate: device.memory_clock_rate(),
memory_bus_width: device.memory_bus_width(),
});
}
}
None
}
pub fn query_compute_capability(_device_index: usize) -> Option<ComputeCapability> {
#[cfg(all(feature = "gpu", scirs2_gpu_available))]
{
use crate::gpu;
if let Ok(device) = gpu::GpuDevice::new(device_index) {
return Some(ComputeCapability {
major: device.compute_capability_major(),
minor: device.compute_capability_minor(),
max_threads_per_block: device.max_threads_per_block(),
max_block_dimensions: device.max_block_dimensions(),
max_grid_dimensions: device.max_grid_dimensions(),
warp_size: device.warp_size(),
max_shared_memory_per_block: device.max_shared_memory_per_block(),
});
}
}
None
}
}
#[derive(Debug, Clone)]
pub struct GpuMemoryInfo {
pub total_memory: usize,
pub free_memory: usize,
pub used_memory: usize,
pub supports_unified_memory: bool,
pub memory_clock_rate: Option<u32>,
pub memory_bus_width: Option<u32>,
}
#[derive(Debug, Clone)]
pub struct ComputeCapability {
pub major: u32,
pub minor: u32,
pub max_threads_per_block: u32,
pub max_block_dimensions: [u32; 3],
pub max_grid_dimensions: [u32; 3],
pub warp_size: u32,
pub max_shared_memory_per_block: usize,
}
#[derive(Debug, Clone, Default)]
pub struct SimdFeatures {
pub sse: bool,
pub sse2: bool,
pub sse3: bool,
pub ssse3: bool,
pub sse4_1: bool,
pub sse4_2: bool,
pub avx: bool,
pub avx2: bool,
pub avx512f: bool,
pub fma: bool,
pub neon: bool, pub sve: bool, pub f16: bool, }
impl SimdFeatures {
pub fn detect_cpu() -> Self {
Self {
sse: true, sse2: true,
sse3: true,
ssse3: true,
sse4_1: true,
sse4_2: true,
avx: cfg!(target_feature = "avx"),
avx2: cfg!(target_feature = "avx2"),
avx512f: cfg!(target_feature = "avx512f"),
fma: cfg!(target_feature = "fma"),
neon: cfg!(target_feature = "neon"),
sve: false, f16: cfg!(target_feature = "f16c"),
}
}
pub fn cuda_default() -> Self {
Self {
sse: false,
sse2: false,
sse3: false,
ssse3: false,
sse4_1: false,
sse4_2: false,
avx: false,
avx2: false,
avx512f: false,
fma: true, neon: false,
sve: false,
f16: true, }
}
pub fn metal_default() -> Self {
Self {
sse: false,
sse2: false,
sse3: false,
ssse3: false,
sse4_1: false,
sse4_2: false,
avx: false,
avx2: false,
avx512f: false,
fma: true,
neon: cfg!(target_arch = "aarch64"),
sve: false,
f16: true, }
}
pub fn wgpu_default() -> Self {
Self {
sse: false,
sse2: false,
sse3: false,
ssse3: false,
sse4_1: false,
sse4_2: false,
avx: false,
avx2: false,
avx512f: false,
fma: false,
neon: false,
sve: false,
f16: false,
}
}
pub fn has_simd(&self) -> bool {
self.sse || self.avx || self.neon
}
pub fn has_advanced_simd(&self) -> bool {
self.avx2 || self.avx512f || self.sve
}
pub fn supports_f16(&self) -> bool {
self.f16
}
pub fn max_vector_width(&self) -> u32 {
if self.avx512f {
512
} else if self.avx2 || self.avx {
256
} else if self.sse || self.neon {
128
} else {
64 }
}
pub fn optimal_chunk_size<T>(&self) -> usize {
let element_size = std::mem::size_of::<T>();
let vector_bytes = self.max_vector_width() as usize / 8;
std::cmp::max(1, vector_bytes / element_size)
}
}
#[derive(Debug, Clone)]
pub struct PciInfo {
pub vendor_id: u16,
pub device_id: u16,
pub subsystem_vendor_id: u16,
pub subsystem_device_id: u16,
pub bus: u8,
pub device: u8,
pub function: u8,
}
impl PciInfo {
pub fn mock_cuda() -> Self {
Self {
vendor_id: 0x10de, device_id: 0x2684, subsystem_vendor_id: 0x10de,
subsystem_device_id: 0x1658,
bus: 1,
device: 0,
function: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct ThermalInfo {
pub current_temp: f32,
pub max_temp: f32,
pub thermal_throttling: bool,
}
impl ThermalInfo {
pub fn mock_gpu() -> Self {
Self {
current_temp: 65.0,
max_temp: 83.0,
thermal_throttling: false,
}
}
pub fn mock_integrated() -> Self {
Self {
current_temp: 45.0,
max_temp: 100.0,
thermal_throttling: false,
}
}
pub fn is_temperature_safe(&self) -> bool {
self.current_temp < self.max_temp * 0.9 }
pub fn temperature_ratio(&self) -> f32 {
self.current_temp / self.max_temp
}
}
pub mod utils {
use super::*;
pub fn compare_capabilities(
a: &DeviceCapabilities,
b: &DeviceCapabilities,
) -> std::cmp::Ordering {
a.compute_score().cmp(&b.compute_score())
}
pub fn find_best_device(capabilities: &[DeviceCapabilities]) -> Option<&DeviceCapabilities> {
capabilities
.iter()
.max_by(|a, b| compare_capabilities(a, b))
}
pub fn filter_by_memory(
capabilities: &[DeviceCapabilities],
min_memory_mb: u64,
) -> Vec<&DeviceCapabilities> {
capabilities
.iter()
.filter(|cap| cap.total_memory_mb() >= min_memory_mb)
.collect()
}
pub fn filter_by_simd(
capabilities: &[DeviceCapabilities],
require_advanced: bool,
) -> Vec<&DeviceCapabilities> {
capabilities
.iter()
.filter(|cap| {
if require_advanced {
cap.simd_features().has_advanced_simd()
} else {
cap.simd_features().has_simd()
}
})
.collect()
}
pub fn capabilities_summary(cap: &DeviceCapabilities) -> String {
format!(
"{} - {} MB, {} cores, SIMD: {}",
cap.device_name(),
cap.total_memory_mb(),
cap.compute_units(),
if cap.simd_features().has_advanced_simd() {
"Advanced"
} else if cap.simd_features().has_simd() {
"Basic"
} else {
"None"
}
)
}
pub fn is_suitable_for_training(cap: &DeviceCapabilities) -> bool {
cap.total_memory_mb() >= 4096 && cap.compute_units() >= 16 && match cap.device_type() {
DeviceType::Cpu => cap.simd_features().has_simd(),
DeviceType::Cuda(_) => true, DeviceType::Metal(_) => true,
DeviceType::Wgpu(_) => cap.total_memory_mb() >= 8192, }
}
pub fn estimate_training_performance(cap: &DeviceCapabilities) -> f64 {
let memory_score = (cap.total_memory_mb() as f64).log2() / 10.0; let compute_score = cap.compute_score() as f64 / 1_000_000.0; let bandwidth_score = cap.peak_bandwidth_gbps().unwrap_or(1.0) / 100.0;
memory_score + compute_score + bandwidth_score
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cpu_capability_detection() {
let cap = DeviceCapabilities::detect(DeviceType::Cpu).expect("detect should succeed");
assert_eq!(cap.device_type(), DeviceType::Cpu);
assert!(cap.total_memory() > 0);
assert!(cap.compute_units() > 0);
}
#[test]
fn test_simd_features() {
let features = SimdFeatures::detect_cpu();
assert!(features.has_simd() || !features.has_simd()); assert!(features.max_vector_width() >= 64);
assert!(features.optimal_chunk_size::<f32>() > 0);
}
#[test]
fn test_memory_calculations() {
let cap = DeviceCapabilities::detect(DeviceType::Cpu).expect("detect should succeed");
assert_eq!(cap.total_memory_mb(), cap.total_memory() / (1024 * 1024));
assert!(cap.memory_utilization() >= 0.0 && cap.memory_utilization() <= 1.0);
}
#[test]
fn test_capability_comparison() {
let cap1 = DeviceCapabilities::detect(DeviceType::Cpu).expect("detect should succeed");
let cap2 = DeviceCapabilities::detect(DeviceType::Cpu).expect("detect should succeed");
let _ordering = utils::compare_capabilities(&cap1, &cap2);
let devices = [cap1, cap2];
let best = utils::find_best_device(&devices);
assert!(best.is_some());
}
#[test]
fn test_device_filtering() {
let cap = DeviceCapabilities::detect(DeviceType::Cpu).expect("detect should succeed");
let caps = vec![cap];
let filtered = utils::filter_by_memory(&caps, 1024); assert!(!filtered.is_empty());
let simd_filtered = utils::filter_by_simd(&caps, false);
assert!(!simd_filtered.is_empty());
}
#[test]
fn test_training_suitability() {
let cap = DeviceCapabilities::detect(DeviceType::Cpu).expect("detect should succeed");
let _suitable = utils::is_suitable_for_training(&cap);
let _performance = utils::estimate_training_performance(&cap);
}
#[test]
fn test_thermal_info() {
let thermal = ThermalInfo::mock_gpu();
assert!(thermal.is_temperature_safe());
assert!(thermal.temperature_ratio() >= 0.0 && thermal.temperature_ratio() <= 1.0);
}
#[test]
fn test_precision_support() {
let cap = DeviceCapabilities::detect(DeviceType::Cpu).expect("detect should succeed");
assert!(cap.supports_double_precision()); }
#[cfg(feature = "cuda")]
#[test]
fn test_cuda_capabilities() {
if let Ok(cap) = DeviceCapabilities::detect(DeviceType::Cuda(0)) {
assert_eq!(cap.device_type(), DeviceType::Cuda(0));
assert!(cap.supports_half_precision());
assert!(cap.supports_feature("tensor_cores"));
}
}
#[cfg(target_os = "macos")]
#[test]
fn test_metal_capabilities() {
if let Ok(cap) = DeviceCapabilities::detect(DeviceType::Metal(0)) {
assert_eq!(cap.device_type(), DeviceType::Metal(0));
assert!(cap.supports_unified_memory());
assert!(cap.supports_half_precision());
}
}
}