import platform
import torch
import sys
from pathlib import Path
PLATFORM_CONFIG = {
'Linux': {
'primary_device': 'cuda', 'fallback_device': 'cpu',
'note': 'AMD Radeon GPU with ROCm 5.7 (requires kernel modules: amdgpu, amdkfd)',
'requires': ['rocm-smi', 'hip', 'rocm-libs'],
'setup_note': 'sudo modprobe amdgpu && sudo modprobe amdkfd',
'optimization': 'AMD ROCm HIP kernels (2-5x speedup on tensor ops)',
},
'Darwin': { 'primary_device': 'mps', 'fallback_device': 'cpu',
'note': 'Apple Metal Performance Shaders (GPU acceleration for Apple Silicon)',
'requires': ['Metal', 'PyTorch 1.12+'],
'setup_note': 'Metal support built-in on Apple Silicon (M1/M2/M3+)',
'optimization': 'Apple GPU via Metal (2-3x speedup on compatible ops)',
},
'Windows': {
'primary_device': 'cuda', 'fallback_device': 'cpu',
'note': 'NVIDIA CUDA GPU (if available) or Intel oneAPI (if available)',
'requires': ['NVIDIA GPU', 'CUDA Toolkit', 'cuDNN'],
'setup_note': 'Install NVIDIA CUDA Toolkit and drivers',
'optimization': 'NVIDIA CUDA kernels (3-8x speedup)',
}
}
class EnvironmentAwareDeviceSelector:
def __init__(self):
self.system = platform.system()
self.config = PLATFORM_CONFIG.get(self.system, PLATFORM_CONFIG['Linux'])
self.device = None
self.device_info = {}
def detect_device(self):
primary = self.config['primary_device']
fallback = self.config['fallback_device']
self.device_info = {
'platform': self.system,
'primary_device': primary,
'fallback_device': fallback,
'available': False,
'note': self.config['note'],
}
if self.system == 'Linux':
return self._detect_linux_device()
elif self.system == 'Darwin':
return self._detect_macos_device()
elif self.system == 'Windows':
return self._detect_windows_device()
else:
return torch.device('cpu')
def _detect_linux_device(self):
print("\n" + "="*80)
print("🐧 LINUX PLATFORM DETECTED")
print("="*80)
print("\n1. Checking for AMD Radeon GPU (ROCm)...")
if torch.cuda.is_available():
device_name = torch.cuda.get_device_name(0)
print(f" ✅ GPU Found: {device_name}")
self.device_info['available'] = True
self.device_info['device_type'] = 'AMD ROCm'
self.device_info['warning'] = "⚠️ AMD GPU CONFIG IS LOCKED TO LINUX. Do not modify when on MacBook."
return torch.device('cuda', 0)
else:
print(" ℹ️ No AMD GPU detected (or HIP kernels not loaded)")
print(" To enable: sudo modprobe amdgpu && sudo modprobe amdkfd")
print("\n2. Using CPU (NumPy + Intel MKL optimizations)")
print(" ✅ CPU Mode: Stable, <200ms per query")
self.device_info['available'] = True
self.device_info['device_type'] = 'CPU'
return torch.device('cpu')
def _detect_macos_device(self):
print("\n" + "="*80)
print("🍎 MACOS PLATFORM DETECTED")
print("="*80)
print("\n⚠️ IMPORTANT: AMD ROCm config is preserved for Linux only!")
print(" This MacBook will use Apple Metal GPU instead.")
print("\n1. Checking for Apple Metal GPU (M1/M2/M3+)...")
try:
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
print(" ✅ Apple Metal Found (GPU Acceleration Available)")
self.device_info['available'] = True
self.device_info['device_type'] = 'Apple Metal'
self.device_info['warning'] = "ℹ️ Using Apple Metal. AMD ROCm config preserved for Linux."
return torch.device('mps')
else:
print(" ℹ️ Apple Metal not available on this macOS version")
print(" (Requires macOS 12.3+ with Apple Silicon)")
except:
print(" ℹ️ Apple Metal not available")
print("\n2. Using CPU (Intel/ARM optimizations)")
print(" ✅ CPU Mode: Reliable, <200ms per query")
self.device_info['available'] = True
self.device_info['device_type'] = 'CPU'
return torch.device('cpu')
def _detect_windows_device(self):
print("\n" + "="*80)
print("🪟 WINDOWS PLATFORM DETECTED")
print("="*80)
print("\n1. Checking for NVIDIA CUDA GPU...")
if torch.cuda.is_available():
device_name = torch.cuda.get_device_name(0)
print(f" ✅ GPU Found: {device_name}")
self.device_info['available'] = True
self.device_info['device_type'] = 'NVIDIA CUDA'
return torch.device('cuda', 0)
else:
print(" ℹ️ No NVIDIA GPU detected")
print("\n2. Using CPU (Intel MKL optimizations)")
print(" ✅ CPU Mode: Stable, <200ms per query")
self.device_info['available'] = True
self.device_info['device_type'] = 'CPU'
return torch.device('cpu')
def get_device_info(self):
return {
'platform': self.system,
'device': str(self.device),
'device_type': self.device_info.get('device_type', 'Unknown'),
'available': self.device_info.get('available', False),
'note': self.config['note'],
'warning': self.device_info.get('warning', ''),
'setup_note': self.config['setup_note'],
'optimization': self.config['optimization'],
}
def print_device_summary(self):
info = self.get_device_info()
print("\n" + "="*80)
print("📊 DEVICE CONFIGURATION SUMMARY")
print("="*80)
print(f"\nPlatform: {info['platform']}")
print(f"Device: {info['device']}")
print(f"Device Type: {info['device_type']}")
print(f"Available: {'✅ Yes' if info['available'] else '❌ No'}")
print(f"\nNote: {info['note']}")
if info['warning']:
print(f"⚠️ WARNING: {info['warning']}")
print(f"\nOptimization: {info['optimization']}")
print(f"Setup: {info['setup_note']}")
print("\n" + "="*80)
class ProtectedDeviceConfig:
CONFIG_FILE = Path(__file__).parent / '.device_config'
@classmethod
def load_last_platform(cls):
if cls.CONFIG_FILE.exists():
with open(cls.CONFIG_FILE, 'r') as f:
return f.read().strip()
return None
@classmethod
def save_platform(cls, platform_name):
with open(cls.CONFIG_FILE, 'w') as f:
f.write(platform_name)
@classmethod
def check_platform_change(cls):
current = platform.system()
last = cls.load_last_platform()
if last and last != current:
print("\n" + "⚠️ "*40)
print("\n🚨 PLATFORM CHANGE DETECTED!")
print(f"\n Last Platform: {last}")
print(f" Current Platform: {current}")
print("\n⚠️ IMPORTANT NOTES:")
print(" - AMD GPU (ROCm) config is LOCKED to Linux only")
print(" - MacBook will use Apple Metal instead (automatic)")
print(" - Device settings are platform-specific and will not conflict")
print(" - AMD ROCm setup is preserved for when you return to Linux")
print("\n" + "⚠️ "*40 + "\n")
cls.save_platform(current)
def get_device(force_device=None):
if force_device:
print(f"\n⚠️ FORCING DEVICE: {force_device} (test mode)")
return torch.device(force_device)
ProtectedDeviceConfig.check_platform_change()
selector = EnvironmentAwareDeviceSelector()
device = selector.detect_device()
selector.print_device_summary()
return device
if __name__ == '__main__':
print("\n" + "="*80)
print("DEVICE CONFIGURATION TEST")
print("="*80)
device = get_device()
print(f"\n✅ Selected Device: {device}")
print("\nTesting tensor operations...")
tensor = torch.randn(10, 10, device=device)
print(f"✅ Tensor created on {device}")
print(f" Shape: {tensor.shape}")
print(f" Device: {tensor.device}")