1use rlx_driver::Device;
24use rlx_ir::{Graph, Op};
25
26use crate::CompileOptions;
27
28pub fn is_available(device: Device) -> bool {
37 #[cfg(feature = "cuda")]
38 if device == Device::Cuda {
39 return rlx_cuda::is_available();
40 }
41 #[cfg(feature = "rocm")]
42 if device == Device::Rocm {
43 return rlx_rocm::is_available();
44 }
45 #[cfg(feature = "gpu")]
46 if device == Device::Gpu {
47 return rlx_wgpu::is_available();
48 }
49 #[cfg(feature = "vulkan")]
50 if device == Device::Vulkan {
51 return rlx_wgpu::is_vulkan_available();
52 }
53 #[cfg(feature = "tpu")]
54 if device == Device::Tpu {
55 return rlx_tpu::is_available();
56 }
57
58 let feature_gated = match device {
59 Device::Cpu => cfg!(feature = "cpu"),
60 Device::Metal => cfg!(feature = "metal"),
61 Device::Mlx => cfg!(feature = "mlx"),
62 Device::Ane => cfg!(feature = "ane"),
63 Device::Cuda => cfg!(feature = "cuda"),
64 Device::Rocm => cfg!(feature = "rocm"),
65 Device::Tpu => cfg!(feature = "tpu"),
66 Device::Gpu => cfg!(feature = "gpu"),
67 Device::Vulkan => cfg!(feature = "vulkan"),
68 Device::OpenGl => cfg!(feature = "opengl"),
69 Device::DirectX => cfg!(feature = "directx"),
70 Device::WebGpu => cfg!(feature = "webgpu"),
71 };
72 if feature_gated {
73 return true;
74 }
75 crate::registry::registered_devices().contains(&device)
76}
77
78#[cfg(all(feature = "apple", target_os = "macos"))]
80pub fn available_apple_devices() -> Vec<Device> {
81 [Device::Metal, Device::Mlx, Device::Gpu]
82 .into_iter()
83 .filter(|d| is_available(*d))
84 .collect()
85}
86
87pub fn available_devices() -> Vec<Device> {
90 Device::all()
91 .iter()
92 .copied()
93 .filter(|d| is_available(*d))
94 .collect()
95}
96
97pub fn full_name(device: Device) -> &'static str {
102 if let Device::Cpu = device {
103 if cfg!(feature = "blas-accelerate") {
104 return "CPU (Accelerate)";
105 }
106 if cfg!(feature = "blas-mkl") {
107 return "CPU (MKL)";
108 }
109 if cfg!(feature = "blas-openblas") {
110 return "CPU (OpenBLAS)";
111 }
112 }
113 device.name()
114}
115
116pub fn supports(device: Device, op: &Op) -> bool {
140 if !is_available(device) {
141 return false;
142 }
143 match device {
144 Device::Cpu => true, Device::Mlx => mlx_supports(op),
146 Device::Metal => metal_supports(op),
147 Device::Gpu | Device::Cuda | Device::Rocm => gpu_family_supports(op),
148 _ => false,
152 }
153}
154
155pub fn supports_graph(device: Device, graph: &Graph) -> bool {
161 supports_graph_with_options(device, graph, &CompileOptions::default())
162}
163
164pub fn supports_graph_with_options(
166 device: Device,
167 graph: &Graph,
168 options: &CompileOptions,
169) -> bool {
170 if !is_available(device) {
171 return false;
172 }
173 if let Some(backend) = crate::registry::backend_for(device) {
174 let (_, report) = rlx_opt::prepare_graph_for_backend_with_report(
175 graph.clone(),
176 device.name(),
177 backend.supported_ops(),
178 options.kernel_dispatch,
179 );
180 return report.compile_ready;
181 }
182 graph.nodes().iter().all(|n| supports(device, &n.op))
183}
184
185pub fn legalize_graph_for_device(graph: Graph, device: Device) -> Result<Graph, String> {
194 let (graph, _report) = legalize_graph_for_device_with_report(graph, device)?;
195 Ok(graph)
196}
197
198pub fn legalize_graph_for_device_with_report(
200 graph: Graph,
201 device: Device,
202) -> Result<(Graph, rlx_opt::KernelDispatchReport), String> {
203 legalize_graph_for_device_with_options(graph, device, &CompileOptions::default())
204}
205
206pub fn legalize_graph_for_device_with_options(
209 graph: Graph,
210 device: Device,
211 options: &CompileOptions,
212) -> Result<(Graph, rlx_opt::KernelDispatchReport), String> {
213 let backend = crate::registry::backend_for(device).ok_or_else(|| {
214 format!(
215 "no backend registered for {device:?} — enable the matching \
216 `rlx-runtime` Cargo feature (e.g. `metal`, `gpu`, `cuda`)"
217 )
218 })?;
219 let ops = backend.supported_ops();
220 let (graph, report) = rlx_opt::prepare_graph_for_backend_with_report(
221 graph,
222 device.name(),
223 ops,
224 options.kernel_dispatch,
225 );
226 if !report.compile_ready {
227 return Err(format!(
228 "{}\n{}",
229 rlx_opt::format_legalize_error(device.name(), &report.still_unsupported),
230 rlx_opt::format_dispatch_report(&report)
231 ));
232 }
233 Ok((graph, report))
234}
235
236pub fn dispatch_report_for_device(
238 graph: &Graph,
239 device: Device,
240) -> Result<rlx_opt::KernelDispatchReport, String> {
241 dispatch_report_for_device_with_options(graph, device, &CompileOptions::default())
242}
243
244pub fn dispatch_report_for_device_with_options(
246 graph: &Graph,
247 device: Device,
248 options: &CompileOptions,
249) -> Result<rlx_opt::KernelDispatchReport, String> {
250 let backend = crate::registry::backend_for(device)
251 .ok_or_else(|| format!("no backend registered for {device:?}"))?;
252 Ok(rlx_opt::analyze_dispatch(
253 graph,
254 device.name(),
255 backend.supported_ops(),
256 options.kernel_dispatch,
257 ))
258}
259
260pub fn first_unsupported_op(device: Device, graph: &Graph) -> Option<(usize, &Op)> {
264 first_unsupported_op_with_options(device, graph, &CompileOptions::default())
265}
266
267pub fn first_unsupported_op_with_options<'a>(
269 device: Device,
270 graph: &'a Graph,
271 options: &CompileOptions,
272) -> Option<(usize, &'a Op)> {
273 if !is_available(device) {
274 return graph.nodes().first().map(|n| (0, &n.op));
275 }
276 if let Some(backend) = crate::registry::backend_for(device) {
277 let (_, report) = rlx_opt::prepare_graph_for_backend_with_report(
278 graph.clone(),
279 device.name(),
280 backend.supported_ops(),
281 options.kernel_dispatch,
282 );
283 if let Some((id, kind)) = report.still_unsupported.first() {
284 let idx = graph.nodes().iter().position(|n| n.id == *id).unwrap_or(0);
285 let op = graph
286 .nodes()
287 .iter()
288 .find(|n| n.id == *id)
289 .map(|n| &n.op)
290 .unwrap_or(&graph.nodes()[0].op);
291 let _ = kind;
292 return Some((idx, op));
293 }
294 return None;
295 }
296 graph
297 .nodes()
298 .iter()
299 .enumerate()
300 .find_map(|(i, n)| (!supports(device, &n.op)).then_some((i, &n.op)))
301}
302
303#[allow(unused_variables)]
304fn mlx_supports(op: &Op) -> bool {
305 true
310}
311
312#[allow(unused_variables)]
313fn metal_supports(op: &Op) -> bool {
314 let _ = op;
320 true
321}
322
323#[allow(unused_variables)]
324fn gpu_family_supports(op: &Op) -> bool {
325 let _ = op;
329 true
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335 use rlx_ir::op::{Activation, BinaryOp};
336 use rlx_ir::{DType, Graph, Shape};
337
338 fn scalar_shape() -> Shape {
339 Shape::new(&[1], DType::F32)
340 }
341
342 #[test]
343 fn cpu_supports_everything_built_in() {
344 assert!(supports(Device::Cpu, &Op::Activation(Activation::Sin)));
345 assert!(supports(Device::Cpu, &Op::Activation(Activation::Cos)));
346 assert!(supports(Device::Cpu, &Op::Activation(Activation::Exp)));
347 assert!(supports(Device::Cpu, &Op::Binary(BinaryOp::Add)));
348 }
349
350 #[test]
351 fn unbuilt_device_supports_nothing() {
352 assert!(!supports(Device::OpenGl, &Op::Activation(Activation::Relu)));
354 }
355
356 #[test]
357 #[cfg(feature = "metal")]
358 fn metal_supports_full_activation_set() {
359 for act in [
363 Activation::Sin,
364 Activation::Cos,
365 Activation::Tan,
366 Activation::Atan,
367 Activation::Exp,
368 ] {
369 assert!(
370 supports(Device::Metal, &Op::Activation(act)),
371 "Metal should support Activation::{act:?}"
372 );
373 }
374 }
375
376 #[test]
377 fn graph_walk_reports_first_blocker() {
378 let mut g = Graph::new("walk");
379 let s = scalar_shape();
380 let x = g.input("x", s.clone());
381 let _e = g.activation(Activation::Exp, x, s.clone());
382 let _sin = g.activation(Activation::Sin, x, s);
383 assert!(supports_graph(Device::Cpu, &g));
385 assert!(first_unsupported_op(Device::Cpu, &g).is_none());
386 }
387}