1use candle_core::Device;
7use tracing::info;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
11pub enum DeviceSelection {
12 #[default]
14 Auto,
15 Cpu,
17 Cuda(usize),
19 Metal(usize),
21}
22
23impl DeviceSelection {
24 pub fn resolve(&self) -> candle_core::Result<Device> {
31 match self {
32 Self::Cpu => {
33 info!("Using CPU device");
34 Ok(Device::Cpu)
35 }
36 Self::Cuda(ordinal) => {
37 info!("Requesting CUDA device {}", ordinal);
38 new_cuda(*ordinal)
39 }
40 Self::Metal(ordinal) => {
41 info!("Requesting Metal device {}", ordinal);
42 new_metal(*ordinal)
43 }
44 Self::Auto => auto_select(),
45 }
46 }
47
48 pub fn label(&self) -> String {
50 match self {
51 Self::Auto => "auto".to_string(),
52 Self::Cpu => "cpu".to_string(),
53 Self::Cuda(ordinal) => format!("cuda:{ordinal}"),
54 Self::Metal(ordinal) => format!("metal:{ordinal}"),
55 }
56 }
57
58 pub fn preferred_runtime_candidates() -> Vec<Self> {
60 vec![
61 #[cfg(feature = "cuda")]
62 Self::Cuda(0),
63 #[cfg(feature = "metal")]
64 Self::Metal(0),
65 Self::Cpu,
66 ]
67 }
68
69 pub fn available_runtime_candidates() -> Vec<Self> {
71 let mut available = Vec::new();
72 for candidate in Self::preferred_runtime_candidates() {
73 if candidate.resolve().is_ok() {
74 available.push(candidate);
75 }
76 }
77
78 if available.is_empty() {
79 available.push(Self::Cpu);
80 }
81
82 available
83 }
84}
85
86fn new_cuda(ordinal: usize) -> candle_core::Result<Device> {
88 #[cfg(feature = "cuda")]
89 {
90 let device = Device::new_cuda(ordinal)?;
91 info!("CUDA device {} initialized", ordinal);
92 Ok(device)
93 }
94 #[cfg(not(feature = "cuda"))]
95 {
96 let _ = ordinal;
97 Err(candle_core::Error::Msg(
98 "CUDA feature not enabled at compile time".to_string(),
99 ))
100 }
101}
102
103fn new_metal(ordinal: usize) -> candle_core::Result<Device> {
105 #[cfg(feature = "metal")]
106 {
107 let device = Device::new_metal(ordinal)?;
108 info!("Metal device {} initialized", ordinal);
109 Ok(device)
110 }
111 #[cfg(not(feature = "metal"))]
112 {
113 let _ = ordinal;
114 Err(candle_core::Error::Msg(
115 "Metal feature not enabled at compile time".to_string(),
116 ))
117 }
118}
119
120fn auto_select() -> candle_core::Result<Device> {
122 #[cfg(feature = "cuda")]
124 {
125 match Device::new_cuda(0) {
126 Ok(device) => {
127 info!("Auto-selected CUDA device 0");
128 return Ok(device);
129 }
130 Err(e) => {
131 info!("CUDA not available: {}, falling back", e);
132 }
133 }
134 }
135
136 #[cfg(feature = "metal")]
138 {
139 match Device::new_metal(0) {
140 Ok(device) => {
141 info!("Auto-selected Metal device 0");
142 return Ok(device);
143 }
144 Err(e) => {
145 info!("Metal not available: {}, falling back", e);
146 }
147 }
148 }
149
150 info!("Auto-selected CPU device");
152 Ok(Device::Cpu)
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[test]
160 fn test_cpu_device_always_works() {
161 let device = DeviceSelection::Cpu.resolve().unwrap();
162 assert!(matches!(device, Device::Cpu));
163 }
164
165 #[test]
166 fn test_auto_resolves_to_some_device() {
167 let device = DeviceSelection::Auto.resolve().unwrap();
168 match device {
170 Device::Cpu => {}
171 Device::Cuda(_) => {}
172 Device::Metal(_) => {}
173 }
174 }
175
176 #[test]
177 fn test_default_is_auto() {
178 assert_eq!(DeviceSelection::default(), DeviceSelection::Auto);
179 }
180
181 #[test]
182 fn test_preferred_candidates_end_with_cpu() {
183 let candidates = DeviceSelection::preferred_runtime_candidates();
184 assert_eq!(candidates.last(), Some(&DeviceSelection::Cpu));
185 }
186
187 #[test]
188 fn test_device_labels_are_stable() {
189 assert_eq!(DeviceSelection::Cpu.label(), "cpu");
190 assert_eq!(DeviceSelection::Cuda(0).label(), "cuda:0");
191 assert_eq!(DeviceSelection::Metal(0).label(), "metal:0");
192 }
193}