1use rlx_driver::Device;
24use rlx_ir::{Graph, Op};
25
26use crate::CompileOptions;
27
28pub(crate) const DEVICE_PRIORITY: &[Device] = &[
33 Device::Tpu,
34 Device::Cuda,
35 Device::Rocm,
36 Device::Mlx,
37 Device::Metal,
38 Device::Ane,
39 Device::Gpu,
40 Device::Vulkan,
41 Device::DirectX,
42 Device::OpenGl,
43 Device::WebGpu,
44 Device::Cpu,
45];
46
47pub fn supports_run_slots(device: Device) -> bool {
58 matches!(
59 device,
60 Device::Cpu | Device::Metal | Device::Mlx | Device::Cuda | Device::Rocm
61 )
62}
63
64pub fn is_available(device: Device) -> bool {
65 #[cfg(feature = "cuda")]
66 if device == Device::Cuda {
67 return rlx_cuda::is_available();
68 }
69 #[cfg(feature = "rocm")]
70 if device == Device::Rocm {
71 return rlx_rocm::is_available();
72 }
73 #[cfg(feature = "gpu")]
74 if device == Device::Gpu {
75 return rlx_wgpu::is_available();
76 }
77 #[cfg(feature = "vulkan")]
78 if device == Device::Vulkan {
79 return rlx_wgpu::is_vulkan_available();
80 }
81 #[cfg(feature = "tpu")]
82 if device == Device::Tpu {
83 return rlx_tpu::is_available();
84 }
85
86 let feature_gated = match device {
87 Device::Cpu => cfg!(feature = "cpu"),
88 Device::Metal => cfg!(feature = "metal"),
89 Device::Mlx => cfg!(feature = "mlx"),
90 Device::Ane => cfg!(feature = "ane"),
91 Device::Cuda => cfg!(feature = "cuda"),
92 Device::Rocm => cfg!(feature = "rocm"),
93 Device::Tpu => cfg!(feature = "tpu"),
94 Device::Gpu => cfg!(feature = "gpu"),
95 Device::Vulkan => cfg!(feature = "vulkan"),
96 Device::OpenGl => cfg!(feature = "opengl"),
97 Device::DirectX => cfg!(feature = "directx"),
98 Device::WebGpu => cfg!(feature = "webgpu"),
99 };
100 if feature_gated {
101 return true;
102 }
103 crate::registry::registered_devices().contains(&device)
104}
105
106#[cfg(all(feature = "apple", target_os = "macos"))]
108pub fn available_apple_devices() -> Vec<Device> {
109 [Device::Metal, Device::Mlx, Device::Gpu]
110 .into_iter()
111 .filter(|d| is_available(*d))
112 .collect()
113}
114
115pub fn available_devices() -> Vec<Device> {
118 Device::all()
119 .iter()
120 .copied()
121 .filter(|d| is_available(*d))
122 .collect()
123}
124
125pub fn devices_for(graph: &Graph) -> Vec<Device> {
128 crate::device_policy::devices_for_with_policy(graph, &crate::DevicePolicy::default())
129}
130
131pub fn fastest_device() -> Device {
138 fastest_among(&available_devices())
139}
140
141pub fn fastest_among(candidates: &[Device]) -> Device {
143 for &d in DEVICE_PRIORITY {
144 if candidates.contains(&d) {
145 return d;
146 }
147 }
148 candidates.first().copied().unwrap_or(Device::Cpu)
149}
150
151pub fn full_name(device: Device) -> &'static str {
156 if let Device::Cpu = device {
157 if cfg!(feature = "blas-accelerate") {
158 return "CPU (Accelerate)";
159 }
160 if cfg!(feature = "blas-mkl") {
161 return "CPU (MKL)";
162 }
163 if cfg!(feature = "blas-openblas") {
164 return "CPU (OpenBLAS)";
165 }
166 }
167 device.name()
168}
169
170pub fn supports(device: Device, op: &Op) -> bool {
194 if !is_available(device) {
195 return false;
196 }
197 match device {
198 Device::Cpu => true, Device::Mlx => mlx_supports(op),
200 Device::Metal => metal_supports(op),
201 Device::Gpu | Device::Cuda | Device::Rocm => gpu_family_supports(op),
202 _ => false,
206 }
207}
208
209pub fn supports_graph(device: Device, graph: &Graph) -> bool {
215 supports_graph_with_options(device, graph, &CompileOptions::default())
216}
217
218pub fn supports_graph_with_options(
220 device: Device,
221 graph: &Graph,
222 options: &CompileOptions,
223) -> bool {
224 if !is_available(device) {
225 return false;
226 }
227 if let Some(backend) = crate::registry::backend_for(device) {
228 let (_, report) = rlx_opt::prepare_graph_for_backend_with_report(
229 graph.clone(),
230 device.name(),
231 backend.supported_ops(),
232 options.kernel_dispatch,
233 );
234 return report.compile_ready;
235 }
236 graph.nodes().iter().all(|n| supports(device, &n.op))
237}
238
239pub fn legalize_graph_for_device(graph: Graph, device: Device) -> Result<Graph, String> {
248 let (graph, _report) = legalize_graph_for_device_with_report(graph, device)?;
249 Ok(graph)
250}
251
252pub fn legalize_graph_for_device_with_report(
254 graph: Graph,
255 device: Device,
256) -> Result<(Graph, rlx_opt::KernelDispatchReport), String> {
257 legalize_graph_for_device_with_options(graph, device, &CompileOptions::default())
258}
259
260pub fn legalize_graph_for_device_with_options(
263 graph: Graph,
264 device: Device,
265 options: &CompileOptions,
266) -> Result<(Graph, rlx_opt::KernelDispatchReport), String> {
267 let backend = crate::registry::backend_for(device).ok_or_else(|| {
268 format!(
269 "no backend registered for {device:?} — enable the matching \
270 `rlx-runtime` Cargo feature (e.g. `metal`, `gpu`, `cuda`)"
271 )
272 })?;
273 let ops = backend.supported_ops();
274 let (graph, report) = rlx_opt::prepare_graph_for_backend_with_report(
275 graph,
276 device.name(),
277 ops,
278 options.kernel_dispatch,
279 );
280 if !report.compile_ready {
281 return Err(format!(
282 "{}\n{}",
283 rlx_opt::format_legalize_error(device.name(), &report.still_unsupported),
284 rlx_opt::format_dispatch_report(&report)
285 ));
286 }
287 Ok((graph, report))
288}
289
290pub fn dispatch_report_for_device(
292 graph: &Graph,
293 device: Device,
294) -> Result<rlx_opt::KernelDispatchReport, String> {
295 dispatch_report_for_device_with_options(graph, device, &CompileOptions::default())
296}
297
298pub fn dispatch_report_for_device_with_options(
300 graph: &Graph,
301 device: Device,
302 options: &CompileOptions,
303) -> Result<rlx_opt::KernelDispatchReport, String> {
304 let backend = crate::registry::backend_for(device)
305 .ok_or_else(|| format!("no backend registered for {device:?}"))?;
306 Ok(rlx_opt::analyze_dispatch(
307 graph,
308 device.name(),
309 backend.supported_ops(),
310 options.kernel_dispatch,
311 ))
312}
313
314pub fn first_unsupported_op(device: Device, graph: &Graph) -> Option<(usize, &Op)> {
318 first_unsupported_op_with_options(device, graph, &CompileOptions::default())
319}
320
321pub fn first_unsupported_op_with_options<'a>(
323 device: Device,
324 graph: &'a Graph,
325 options: &CompileOptions,
326) -> Option<(usize, &'a Op)> {
327 if !is_available(device) {
328 return graph.nodes().first().map(|n| (0, &n.op));
329 }
330 if let Some(backend) = crate::registry::backend_for(device) {
331 let (_, report) = rlx_opt::prepare_graph_for_backend_with_report(
332 graph.clone(),
333 device.name(),
334 backend.supported_ops(),
335 options.kernel_dispatch,
336 );
337 if let Some((id, kind)) = report.still_unsupported.first() {
338 let idx = graph.nodes().iter().position(|n| n.id == *id).unwrap_or(0);
339 let op = graph
340 .nodes()
341 .iter()
342 .find(|n| n.id == *id)
343 .map(|n| &n.op)
344 .unwrap_or(&graph.nodes()[0].op);
345 let _ = kind;
346 return Some((idx, op));
347 }
348 return None;
349 }
350 graph
351 .nodes()
352 .iter()
353 .enumerate()
354 .find_map(|(i, n)| (!supports(device, &n.op)).then_some((i, &n.op)))
355}
356
357#[allow(unused_variables)]
358fn mlx_supports(op: &Op) -> bool {
359 true
364}
365
366#[allow(unused_variables)]
367fn metal_supports(op: &Op) -> bool {
368 let _ = op;
374 true
375}
376
377#[allow(unused_variables)]
378fn gpu_family_supports(op: &Op) -> bool {
379 let _ = op;
383 true
384}
385
386pub fn drain_device(device: Device) {
389 #[cfg(all(target_os = "macos", feature = "metal"))]
390 {
391 if device == Device::Metal {
392 rlx_metal::device::drain_command_queue();
393 }
394 }
395 #[cfg(not(all(target_os = "macos", feature = "metal")))]
396 let _ = device;
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use rlx_ir::op::{Activation, BinaryOp};
403 use rlx_ir::{DType, Graph, Shape};
404
405 fn scalar_shape() -> Shape {
406 Shape::new(&[1], DType::F32)
407 }
408
409 #[test]
410 fn cpu_supports_everything_built_in() {
411 assert!(supports(Device::Cpu, &Op::Activation(Activation::Sin)));
412 assert!(supports(Device::Cpu, &Op::Activation(Activation::Cos)));
413 assert!(supports(Device::Cpu, &Op::Activation(Activation::Exp)));
414 assert!(supports(Device::Cpu, &Op::Binary(BinaryOp::Add)));
415 }
416
417 #[test]
418 fn unbuilt_device_supports_nothing() {
419 assert!(!supports(Device::OpenGl, &Op::Activation(Activation::Relu)));
421 }
422
423 #[test]
424 #[cfg(feature = "metal")]
425 fn metal_supports_full_activation_set() {
426 for act in [
430 Activation::Sin,
431 Activation::Cos,
432 Activation::Tan,
433 Activation::Atan,
434 Activation::Exp,
435 ] {
436 assert!(
437 supports(Device::Metal, &Op::Activation(act)),
438 "Metal should support Activation::{act:?}"
439 );
440 }
441 }
442
443 #[test]
444 fn graph_walk_reports_first_blocker() {
445 let mut g = Graph::new("walk");
446 let s = scalar_shape();
447 let x = g.input("x", s.clone());
448 let _e = g.activation(Activation::Exp, x, s.clone());
449 let _sin = g.activation(Activation::Sin, x, s);
450 assert!(supports_graph(Device::Cpu, &g));
452 assert!(first_unsupported_op(Device::Cpu, &g).is_none());
453 }
454
455 #[test]
456 fn fastest_device_returns_cpu_when_only_cpu_is_available() {
457 let pick = fastest_device();
458 assert!(is_available(pick));
459 assert_eq!(pick, fastest_among(&available_devices()));
460 }
461
462 #[test]
463 fn fastest_among_respects_priority_order() {
464 let pick = fastest_among(&[Device::Cpu, Device::Metal, Device::Mlx]);
465 assert_eq!(pick, Device::Mlx);
466 }
467
468 #[test]
469 fn devices_for_is_subset_of_available() {
470 let mut g = Graph::new("id");
471 let x = g.input("x", scalar_shape());
472 g.set_outputs(vec![x]);
473 for d in devices_for(&g) {
474 assert!(is_available(d));
475 assert!(supports_graph(d, &g));
476 }
477 }
478}