1use rlx_driver::Device;
24use rlx_ir::{Graph, Op};
25
26use crate::CompileOptions;
27
28pub(crate) const DEVICE_PRIORITY: &[Device] = &[
33 Device::Tpu,
34 Device::Cuda,
35 Device::Rocm,
36 Device::Mlx,
37 Device::Metal,
38 Device::Ane,
39 Device::Gpu,
40 Device::Vulkan,
41 Device::DirectX,
42 Device::OpenGl,
43 Device::WebGpu,
44 Device::Cpu,
45];
46
47pub fn supports_run_slots(device: Device) -> bool {
58 matches!(
59 device,
60 Device::Cpu | Device::Metal | Device::Mlx | Device::Cuda | Device::Rocm
61 )
62}
63
64pub fn is_available(device: Device) -> bool {
65 #[cfg(feature = "cuda")]
66 if device == Device::Cuda {
67 return rlx_cuda::is_available();
68 }
69 #[cfg(feature = "rocm")]
70 if device == Device::Rocm {
71 return rlx_rocm::is_available();
72 }
73 #[cfg(feature = "gpu")]
74 if device == Device::Gpu {
75 return rlx_wgpu::is_available();
76 }
77 #[cfg(feature = "vulkan")]
78 if device == Device::Vulkan {
79 return rlx_wgpu::is_vulkan_available();
80 }
81 #[cfg(feature = "tpu")]
82 if device == Device::Tpu {
83 return rlx_tpu::is_available();
84 }
85
86 let feature_gated = match device {
87 Device::Cpu => cfg!(feature = "cpu"),
88 Device::Metal => cfg!(feature = "metal"),
89 Device::Mlx => cfg!(feature = "mlx"),
90 Device::Ane => cfg!(any(feature = "coreml", feature = "ane")),
91 Device::Cuda => cfg!(feature = "cuda"),
92 Device::Rocm => cfg!(feature = "rocm"),
93 Device::Tpu => cfg!(feature = "tpu"),
94 Device::Gpu => cfg!(feature = "gpu"),
95 Device::Vulkan => cfg!(feature = "vulkan"),
96 Device::OpenGl => cfg!(feature = "opengl"),
97 Device::DirectX => cfg!(feature = "directx"),
98 Device::WebGpu => cfg!(feature = "webgpu"),
99 };
100 if feature_gated {
101 return true;
102 }
103 crate::registry::registered_devices().contains(&device)
104}
105
106#[cfg(all(feature = "apple", target_os = "macos"))]
109pub fn available_apple_devices() -> Vec<Device> {
110 [Device::Metal, Device::Mlx, Device::Gpu, Device::Ane]
111 .into_iter()
112 .filter(|d| is_available(*d))
113 .collect()
114}
115
116pub fn available_devices() -> Vec<Device> {
119 Device::all()
120 .iter()
121 .copied()
122 .filter(|d| is_available(*d))
123 .collect()
124}
125
126pub fn devices_for(graph: &Graph) -> Vec<Device> {
129 crate::device_policy::devices_for_with_policy(graph, &crate::DevicePolicy::default())
130}
131
132pub fn fastest_device() -> Device {
139 fastest_among(&available_devices())
140}
141
142pub fn fastest_among(candidates: &[Device]) -> Device {
144 for &d in DEVICE_PRIORITY {
145 if candidates.contains(&d) {
146 return d;
147 }
148 }
149 candidates.first().copied().unwrap_or(Device::Cpu)
150}
151
152pub fn full_name(device: Device) -> &'static str {
157 if let Device::Cpu = device {
158 if cfg!(feature = "blas-accelerate") {
159 return "CPU (Accelerate)";
160 }
161 if cfg!(feature = "blas-mkl") {
162 return "CPU (MKL)";
163 }
164 if cfg!(feature = "blas-openblas") {
165 return "CPU (OpenBLAS)";
166 }
167 }
168 device.name()
169}
170
171pub fn supports(device: Device, op: &Op) -> bool {
195 if !is_available(device) {
196 return false;
197 }
198 match device {
199 Device::Cpu => true, Device::Mlx => mlx_supports(op),
201 Device::Metal => metal_supports(op),
202 Device::Ane => coreml_supports(op),
203 Device::Gpu | Device::Cuda | Device::Rocm => gpu_family_supports(op),
204 _ => false,
208 }
209}
210
211pub fn supports_graph(device: Device, graph: &Graph) -> bool {
217 supports_graph_with_options(device, graph, &CompileOptions::default())
218}
219
220pub fn supports_graph_with_options(
222 device: Device,
223 graph: &Graph,
224 options: &CompileOptions,
225) -> bool {
226 if !is_available(device) {
227 return false;
228 }
229 if let Some(backend) = crate::registry::backend_for(device) {
230 let (_, report) = rlx_opt::prepare_graph_for_backend_with_report(
231 graph.clone(),
232 device.name(),
233 backend.supported_ops(),
234 options.kernel_dispatch,
235 );
236 return report.compile_ready;
237 }
238 graph.nodes().iter().all(|n| supports(device, &n.op))
239}
240
241pub fn legalize_graph_for_device(graph: Graph, device: Device) -> Result<Graph, String> {
250 let (graph, _report) = legalize_graph_for_device_with_report(graph, device)?;
251 Ok(graph)
252}
253
254pub fn legalize_graph_for_device_with_report(
256 graph: Graph,
257 device: Device,
258) -> Result<(Graph, rlx_opt::KernelDispatchReport), String> {
259 legalize_graph_for_device_with_options(graph, device, &CompileOptions::default())
260}
261
262pub fn legalize_graph_for_device_with_options(
265 graph: Graph,
266 device: Device,
267 options: &CompileOptions,
268) -> Result<(Graph, rlx_opt::KernelDispatchReport), String> {
269 let backend = crate::registry::backend_for(device).ok_or_else(|| {
270 format!(
271 "no backend registered for {device:?} — enable the matching \
272 `rlx-runtime` Cargo feature (e.g. `metal`, `gpu`, `cuda`)"
273 )
274 })?;
275 let ops = backend.supported_ops();
276 let (graph, report) = rlx_opt::prepare_graph_for_backend_with_report(
277 graph,
278 device.name(),
279 ops,
280 options.kernel_dispatch,
281 );
282 if !report.compile_ready {
283 return Err(format!(
284 "{}\n{}",
285 rlx_opt::format_legalize_error(device.name(), &report.still_unsupported),
286 rlx_opt::format_dispatch_report(&report)
287 ));
288 }
289 Ok((graph, report))
290}
291
292pub fn dispatch_report_for_device(
294 graph: &Graph,
295 device: Device,
296) -> Result<rlx_opt::KernelDispatchReport, String> {
297 dispatch_report_for_device_with_options(graph, device, &CompileOptions::default())
298}
299
300pub fn dispatch_report_for_device_with_options(
302 graph: &Graph,
303 device: Device,
304 options: &CompileOptions,
305) -> Result<rlx_opt::KernelDispatchReport, String> {
306 let backend = crate::registry::backend_for(device)
307 .ok_or_else(|| format!("no backend registered for {device:?}"))?;
308 Ok(rlx_opt::analyze_dispatch(
309 graph,
310 device.name(),
311 backend.supported_ops(),
312 options.kernel_dispatch,
313 ))
314}
315
316pub fn first_unsupported_op(device: Device, graph: &Graph) -> Option<(usize, &Op)> {
320 first_unsupported_op_with_options(device, graph, &CompileOptions::default())
321}
322
323pub fn first_unsupported_op_with_options<'a>(
325 device: Device,
326 graph: &'a Graph,
327 options: &CompileOptions,
328) -> Option<(usize, &'a Op)> {
329 if !is_available(device) {
330 return graph.nodes().first().map(|n| (0, &n.op));
331 }
332 if let Some(backend) = crate::registry::backend_for(device) {
333 let (_, report) = rlx_opt::prepare_graph_for_backend_with_report(
334 graph.clone(),
335 device.name(),
336 backend.supported_ops(),
337 options.kernel_dispatch,
338 );
339 if let Some((id, kind)) = report.still_unsupported.first() {
340 let idx = graph.nodes().iter().position(|n| n.id == *id).unwrap_or(0);
341 let op = graph
342 .nodes()
343 .iter()
344 .find(|n| n.id == *id)
345 .map(|n| &n.op)
346 .unwrap_or(&graph.nodes()[0].op);
347 let _ = kind;
348 return Some((idx, op));
349 }
350 return None;
351 }
352 graph
353 .nodes()
354 .iter()
355 .enumerate()
356 .find_map(|(i, n)| (!supports(device, &n.op)).then_some((i, &n.op)))
357}
358
359#[allow(unused_variables)]
360fn mlx_supports(op: &Op) -> bool {
361 true
366}
367
368#[allow(unused_variables)]
369fn metal_supports(op: &Op) -> bool {
370 let _ = op;
376 true
377}
378
379fn coreml_supports(op: &Op) -> bool {
384 crate::backend::COREML_SUPPORTED_OPS.contains(&op.kind())
385}
386
387#[allow(unused_variables)]
388fn gpu_family_supports(op: &Op) -> bool {
389 let _ = op;
393 true
394}
395
396pub fn drain_device(device: Device) {
399 #[cfg(all(target_os = "macos", feature = "metal"))]
400 {
401 if device == Device::Metal {
402 rlx_metal::device::drain_command_queue();
403 }
404 }
405 #[cfg(not(all(target_os = "macos", feature = "metal")))]
406 let _ = device;
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412 use rlx_ir::op::{Activation, BinaryOp};
413 use rlx_ir::{DType, Graph, Shape};
414
415 fn scalar_shape() -> Shape {
416 Shape::new(&[1], DType::F32)
417 }
418
419 #[test]
420 fn cpu_supports_everything_built_in() {
421 assert!(supports(Device::Cpu, &Op::Activation(Activation::Sin)));
422 assert!(supports(Device::Cpu, &Op::Activation(Activation::Cos)));
423 assert!(supports(Device::Cpu, &Op::Activation(Activation::Exp)));
424 assert!(supports(Device::Cpu, &Op::Binary(BinaryOp::Add)));
425 }
426
427 #[test]
428 fn unbuilt_device_supports_nothing() {
429 assert!(!supports(Device::OpenGl, &Op::Activation(Activation::Relu)));
431 }
432
433 #[test]
434 #[cfg(feature = "metal")]
435 fn metal_supports_full_activation_set() {
436 for act in [
440 Activation::Sin,
441 Activation::Cos,
442 Activation::Tan,
443 Activation::Atan,
444 Activation::Exp,
445 ] {
446 assert!(
447 supports(Device::Metal, &Op::Activation(act)),
448 "Metal should support Activation::{act:?}"
449 );
450 }
451 }
452
453 #[test]
454 fn graph_walk_reports_first_blocker() {
455 let mut g = Graph::new("walk");
456 let s = scalar_shape();
457 let x = g.input("x", s.clone());
458 let _e = g.activation(Activation::Exp, x, s.clone());
459 let _sin = g.activation(Activation::Sin, x, s);
460 assert!(supports_graph(Device::Cpu, &g));
462 assert!(first_unsupported_op(Device::Cpu, &g).is_none());
463 }
464
465 #[test]
466 fn fastest_device_returns_cpu_when_only_cpu_is_available() {
467 let pick = fastest_device();
468 assert!(is_available(pick));
469 assert_eq!(pick, fastest_among(&available_devices()));
470 }
471
472 #[test]
473 fn fastest_among_respects_priority_order() {
474 let pick = fastest_among(&[Device::Cpu, Device::Metal, Device::Mlx]);
475 assert_eq!(pick, Device::Mlx);
476 }
477
478 #[test]
479 fn devices_for_is_subset_of_available() {
480 let mut g = Graph::new("id");
481 let x = g.input("x", scalar_shape());
482 g.set_outputs(vec![x]);
483 for d in devices_for(&g) {
484 assert!(is_available(d));
485 assert!(supports_graph(d, &g));
486 }
487 }
488}