use candle_core::Device;
#[allow(unused_imports)]
use log::{debug, info, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DeviceSelection {
#[default]
Auto,
Cpu,
Metal,
Cuda(usize),
}
fn is_ios_simulator() -> bool {
#[cfg(target_os = "ios")]
{
if std::env::var("SIMULATOR_DEVICE_NAME").is_ok() {
return true;
}
#[cfg(target_vendor = "apple")]
{
use std::ffi::CStr;
use std::os::raw::c_char;
extern "C" {
fn sysctlbyname(
name: *const c_char,
oldp: *mut u8,
oldlenp: *mut usize,
newp: *const u8,
newlen: usize,
) -> i32;
}
let mut buffer = [0u8; 256];
let mut size = buffer.len();
let name = b"hw.model\0";
unsafe {
if sysctlbyname(
name.as_ptr() as *const c_char,
buffer.as_mut_ptr(),
&mut size,
std::ptr::null(),
0,
) == 0
{
if let Ok(model) = CStr::from_ptr(buffer.as_ptr() as *const c_char).to_str() {
if model.contains("Mac") || model.contains("x86") {
return true;
}
}
}
}
}
false
}
#[cfg(not(target_os = "ios"))]
{
false
}
}
pub fn select_device(preference: DeviceSelection) -> candle_core::Result<Device> {
match preference {
DeviceSelection::Cpu => Ok(Device::Cpu),
DeviceSelection::Metal => {
#[cfg(feature = "candle-metal")]
{
if is_ios_simulator() {
warn!("Metal not supported on iOS Simulator, falling back to CPU");
return Ok(Device::Cpu);
}
Device::new_metal(0)
}
#[cfg(not(feature = "candle-metal"))]
{
warn!("Metal requested but candle-metal feature not enabled, falling back to CPU");
Ok(Device::Cpu)
}
}
DeviceSelection::Cuda(ordinal) => {
#[cfg(feature = "candle-cuda")]
{
Device::new_cuda(ordinal)
}
#[cfg(not(feature = "candle-cuda"))]
{
let _ = ordinal;
warn!("CUDA requested but candle-cuda feature not enabled, falling back to CPU");
Ok(Device::Cpu)
}
}
DeviceSelection::Auto => {
if is_ios_simulator() {
info!("Running on iOS Simulator, using CPU device");
return Ok(Device::Cpu);
}
#[cfg(feature = "candle-metal")]
{
match Device::new_metal(0) {
Ok(device) => {
info!("Auto-selected Metal device");
return Ok(device);
}
Err(e) => {
debug!("Metal device not available: {}, trying alternatives", e);
}
}
}
#[cfg(feature = "candle-cuda")]
{
match Device::new_cuda(0) {
Ok(device) => {
info!("Auto-selected CUDA device 0");
return Ok(device);
}
Err(e) => {
debug!("CUDA device not available: {}, falling back to CPU", e);
}
}
}
info!("Using CPU device");
Ok(Device::Cpu)
}
}
}
pub fn device_name(device: &Device) -> &'static str {
match device {
Device::Cpu => "CPU",
Device::Cuda(_) => "CUDA",
Device::Metal(_) => "Metal",
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cpu_selection() {
let device = select_device(DeviceSelection::Cpu).unwrap();
assert!(matches!(device, Device::Cpu));
}
#[test]
fn test_auto_selection() {
let device = select_device(DeviceSelection::Auto).unwrap();
let _ = device_name(&device);
}
#[test]
fn test_device_name() {
assert_eq!(device_name(&Device::Cpu), "CPU");
}
}