use crate::device::implementations::DeviceFactory;
use crate::device::{Device, DeviceCapabilities, DeviceType};
use crate::error::Result;
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
#[derive(Debug)]
pub struct DeviceDiscovery {
discovered_devices: RwLock<Vec<DiscoveredDevice>>,
device_cache: RwLock<HashMap<DeviceType, Arc<dyn Device>>>,
selection_history: RwLock<Vec<SelectionRecord>>,
config: DiscoveryConfig,
}
impl DeviceDiscovery {
pub fn new() -> Self {
Self {
discovered_devices: RwLock::new(Vec::new()),
device_cache: RwLock::new(HashMap::new()),
selection_history: RwLock::new(Vec::new()),
config: DiscoveryConfig::default(),
}
}
pub fn with_config(config: DiscoveryConfig) -> Self {
Self {
discovered_devices: RwLock::new(Vec::new()),
device_cache: RwLock::new(HashMap::new()),
selection_history: RwLock::new(Vec::new()),
config,
}
}
pub fn scan_devices(&self) -> Result<usize> {
let mut discovered = Vec::new();
if self.config.scan_cpu {
discovered.extend(self.scan_cpu_devices()?);
}
if self.config.scan_cuda {
discovered.extend(self.scan_cuda_devices()?);
}
if self.config.scan_metal {
discovered.extend(self.scan_metal_devices()?);
}
if self.config.scan_wgpu {
discovered.extend(self.scan_wgpu_devices()?);
}
let count = discovered.len();
{
let mut devices = self
.discovered_devices
.write()
.expect("lock should not be poisoned");
*devices = discovered;
}
self.populate_device_cache()?;
Ok(count)
}
pub fn get_discovered_devices(&self) -> Vec<DiscoveredDevice> {
let devices = self
.discovered_devices
.read()
.expect("lock should not be poisoned");
devices.clone()
}
pub fn select_optimal_device(
&self,
workload: &WorkloadProfile,
) -> Result<Option<Arc<dyn Device>>> {
let devices = self
.discovered_devices
.read()
.expect("lock should not be poisoned");
let cache = self
.device_cache
.read()
.expect("lock should not be poisoned");
if devices.is_empty() {
return Ok(None);
}
let mut best_device = None;
let mut best_score = 0.0;
for discovered in devices.iter() {
if !discovered.is_available {
continue;
}
if !self.is_workload_compatible(discovered, workload)? {
continue;
}
let score = self.calculate_fitness_score(discovered, workload)?;
if score > best_score {
best_score = score;
best_device = cache.get(&discovered.device_type).cloned();
}
}
if let Some(ref device) = best_device {
self.record_selection(device.device_type(), workload.clone(), best_score);
}
Ok(best_device)
}
pub fn select_devices_for_distributed_workload(
&self,
workload: &WorkloadProfile,
target_count: usize,
) -> Result<Vec<Arc<dyn Device>>> {
let devices = self
.discovered_devices
.read()
.expect("lock should not be poisoned");
let cache = self
.device_cache
.read()
.expect("lock should not be poisoned");
let mut candidates: Vec<_> = devices
.iter()
.filter(|d| d.is_available && self.is_workload_compatible(d, workload).unwrap_or(false))
.collect();
candidates.sort_by(|a, b| {
let score_a = self.calculate_fitness_score(a, workload).unwrap_or(0.0);
let score_b = self.calculate_fitness_score(b, workload).unwrap_or(0.0);
score_b
.partial_cmp(&score_a)
.unwrap_or(std::cmp::Ordering::Equal)
});
let selected: Vec<_> = candidates
.into_iter()
.take(target_count)
.filter_map(|d| cache.get(&d.device_type).cloned())
.collect();
Ok(selected)
}
pub fn get_devices_by_capabilities(
&self,
requirements: &CapabilityRequirements,
) -> Result<Vec<Arc<dyn Device>>> {
let devices = self
.discovered_devices
.read()
.expect("lock should not be poisoned");
let cache = self
.device_cache
.read()
.expect("lock should not be poisoned");
let mut matching_devices = Vec::new();
for discovered in devices.iter() {
if !discovered.is_available {
continue;
}
if self.meets_capability_requirements(discovered, requirements)? {
if let Some(device) = cache.get(&discovered.device_type) {
matching_devices.push(device.clone());
}
}
}
Ok(matching_devices)
}
pub fn recommend_device(&self, use_case: UseCase) -> Result<DeviceRecommendation> {
let workload = match use_case {
UseCase::Training => WorkloadProfile::training_large(),
UseCase::Inference => WorkloadProfile::inference(),
UseCase::Development => WorkloadProfile::development(),
UseCase::Benchmarking => WorkloadProfile::benchmarking(),
UseCase::Research => WorkloadProfile::research(),
};
let devices = self
.discovered_devices
.read()
.expect("lock should not be poisoned");
let cache = self
.device_cache
.read()
.expect("lock should not be poisoned");
let mut recommendations = Vec::new();
for discovered in devices.iter() {
if !discovered.is_available {
continue;
}
let score = self.calculate_fitness_score(discovered, &workload)?;
let reasoning = self.generate_recommendation_reasoning(discovered, &workload, score);
if let Some(device) = cache.get(&discovered.device_type) {
recommendations.push(DeviceOption {
device: device.clone(),
score,
reasoning,
estimated_performance: self.estimate_performance(discovered, &workload)?,
});
}
}
recommendations.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(DeviceRecommendation {
use_case,
options: recommendations,
workload_profile: workload,
})
}
pub fn get_statistics(&self) -> DiscoveryStatistics {
let devices = self
.discovered_devices
.read()
.expect("lock should not be poisoned");
let history = self
.selection_history
.read()
.expect("lock should not be poisoned");
let total_devices = devices.len();
let available_devices = devices.iter().filter(|d| d.is_available).count();
let device_types: HashSet<_> = devices.iter().map(|d| d.device_type).collect();
let unique_device_types = device_types.len();
let total_memory: u64 = devices.iter().map(|d| d.capabilities.total_memory()).sum();
DiscoveryStatistics {
total_devices,
available_devices,
unique_device_types,
total_memory_gb: total_memory / (1024 * 1024 * 1024),
total_selections: history.len(),
}
}
fn scan_cpu_devices(&self) -> Result<Vec<DiscoveredDevice>> {
let mut devices = Vec::new();
let device_type = DeviceType::Cpu;
if DeviceFactory::is_device_type_available(device_type) {
let capabilities = DeviceCapabilities::detect(device_type)?;
let platform_info = self.detect_cpu_platform_info();
devices.push(DiscoveredDevice {
device_type,
capabilities,
is_available: true,
platform_info,
discovery_time: std::time::Instant::now(),
});
}
Ok(devices)
}
fn scan_cuda_devices(&self) -> Result<Vec<DiscoveredDevice>> {
#[allow(unused_mut)] let mut devices = Vec::new();
#[cfg(feature = "cuda")]
{
for index in 0..2 {
let device_type = DeviceType::Cuda(index);
if DeviceFactory::is_device_type_available(device_type) {
if let Ok(capabilities) = DeviceCapabilities::detect(device_type) {
let platform_info = self.detect_cuda_platform_info(index);
devices.push(DiscoveredDevice {
device_type,
capabilities,
is_available: true,
platform_info,
discovery_time: std::time::Instant::now(),
});
}
}
}
}
Ok(devices)
}
fn scan_metal_devices(&self) -> Result<Vec<DiscoveredDevice>> {
#[allow(unused_mut)] let mut devices = Vec::new();
#[cfg(target_os = "macos")]
{
let device_type = DeviceType::Metal(0);
if DeviceFactory::is_device_type_available(device_type) {
if let Ok(capabilities) = DeviceCapabilities::detect(device_type) {
let platform_info = self.detect_metal_platform_info();
devices.push(DiscoveredDevice {
device_type,
capabilities,
is_available: true,
platform_info,
discovery_time: std::time::Instant::now(),
});
}
}
}
Ok(devices)
}
fn scan_wgpu_devices(&self) -> Result<Vec<DiscoveredDevice>> {
#[allow(unused_mut)] let mut devices = Vec::new();
#[cfg(feature = "wgpu")]
{
let device_type = DeviceType::Wgpu(0);
if DeviceFactory::is_device_type_available(device_type) {
if let Ok(capabilities) = DeviceCapabilities::detect(device_type) {
let platform_info = self.detect_wgpu_platform_info();
devices.push(DiscoveredDevice {
device_type,
capabilities,
is_available: true,
platform_info,
discovery_time: std::time::Instant::now(),
});
}
}
}
Ok(devices)
}
fn populate_device_cache(&self) -> Result<()> {
let devices = self
.discovered_devices
.read()
.expect("lock should not be poisoned");
let mut cache = self
.device_cache
.write()
.expect("lock should not be poisoned");
cache.clear();
for discovered in devices.iter() {
if discovered.is_available {
if let Ok(device) = DeviceFactory::create_device(discovered.device_type) {
let arc_device: Arc<dyn Device> = device.into();
cache.insert(discovered.device_type, arc_device);
}
}
}
Ok(())
}
fn is_workload_compatible(
&self,
device: &DiscoveredDevice,
workload: &WorkloadProfile,
) -> Result<bool> {
if device.capabilities.available_memory() < workload.min_memory_bytes {
return Ok(false);
}
if device.capabilities.compute_units() < workload.min_compute_units {
return Ok(false);
}
if workload.requires_fp64 && !device.capabilities.supports_double_precision() {
return Ok(false);
}
if workload.requires_fp16 && !device.capabilities.supports_half_precision() {
return Ok(false);
}
match workload.device_preference {
DevicePreference::GpuOnly => {
if device.device_type.is_cpu() {
return Ok(false);
}
}
DevicePreference::CpuOnly => {
if !device.device_type.is_cpu() {
return Ok(false);
}
}
DevicePreference::CudaOnly => {
if !device.device_type.is_cuda() {
return Ok(false);
}
}
DevicePreference::Any => {}
}
Ok(true)
}
fn calculate_fitness_score(
&self,
device: &DiscoveredDevice,
workload: &WorkloadProfile,
) -> Result<f64> {
let mut score = 0.0;
let perf_score = device.capabilities.compute_score() as f64 / 1_000_000.0;
score += perf_score * workload.performance_weight;
let memory_ratio =
device.capabilities.available_memory() as f64 / workload.min_memory_bytes as f64;
let memory_score = memory_ratio.min(2.0); score += memory_score * workload.memory_weight;
let efficiency_score = self.estimate_efficiency(device)?;
score += efficiency_score * workload.efficiency_weight;
if device.device_type.is_gpu() && workload.prefers_gpu {
score += 0.5;
}
if device.capabilities.supports_double_precision() && workload.requires_fp64 {
score += 0.3;
}
if device.capabilities.supports_half_precision() && workload.requires_fp16 {
score += 0.2;
}
let history_bonus = self.get_history_bonus(device.device_type, workload);
score += history_bonus;
Ok(score)
}
fn meets_capability_requirements(
&self,
device: &DiscoveredDevice,
requirements: &CapabilityRequirements,
) -> Result<bool> {
if let Some(min_memory) = requirements.min_memory_gb {
if device.capabilities.total_memory_mb() < min_memory * 1024 {
return Ok(false);
}
}
if let Some(min_cores) = requirements.min_compute_units {
if device.capabilities.compute_units() < min_cores {
return Ok(false);
}
}
if requirements.requires_gpu && device.device_type.is_cpu() {
return Ok(false);
}
if requirements.requires_fp64 && !device.capabilities.supports_double_precision() {
return Ok(false);
}
if requirements.requires_fp16 && !device.capabilities.supports_half_precision() {
return Ok(false);
}
for feature in &requirements.required_features {
if !device.capabilities.supports_feature(feature) {
return Ok(false);
}
}
Ok(true)
}
fn estimate_efficiency(&self, device: &DiscoveredDevice) -> Result<f64> {
match device.device_type {
DeviceType::Cpu => Ok(0.5), DeviceType::Cuda(_) => Ok(0.9), DeviceType::Metal(_) => Ok(0.8), DeviceType::Wgpu(_) => Ok(0.6), }
}
fn estimate_performance(
&self,
device: &DiscoveredDevice,
workload: &WorkloadProfile,
) -> Result<PerformanceEstimate> {
let base_throughput = device.capabilities.compute_score() as f64;
let workload_multiplier = match workload.workload_type {
WorkloadType::Training => 1.0,
WorkloadType::Inference => 1.2, WorkloadType::Validation => 1.1,
WorkloadType::Benchmarking => 0.9, };
let estimated_throughput = base_throughput * workload_multiplier;
let estimated_latency_ms = match device.device_type {
DeviceType::Cpu => 10.0,
DeviceType::Cuda(_) => 2.0,
DeviceType::Metal(_) => 3.0,
DeviceType::Wgpu(_) => 5.0,
};
Ok(PerformanceEstimate {
throughput: estimated_throughput,
latency_ms: estimated_latency_ms,
memory_bandwidth_gbps: device.capabilities.peak_bandwidth_gbps().unwrap_or(1.0),
})
}
fn generate_recommendation_reasoning(
&self,
device: &DiscoveredDevice,
workload: &WorkloadProfile,
score: f64,
) -> String {
let mut reasons = Vec::new();
if device.device_type.is_gpu() && workload.prefers_gpu {
reasons.push("GPU acceleration preferred for this workload".to_string());
}
if device.capabilities.total_memory_mb() > workload.min_memory_bytes / (1024 * 1024) {
reasons.push(format!(
"Sufficient memory ({:.1}GB available)",
device.capabilities.total_memory_mb() as f64 / 1024.0
));
}
if device.capabilities.supports_half_precision() && workload.requires_fp16 {
reasons.push("Supports required half-precision operations".to_string());
}
if score > 2.0 {
reasons.push("High performance score for workload requirements".to_string());
} else if score > 1.0 {
reasons.push("Good performance score for workload requirements".to_string());
}
if reasons.is_empty() {
"Meets basic requirements".to_string()
} else {
reasons.join("; ")
}
}
fn get_history_bonus(&self, device_type: DeviceType, workload: &WorkloadProfile) -> f64 {
let history = self
.selection_history
.read()
.expect("lock should not be poisoned");
let successful_selections = history
.iter()
.filter(|record| {
record.device_type == device_type
&& record.workload.workload_type == workload.workload_type
&& record.success_score > 1.0
})
.count();
(successful_selections as f64) * 0.1
}
fn record_selection(&self, device_type: DeviceType, workload: WorkloadProfile, score: f64) {
let mut history = self
.selection_history
.write()
.expect("lock should not be poisoned");
history.push(SelectionRecord {
device_type,
workload,
success_score: score,
timestamp: std::time::Instant::now(),
});
if history.len() > 1000 {
history.remove(0);
}
}
fn detect_cpu_platform_info(&self) -> PlatformInfo {
PlatformInfo {
vendor: "Unknown".to_string(),
architecture: std::env::consts::ARCH.to_string(),
features: vec!["sse2".to_string(), "avx".to_string()],
driver_version: None,
}
}
#[allow(dead_code)]
fn detect_cuda_platform_info(&self, _index: usize) -> PlatformInfo {
PlatformInfo {
vendor: "NVIDIA".to_string(),
architecture: "CUDA".to_string(),
features: vec!["compute_capability_8_6".to_string()],
driver_version: Some("12.0".to_string()),
}
}
#[allow(dead_code)] fn detect_metal_platform_info(&self) -> PlatformInfo {
PlatformInfo {
vendor: "Apple".to_string(),
architecture: "Apple Silicon".to_string(),
features: vec!["unified_memory".to_string(), "tile_shaders".to_string()],
driver_version: Some("Metal 3.0".to_string()),
}
}
#[allow(dead_code)]
fn detect_wgpu_platform_info(&self) -> PlatformInfo {
PlatformInfo {
vendor: "WebGPU".to_string(),
architecture: "WebGPU".to_string(),
features: vec!["compute_shaders".to_string()],
driver_version: Some("1.0".to_string()),
}
}
}
#[derive(Debug, Clone)]
pub struct DiscoveredDevice {
pub device_type: DeviceType,
pub capabilities: DeviceCapabilities,
pub is_available: bool,
pub platform_info: PlatformInfo,
pub discovery_time: std::time::Instant,
}
#[derive(Debug, Clone)]
pub struct PlatformInfo {
pub vendor: String,
pub architecture: String,
pub features: Vec<String>,
pub driver_version: Option<String>,
}
#[derive(Debug, Clone)]
pub struct WorkloadProfile {
pub workload_type: WorkloadType,
pub min_memory_bytes: u64,
pub min_compute_units: u32,
pub requires_fp64: bool,
pub requires_fp16: bool,
pub prefers_gpu: bool,
pub device_preference: DevicePreference,
pub performance_weight: f64,
pub memory_weight: f64,
pub efficiency_weight: f64,
}
impl WorkloadProfile {
pub fn training_large() -> Self {
Self {
workload_type: WorkloadType::Training,
min_memory_bytes: 8 * 1024 * 1024 * 1024, min_compute_units: 32,
requires_fp64: false,
requires_fp16: true,
prefers_gpu: true,
device_preference: DevicePreference::GpuOnly,
performance_weight: 1.0,
memory_weight: 0.8,
efficiency_weight: 0.6,
}
}
pub fn inference() -> Self {
Self {
workload_type: WorkloadType::Inference,
min_memory_bytes: 2 * 1024 * 1024 * 1024, min_compute_units: 8,
requires_fp64: false,
requires_fp16: true,
prefers_gpu: true,
device_preference: DevicePreference::Any,
performance_weight: 0.8,
memory_weight: 0.6,
efficiency_weight: 1.0,
}
}
pub fn development() -> Self {
Self {
workload_type: WorkloadType::Training,
min_memory_bytes: 1024 * 1024 * 1024, min_compute_units: 4,
requires_fp64: false,
requires_fp16: false,
prefers_gpu: false,
device_preference: DevicePreference::Any,
performance_weight: 0.5,
memory_weight: 0.7,
efficiency_weight: 0.3,
}
}
pub fn benchmarking() -> Self {
Self {
workload_type: WorkloadType::Benchmarking,
min_memory_bytes: 4 * 1024 * 1024 * 1024, min_compute_units: 16,
requires_fp64: true,
requires_fp16: true,
prefers_gpu: true,
device_preference: DevicePreference::Any,
performance_weight: 1.0,
memory_weight: 0.5,
efficiency_weight: 0.5,
}
}
pub fn research() -> Self {
Self {
workload_type: WorkloadType::Training,
min_memory_bytes: 16 * 1024 * 1024 * 1024, min_compute_units: 64,
requires_fp64: true,
requires_fp16: true,
prefers_gpu: true,
device_preference: DevicePreference::Any,
performance_weight: 1.0,
memory_weight: 1.0,
efficiency_weight: 0.4,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WorkloadType {
Training,
Inference,
Validation,
Benchmarking,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DevicePreference {
Any,
GpuOnly,
CpuOnly,
CudaOnly,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UseCase {
Training,
Inference,
Development,
Benchmarking,
Research,
}
#[derive(Debug, Clone)]
pub struct CapabilityRequirements {
pub min_memory_gb: Option<u64>,
pub min_compute_units: Option<u32>,
pub requires_gpu: bool,
pub requires_fp64: bool,
pub requires_fp16: bool,
pub required_features: Vec<String>,
}
impl CapabilityRequirements {
pub fn basic() -> Self {
Self {
min_memory_gb: None,
min_compute_units: None,
requires_gpu: false,
requires_fp64: false,
requires_fp16: false,
required_features: Vec::new(),
}
}
pub fn gpu_training() -> Self {
Self {
min_memory_gb: Some(4),
min_compute_units: Some(16),
requires_gpu: true,
requires_fp64: false,
requires_fp16: true,
required_features: vec!["tensor_cores".to_string()],
}
}
}
#[derive(Debug)]
pub struct DeviceRecommendation {
pub use_case: UseCase,
pub options: Vec<DeviceOption>,
pub workload_profile: WorkloadProfile,
}
#[derive(Debug)]
pub struct DeviceOption {
pub device: Arc<dyn Device>,
pub score: f64,
pub reasoning: String,
pub estimated_performance: PerformanceEstimate,
}
#[derive(Debug, Clone)]
pub struct PerformanceEstimate {
pub throughput: f64,
pub latency_ms: f64,
pub memory_bandwidth_gbps: f64,
}
#[derive(Debug, Clone)]
struct SelectionRecord {
device_type: DeviceType,
workload: WorkloadProfile,
success_score: f64,
#[allow(dead_code)] timestamp: std::time::Instant,
}
#[derive(Debug, Clone)]
pub struct DiscoveryConfig {
pub scan_cpu: bool,
pub scan_cuda: bool,
pub scan_metal: bool,
pub scan_wgpu: bool,
pub cache_discoveries: bool,
pub track_selection_history: bool,
}
impl Default for DiscoveryConfig {
fn default() -> Self {
Self {
scan_cpu: true,
scan_cuda: cfg!(feature = "cuda"),
scan_metal: cfg!(target_os = "macos"),
scan_wgpu: cfg!(feature = "wgpu"),
cache_discoveries: true,
track_selection_history: true,
}
}
}
#[derive(Debug, Clone)]
pub struct DiscoveryStatistics {
pub total_devices: usize,
pub available_devices: usize,
pub unique_device_types: usize,
pub total_memory_gb: u64,
pub total_selections: usize,
}
impl Default for DeviceDiscovery {
fn default() -> Self {
Self::new()
}
}
pub mod utils {
use super::*;
pub fn create_and_scan() -> Result<DeviceDiscovery> {
let discovery = DeviceDiscovery::new();
discovery.scan_devices()?;
Ok(discovery)
}
pub fn quick_select_for_training() -> Result<Option<Arc<dyn Device>>> {
let discovery = create_and_scan()?;
let workload = WorkloadProfile::training_large();
discovery.select_optimal_device(&workload)
}
pub fn quick_select_for_inference() -> Result<Option<Arc<dyn Device>>> {
let discovery = create_and_scan()?;
let workload = WorkloadProfile::inference();
discovery.select_optimal_device(&workload)
}
pub fn get_best_gpu() -> Result<Option<Arc<dyn Device>>> {
let discovery = create_and_scan()?;
let requirements = CapabilityRequirements {
min_memory_gb: Some(1),
min_compute_units: Some(8),
requires_gpu: true,
requires_fp64: false,
requires_fp16: false,
required_features: Vec::new(),
};
let devices = discovery.get_devices_by_capabilities(&requirements)?;
Ok(devices.into_iter().next())
}
pub fn create_device_summary() -> Result<Vec<String>> {
let discovery = create_and_scan()?;
let devices = discovery.get_discovered_devices();
let summary = devices
.iter()
.map(|device| {
format!(
"{:?} - {:.1}GB, {} cores, {}",
device.device_type,
device.capabilities.total_memory_mb() as f64 / 1024.0,
device.capabilities.compute_units(),
if device.is_available {
"Available"
} else {
"Unavailable"
}
)
})
.collect();
Ok(summary)
}
pub fn has_high_performance_devices() -> Result<bool> {
let discovery = create_and_scan()?;
let requirements = CapabilityRequirements {
min_memory_gb: Some(8),
min_compute_units: Some(32),
requires_gpu: true,
requires_fp64: false,
requires_fp16: true,
required_features: Vec::new(),
};
let devices = discovery.get_devices_by_capabilities(&requirements)?;
Ok(!devices.is_empty())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_discovery() {
let discovery = DeviceDiscovery::new();
let count = discovery
.scan_devices()
.expect("scan_devices should succeed");
assert!(count > 0);
let devices = discovery.get_discovered_devices();
assert!(!devices.is_empty());
assert!(devices.iter().any(|d| d.device_type == DeviceType::Cpu));
}
#[test]
fn test_workload_profiles() {
let training = WorkloadProfile::training_large();
assert_eq!(training.workload_type, WorkloadType::Training);
assert!(training.prefers_gpu);
assert!(training.requires_fp16);
let inference = WorkloadProfile::inference();
assert_eq!(inference.workload_type, WorkloadType::Inference);
assert!(inference.min_memory_bytes < training.min_memory_bytes);
let development = WorkloadProfile::development();
assert!(!development.prefers_gpu);
}
#[test]
fn test_capability_requirements() {
let basic = CapabilityRequirements::basic();
assert!(!basic.requires_gpu);
assert!(basic.required_features.is_empty());
let gpu_training = CapabilityRequirements::gpu_training();
assert!(gpu_training.requires_gpu);
assert!(gpu_training.min_memory_gb.is_some());
assert!(!gpu_training.required_features.is_empty());
}
#[test]
fn test_device_selection() {
let discovery = DeviceDiscovery::new();
discovery
.scan_devices()
.expect("scan_devices should succeed");
let workload = WorkloadProfile::development();
let device = discovery
.select_optimal_device(&workload)
.expect("select_optimal_device should succeed");
assert!(device.is_some());
let distributed = discovery
.select_devices_for_distributed_workload(&workload, 2)
.expect("select_devices_for_distributed_workload should succeed");
assert!(!distributed.is_empty());
}
#[test]
fn test_device_recommendation() {
let discovery = DeviceDiscovery::new();
discovery
.scan_devices()
.expect("scan_devices should succeed");
let recommendation = discovery
.recommend_device(UseCase::Development)
.expect("recommend_device should succeed");
assert_eq!(recommendation.use_case, UseCase::Development);
assert!(!recommendation.options.is_empty());
for option in &recommendation.options {
assert!(option.score >= 0.0);
assert!(!option.reasoning.is_empty());
}
}
#[test]
fn test_discovery_statistics() {
let discovery = DeviceDiscovery::new();
discovery
.scan_devices()
.expect("scan_devices should succeed");
let stats = discovery.get_statistics();
assert!(stats.total_devices > 0);
assert!(stats.available_devices > 0);
assert!(stats.unique_device_types > 0);
}
#[test]
fn test_utils_functions() {
let discovery = utils::create_and_scan().expect("create_and_scan should succeed");
assert!(!discovery.get_discovered_devices().is_empty());
let _training_device =
utils::quick_select_for_training().expect("quick_select_for_training should succeed");
let _inference_device =
utils::quick_select_for_inference().expect("quick_select_for_inference should succeed");
let summary = utils::create_device_summary().expect("create_device_summary should succeed");
assert!(!summary.is_empty());
let _has_hp = utils::has_high_performance_devices()
.expect("has_high_performance_devices should succeed");
}
#[test]
fn test_platform_info() {
let discovery = DeviceDiscovery::new();
let cpu_info = discovery.detect_cpu_platform_info();
assert!(!cpu_info.vendor.is_empty());
assert!(!cpu_info.architecture.is_empty());
}
#[test]
fn test_performance_estimate() {
let discovery = DeviceDiscovery::new();
discovery
.scan_devices()
.expect("scan_devices should succeed");
let devices = discovery.get_discovered_devices();
if let Some(device) = devices.first() {
let workload = WorkloadProfile::development();
let estimate = discovery
.estimate_performance(device, &workload)
.expect("estimate_performance should succeed");
assert!(estimate.throughput > 0.0);
assert!(estimate.latency_ms > 0.0);
}
}
}