1use rlx_driver::Device;
19use rlx_ir::Graph;
20
21use crate::cost::fastest_device_for_with_policy;
22use crate::device_ext::{DEVICE_PRIORITY, is_available, supports_graph};
23use crate::device_parse::{device_label, parse_device, parse_device_list};
24use crate::registry::backend_for;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
28pub enum DevicePickStrategy {
29 #[default]
31 CostModel,
32 Benchmark { runs: usize },
34}
35
36#[derive(Debug, Clone, Default, PartialEq, Eq)]
39pub struct DevicePolicy {
40 allow: Option<Vec<Device>>,
41 deny: Vec<Device>,
42 prefer: Vec<Device>,
43 pick: DevicePickStrategy,
44}
45
46impl DevicePolicy {
47 pub fn all() -> Self {
49 Self::default()
50 }
51
52 pub fn only(devices: impl IntoIterator<Item = Device>) -> Self {
54 Self {
55 allow: Some(devices.into_iter().collect()),
56 ..Self::default()
57 }
58 }
59
60 pub fn with_deny(mut self, devices: impl IntoIterator<Item = Device>) -> Self {
62 self.deny.extend(devices);
63 self
64 }
65
66 pub fn with_prefer(mut self, devices: impl IntoIterator<Item = Device>) -> Self {
68 self.prefer.extend(devices);
69 self
70 }
71
72 pub fn with_benchmark_pick(mut self, runs: usize) -> Self {
74 self.pick = DevicePickStrategy::Benchmark { runs: runs.max(1) };
75 self
76 }
77
78 pub fn pick_strategy(&self) -> DevicePickStrategy {
79 self.pick
80 }
81
82 pub fn from_env() -> Self {
84 Self::from_env_key("RLX")
85 }
86
87 pub fn from_env_key(prefix: &str) -> Self {
89 let mut policy = Self::default();
90 let devices_key = format!("{prefix}_DEVICES");
91 let deny_key = format!("{prefix}_DENY_DEVICES");
92 let prefer_key = format!("{prefix}_PREFER_DEVICES");
93
94 if let Some(raw) = rlx_ir::env::var(&devices_key) {
95 if let Ok(list) = parse_device_list(&raw) {
96 policy.allow = Some(list);
97 }
98 }
99 if let Some(raw) = rlx_ir::env::var(&deny_key) {
100 if let Ok(list) = parse_device_list(&raw) {
101 policy.deny = list;
102 }
103 }
104 if let Some(raw) = rlx_ir::env::var(&prefer_key) {
105 if let Ok(list) = parse_device_list(&raw) {
106 policy.prefer = list;
107 }
108 }
109 let bench_key = format!("{prefix}_BENCHMARK_PICK");
110 if let Some(raw) = rlx_ir::env::var(&bench_key) {
111 if let Ok(runs) = raw.trim().parse::<usize>() {
112 policy.pick = DevicePickStrategy::Benchmark { runs: runs.max(1) };
113 }
114 }
115 policy
116 }
117
118 pub fn probe_set(&self) -> Vec<Device> {
120 self.allow.clone().unwrap_or_else(|| Device::all().to_vec())
121 }
122
123 pub fn apply(&self, mut candidates: Vec<Device>) -> Vec<Device> {
125 if let Some(allow) = &self.allow {
126 candidates.retain(|d| allow.contains(d));
127 }
128 candidates.retain(|d| !self.deny.contains(d));
129 candidates.sort_by_key(|d| self.rank_key(*d));
130 candidates
131 }
132
133 fn rank_key(&self, device: Device) -> (u8, u8) {
134 let prefer = self
135 .prefer
136 .iter()
137 .position(|d| *d == device)
138 .map(|i| i as u8)
139 .unwrap_or(u8::MAX);
140 let platform = DEVICE_PRIORITY
141 .iter()
142 .position(|d| *d == device)
143 .map(|i| i as u8)
144 .unwrap_or(u8::MAX);
145 (prefer, platform)
146 }
147}
148
149pub fn device_from_env() -> Option<Device> {
151 device_from_env_key("RLX")
152}
153
154pub fn device_from_env_key(prefix: &str) -> Option<Device> {
156 let key = format!("{prefix}_DEVICE");
157 rlx_ir::env::var(&key).and_then(|raw| parse_device(&raw).ok())
158}
159
160pub fn devices_for_with_policy(graph: &Graph, policy: &DevicePolicy) -> Vec<Device> {
162 policy.apply(
163 crate::available_devices()
164 .into_iter()
165 .filter(|d| supports_graph(*d, graph))
166 .collect(),
167 )
168}
169
170#[derive(Debug, Clone, PartialEq)]
172pub struct DeviceCandidate {
173 pub device: Device,
174 pub label: &'static str,
175 pub available: bool,
176 pub registered: bool,
177 pub supports_graph: bool,
178 pub recommended: bool,
179 pub blocker: Option<String>,
180}
181
182pub fn device_report(graph: &Graph, policy: &DevicePolicy) -> Vec<DeviceCandidate> {
184 let recommended = fastest_device_for_with_policy(graph, policy);
185 policy
186 .probe_set()
187 .into_iter()
188 .map(|device| {
189 let available = is_available(device);
190 let registered = backend_for(device).is_some();
191 let supports = available && supports_graph(device, graph);
192 let blocker = if !available {
193 Some("not available on this host or in this build".into())
194 } else if !supports {
195 crate::first_unsupported_op(device, graph)
196 .map(|(idx, op)| format!("unsupported op at node {idx}: {op:?}"))
197 } else if policy.deny.contains(&device) {
198 Some("denied by DevicePolicy".into())
199 } else if policy
200 .allow
201 .as_ref()
202 .is_some_and(|allow| !allow.contains(&device))
203 {
204 Some("not in DevicePolicy allow-list".into())
205 } else {
206 None
207 };
208 DeviceCandidate {
209 device,
210 label: device_label(device),
211 available,
212 registered,
213 supports_graph: supports,
214 recommended: device == recommended,
215 blocker,
216 }
217 })
218 .collect()
219}
220
221pub fn resolve_device(
223 graph: &Graph,
224 hint: Option<Device>,
225 policy: &DevicePolicy,
226) -> Result<Device, String> {
227 let candidates = devices_for_with_policy(graph, policy);
228 if candidates.is_empty() {
229 return Err(
230 "no backend can lower this graph under the current policy — \
231 widen DevicePolicy or enable additional Cargo features"
232 .into(),
233 );
234 }
235
236 if let Some(device) = hint {
237 return pick_from_candidates(device, &candidates, "hint");
238 }
239 if let Some(device) = device_from_env() {
240 if let Ok(device) = pick_from_candidates(device, &candidates, "RLX_DEVICE") {
241 return Ok(device);
242 }
243 }
244 Ok(fastest_device_for_with_policy(graph, policy))
245}
246
247fn pick_from_candidates(
248 device: Device,
249 candidates: &[Device],
250 source: &str,
251) -> Result<Device, String> {
252 if candidates.contains(&device) {
253 return Ok(device);
254 }
255 Err(format!(
256 "{source} requested {device} but viable backends are [{}]",
257 candidates
258 .iter()
259 .map(|d| device_label(*d))
260 .collect::<Vec<_>>()
261 .join(", ")
262 ))
263}
264
265pub fn device_chain_from_env() -> Vec<Device> {
267 device_chain_from_env_key("RLX")
268}
269
270pub fn device_chain_from_env_key(prefix: &str) -> Vec<Device> {
272 let key = format!("{prefix}_DEVICE_CHAIN");
273 rlx_ir::env::var(&key)
274 .and_then(|raw| parse_device_list(&raw).ok())
275 .unwrap_or_default()
276}
277
278pub fn resolve_device_chain(
280 graph: &Graph,
281 chain: &[Device],
282 policy: &DevicePolicy,
283) -> Result<Device, String> {
284 let viable = devices_for_with_policy(graph, policy);
285 for &device in chain {
286 if viable.contains(&device) {
287 return Ok(device);
288 }
289 }
290 Err(format!(
291 "no device in chain [{}] can run this graph — viable: [{}]",
292 chain
293 .iter()
294 .map(|d| device_label(*d))
295 .collect::<Vec<_>>()
296 .join(", "),
297 viable
298 .iter()
299 .map(|d| device_label(*d))
300 .collect::<Vec<_>>()
301 .join(", ")
302 ))
303}
304
305#[derive(Debug, Clone, PartialEq, Eq)]
307pub struct DeviceFallbackError {
308 pub attempts: Vec<(Device, String)>,
309}
310
311impl std::fmt::Display for DeviceFallbackError {
312 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313 write!(f, "all backends failed:")?;
314 for (d, e) in &self.attempts {
315 write!(f, "\n {}: {e}", device_label(*d))?;
316 }
317 Ok(())
318 }
319}
320
321impl std::error::Error for DeviceFallbackError {}
322
323impl From<String> for DeviceFallbackError {
324 fn from(msg: String) -> Self {
325 Self {
326 attempts: vec![(Device::Cpu, msg)],
327 }
328 }
329}
330
331pub fn run_with_fallback<T, F>(
333 graph: &Graph,
334 policy: &DevicePolicy,
335 chain: &[Device],
336 mut run: F,
337) -> Result<(Device, T), DeviceFallbackError>
338where
339 F: FnMut(Device) -> Result<T, String>,
340{
341 let viable = devices_for_with_policy(graph, policy);
342 let mut attempts = Vec::new();
343 for &device in chain {
344 if !viable.contains(&device) {
345 attempts.push((device, "not viable for this graph under policy".into()));
346 continue;
347 }
348 match run(device) {
349 Ok(value) => return Ok((device, value)),
350 Err(err) => attempts.push((device, err)),
351 }
352 }
353 if attempts.is_empty() {
354 attempts.push((Device::Cpu, "empty fallback chain".into()));
355 }
356 Err(DeviceFallbackError { attempts })
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use rlx_ir::{DType, Shape};
363
364 fn tiny_graph() -> Graph {
365 let mut g = Graph::new("tiny");
366 let x = g.input("x", Shape::new(&[2], DType::F32));
367 g.set_outputs(vec![x]);
368 g
369 }
370
371 #[test]
372 fn only_policy_restricts_devices_for() {
373 let g = tiny_graph();
374 let all = devices_for_with_policy(&g, &DevicePolicy::default());
375 let cpu_only = devices_for_with_policy(&g, &DevicePolicy::only([Device::Cpu]));
376 assert_eq!(cpu_only, vec![Device::Cpu]);
377 assert!(all.contains(&Device::Cpu));
378 }
379
380 #[test]
381 fn resolve_honors_hint_then_env() {
382 let g = tiny_graph();
383 let policy = DevicePolicy::only([Device::Cpu]);
384 assert_eq!(
385 resolve_device(&g, Some(Device::Cpu), &policy).unwrap(),
386 Device::Cpu
387 );
388
389 rlx_ir::env::set("RLX_DEVICE", "cpu");
390 assert_eq!(resolve_device(&g, None, &policy).unwrap(), Device::Cpu);
391 rlx_ir::env::unset("RLX_DEVICE");
392 }
393
394 #[test]
395 fn device_report_marks_recommended() {
396 let g = tiny_graph();
397 let policy = DevicePolicy::only([Device::Cpu]);
398 let rows = device_report(&g, &policy);
399 assert_eq!(rows.len(), 1);
400 assert!(rows[0].recommended);
401 assert!(rows[0].supports_graph);
402 }
403}