Skip to main content

any_tts/
device.rs

1//! Device selection utilities.
2//!
3//! Automatically selects the best available compute device based on compiled
4//! feature flags and hardware availability.
5
6use candle_core::Device;
7use tracing::info;
8
9/// Strategy for selecting a compute device.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
11pub enum DeviceSelection {
12    /// Automatically select the best available device (CUDA → Metal → CPU).
13    #[default]
14    Auto,
15    /// Force CPU execution.
16    Cpu,
17    /// Force CUDA execution on the specified GPU ordinal.
18    Cuda(usize),
19    /// Force Metal execution on the specified GPU ordinal.
20    Metal(usize),
21}
22
23impl DeviceSelection {
24    /// Resolve this selection into a concrete candle [`Device`].
25    ///
26    /// For [`DeviceSelection::Auto`], the priority is:
27    /// 1. CUDA (if `cuda` feature enabled and device available)
28    /// 2. Metal (if `metal` feature enabled and device available)
29    /// 3. CPU (always available)
30    pub fn resolve(&self) -> candle_core::Result<Device> {
31        match self {
32            Self::Cpu => {
33                info!("Using CPU device");
34                Ok(Device::Cpu)
35            }
36            Self::Cuda(ordinal) => {
37                info!("Requesting CUDA device {}", ordinal);
38                new_cuda(*ordinal)
39            }
40            Self::Metal(ordinal) => {
41                info!("Requesting Metal device {}", ordinal);
42                new_metal(*ordinal)
43            }
44            Self::Auto => auto_select(),
45        }
46    }
47
48    /// Human-readable backend label.
49    pub fn label(&self) -> String {
50        match self {
51            Self::Auto => "auto".to_string(),
52            Self::Cpu => "cpu".to_string(),
53            Self::Cuda(ordinal) => format!("cuda:{ordinal}"),
54            Self::Metal(ordinal) => format!("metal:{ordinal}"),
55        }
56    }
57
58    /// Preferred runtime candidates for the current binary, ordered fastest-first.
59    pub fn preferred_runtime_candidates() -> Vec<Self> {
60        vec![
61            #[cfg(feature = "cuda")]
62            Self::Cuda(0),
63            #[cfg(feature = "metal")]
64            Self::Metal(0),
65            Self::Cpu,
66        ]
67    }
68
69    /// Runtime candidates that successfully resolve on the current machine.
70    pub fn available_runtime_candidates() -> Vec<Self> {
71        let mut available = Vec::new();
72        for candidate in Self::preferred_runtime_candidates() {
73            if candidate.resolve().is_ok() {
74                available.push(candidate);
75            }
76        }
77
78        if available.is_empty() {
79            available.push(Self::Cpu);
80        }
81
82        available
83    }
84}
85
86/// Attempt to create a CUDA device.
87fn new_cuda(ordinal: usize) -> candle_core::Result<Device> {
88    #[cfg(feature = "cuda")]
89    {
90        let device = Device::new_cuda(ordinal)?;
91        info!("CUDA device {} initialized", ordinal);
92        Ok(device)
93    }
94    #[cfg(not(feature = "cuda"))]
95    {
96        let _ = ordinal;
97        Err(candle_core::Error::Msg(
98            "CUDA feature not enabled at compile time".to_string(),
99        ))
100    }
101}
102
103/// Attempt to create a Metal device.
104fn new_metal(ordinal: usize) -> candle_core::Result<Device> {
105    #[cfg(feature = "metal")]
106    {
107        let device = Device::new_metal(ordinal)?;
108        info!("Metal device {} initialized", ordinal);
109        Ok(device)
110    }
111    #[cfg(not(feature = "metal"))]
112    {
113        let _ = ordinal;
114        Err(candle_core::Error::Msg(
115            "Metal feature not enabled at compile time".to_string(),
116        ))
117    }
118}
119
120/// Auto-select the best device based on compiled features and availability.
121fn auto_select() -> candle_core::Result<Device> {
122    // Try CUDA first
123    #[cfg(feature = "cuda")]
124    {
125        match Device::new_cuda(0) {
126            Ok(device) => {
127                info!("Auto-selected CUDA device 0");
128                return Ok(device);
129            }
130            Err(e) => {
131                info!("CUDA not available: {}, falling back", e);
132            }
133        }
134    }
135
136    // Try Metal
137    #[cfg(feature = "metal")]
138    {
139        match Device::new_metal(0) {
140            Ok(device) => {
141                info!("Auto-selected Metal device 0");
142                return Ok(device);
143            }
144            Err(e) => {
145                info!("Metal not available: {}, falling back", e);
146            }
147        }
148    }
149
150    // Fallback to CPU
151    info!("Auto-selected CPU device");
152    Ok(Device::Cpu)
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    #[test]
160    fn test_cpu_device_always_works() {
161        let device = DeviceSelection::Cpu.resolve().unwrap();
162        assert!(matches!(device, Device::Cpu));
163    }
164
165    #[test]
166    fn test_auto_resolves_to_some_device() {
167        let device = DeviceSelection::Auto.resolve().unwrap();
168        // Should always succeed — worst case returns CPU
169        match device {
170            Device::Cpu => {}
171            Device::Cuda(_) => {}
172            Device::Metal(_) => {}
173        }
174    }
175
176    #[test]
177    fn test_default_is_auto() {
178        assert_eq!(DeviceSelection::default(), DeviceSelection::Auto);
179    }
180
181    #[test]
182    fn test_preferred_candidates_end_with_cpu() {
183        let candidates = DeviceSelection::preferred_runtime_candidates();
184        assert_eq!(candidates.last(), Some(&DeviceSelection::Cpu));
185    }
186
187    #[test]
188    fn test_device_labels_are_stable() {
189        assert_eq!(DeviceSelection::Cpu.label(), "cpu");
190        assert_eq!(DeviceSelection::Cuda(0).label(), "cuda:0");
191        assert_eq!(DeviceSelection::Metal(0).label(), "metal:0");
192    }
193}