use rlx_driver::Device;
use rlx_ir::Graph;
use crate::cost::fastest_device_for_with_policy;
use crate::device_ext::{DEVICE_PRIORITY, is_available, supports_graph};
use crate::device_parse::{device_label, parse_device, parse_device_list};
use crate::registry::backend_for;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DevicePickStrategy {
#[default]
CostModel,
Benchmark { runs: usize },
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct DevicePolicy {
allow: Option<Vec<Device>>,
deny: Vec<Device>,
prefer: Vec<Device>,
pick: DevicePickStrategy,
}
impl DevicePolicy {
pub fn all() -> Self {
Self::default()
}
pub fn only(devices: impl IntoIterator<Item = Device>) -> Self {
Self {
allow: Some(devices.into_iter().collect()),
..Self::default()
}
}
pub fn with_deny(mut self, devices: impl IntoIterator<Item = Device>) -> Self {
self.deny.extend(devices);
self
}
pub fn with_prefer(mut self, devices: impl IntoIterator<Item = Device>) -> Self {
self.prefer.extend(devices);
self
}
pub fn with_benchmark_pick(mut self, runs: usize) -> Self {
self.pick = DevicePickStrategy::Benchmark { runs: runs.max(1) };
self
}
pub fn pick_strategy(&self) -> DevicePickStrategy {
self.pick
}
pub fn from_env() -> Self {
Self::from_env_key("RLX")
}
pub fn from_env_key(prefix: &str) -> Self {
let mut policy = Self::default();
let devices_key = format!("{prefix}_DEVICES");
let deny_key = format!("{prefix}_DENY_DEVICES");
let prefer_key = format!("{prefix}_PREFER_DEVICES");
if let Some(raw) = rlx_ir::env::var(&devices_key) {
if let Ok(list) = parse_device_list(&raw) {
policy.allow = Some(list);
}
}
if let Some(raw) = rlx_ir::env::var(&deny_key) {
if let Ok(list) = parse_device_list(&raw) {
policy.deny = list;
}
}
if let Some(raw) = rlx_ir::env::var(&prefer_key) {
if let Ok(list) = parse_device_list(&raw) {
policy.prefer = list;
}
}
let bench_key = format!("{prefix}_BENCHMARK_PICK");
if let Some(raw) = rlx_ir::env::var(&bench_key) {
if let Ok(runs) = raw.trim().parse::<usize>() {
policy.pick = DevicePickStrategy::Benchmark { runs: runs.max(1) };
}
}
policy
}
pub fn probe_set(&self) -> Vec<Device> {
self.allow.clone().unwrap_or_else(|| Device::all().to_vec())
}
pub fn apply(&self, mut candidates: Vec<Device>) -> Vec<Device> {
if let Some(allow) = &self.allow {
candidates.retain(|d| allow.contains(d));
}
candidates.retain(|d| !self.deny.contains(d));
candidates.sort_by_key(|d| self.rank_key(*d));
candidates
}
fn rank_key(&self, device: Device) -> (u8, u8) {
let prefer = self
.prefer
.iter()
.position(|d| *d == device)
.map(|i| i as u8)
.unwrap_or(u8::MAX);
let platform = DEVICE_PRIORITY
.iter()
.position(|d| *d == device)
.map(|i| i as u8)
.unwrap_or(u8::MAX);
(prefer, platform)
}
}
pub fn device_from_env() -> Option<Device> {
device_from_env_key("RLX")
}
pub fn device_from_env_key(prefix: &str) -> Option<Device> {
let key = format!("{prefix}_DEVICE");
rlx_ir::env::var(&key).and_then(|raw| parse_device(&raw).ok())
}
pub fn devices_for_with_policy(graph: &Graph, policy: &DevicePolicy) -> Vec<Device> {
policy.apply(
crate::available_devices()
.into_iter()
.filter(|d| supports_graph(*d, graph))
.collect(),
)
}
#[derive(Debug, Clone, PartialEq)]
pub struct DeviceCandidate {
pub device: Device,
pub label: &'static str,
pub available: bool,
pub registered: bool,
pub supports_graph: bool,
pub recommended: bool,
pub blocker: Option<String>,
}
pub fn device_report(graph: &Graph, policy: &DevicePolicy) -> Vec<DeviceCandidate> {
let recommended = fastest_device_for_with_policy(graph, policy);
policy
.probe_set()
.into_iter()
.map(|device| {
let available = is_available(device);
let registered = backend_for(device).is_some();
let supports = available && supports_graph(device, graph);
let blocker = if !available {
Some("not available on this host or in this build".into())
} else if !supports {
crate::first_unsupported_op(device, graph)
.map(|(idx, op)| format!("unsupported op at node {idx}: {op:?}"))
} else if policy.deny.contains(&device) {
Some("denied by DevicePolicy".into())
} else if policy
.allow
.as_ref()
.is_some_and(|allow| !allow.contains(&device))
{
Some("not in DevicePolicy allow-list".into())
} else {
None
};
DeviceCandidate {
device,
label: device_label(device),
available,
registered,
supports_graph: supports,
recommended: device == recommended,
blocker,
}
})
.collect()
}
pub fn resolve_device(
graph: &Graph,
hint: Option<Device>,
policy: &DevicePolicy,
) -> Result<Device, String> {
let candidates = devices_for_with_policy(graph, policy);
if candidates.is_empty() {
return Err(
"no backend can lower this graph under the current policy — \
widen DevicePolicy or enable additional Cargo features"
.into(),
);
}
if let Some(device) = hint {
return pick_from_candidates(device, &candidates, "hint");
}
if let Some(device) = device_from_env() {
if let Ok(device) = pick_from_candidates(device, &candidates, "RLX_DEVICE") {
return Ok(device);
}
}
Ok(fastest_device_for_with_policy(graph, policy))
}
fn pick_from_candidates(
device: Device,
candidates: &[Device],
source: &str,
) -> Result<Device, String> {
if candidates.contains(&device) {
return Ok(device);
}
Err(format!(
"{source} requested {device} but viable backends are [{}]",
candidates
.iter()
.map(|d| device_label(*d))
.collect::<Vec<_>>()
.join(", ")
))
}
pub fn device_chain_from_env() -> Vec<Device> {
device_chain_from_env_key("RLX")
}
pub fn device_chain_from_env_key(prefix: &str) -> Vec<Device> {
let key = format!("{prefix}_DEVICE_CHAIN");
rlx_ir::env::var(&key)
.and_then(|raw| parse_device_list(&raw).ok())
.unwrap_or_default()
}
pub fn resolve_device_chain(
graph: &Graph,
chain: &[Device],
policy: &DevicePolicy,
) -> Result<Device, String> {
let viable = devices_for_with_policy(graph, policy);
for &device in chain {
if viable.contains(&device) {
return Ok(device);
}
}
Err(format!(
"no device in chain [{}] can run this graph — viable: [{}]",
chain
.iter()
.map(|d| device_label(*d))
.collect::<Vec<_>>()
.join(", "),
viable
.iter()
.map(|d| device_label(*d))
.collect::<Vec<_>>()
.join(", ")
))
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DeviceFallbackError {
pub attempts: Vec<(Device, String)>,
}
impl std::fmt::Display for DeviceFallbackError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "all backends failed:")?;
for (d, e) in &self.attempts {
write!(f, "\n {}: {e}", device_label(*d))?;
}
Ok(())
}
}
impl std::error::Error for DeviceFallbackError {}
impl From<String> for DeviceFallbackError {
fn from(msg: String) -> Self {
Self {
attempts: vec![(Device::Cpu, msg)],
}
}
}
pub fn run_with_fallback<T, F>(
graph: &Graph,
policy: &DevicePolicy,
chain: &[Device],
mut run: F,
) -> Result<(Device, T), DeviceFallbackError>
where
F: FnMut(Device) -> Result<T, String>,
{
let viable = devices_for_with_policy(graph, policy);
let mut attempts = Vec::new();
for &device in chain {
if !viable.contains(&device) {
attempts.push((device, "not viable for this graph under policy".into()));
continue;
}
match run(device) {
Ok(value) => return Ok((device, value)),
Err(err) => attempts.push((device, err)),
}
}
if attempts.is_empty() {
attempts.push((Device::Cpu, "empty fallback chain".into()));
}
Err(DeviceFallbackError { attempts })
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_ir::{DType, Shape};
fn tiny_graph() -> Graph {
let mut g = Graph::new("tiny");
let x = g.input("x", Shape::new(&[2], DType::F32));
g.set_outputs(vec![x]);
g
}
#[test]
fn only_policy_restricts_devices_for() {
let g = tiny_graph();
let all = devices_for_with_policy(&g, &DevicePolicy::default());
let cpu_only = devices_for_with_policy(&g, &DevicePolicy::only([Device::Cpu]));
assert_eq!(cpu_only, vec![Device::Cpu]);
assert!(all.contains(&Device::Cpu));
}
#[test]
fn resolve_honors_hint_then_env() {
let g = tiny_graph();
let policy = DevicePolicy::only([Device::Cpu]);
assert_eq!(
resolve_device(&g, Some(Device::Cpu), &policy).unwrap(),
Device::Cpu
);
rlx_ir::env::set("RLX_DEVICE", "cpu");
assert_eq!(resolve_device(&g, None, &policy).unwrap(), Device::Cpu);
rlx_ir::env::unset("RLX_DEVICE");
}
#[test]
fn device_report_marks_recommended() {
let g = tiny_graph();
let policy = DevicePolicy::only([Device::Cpu]);
let rows = device_report(&g, &policy);
assert_eq!(rows.len(), 1);
assert!(rows[0].recommended);
assert!(rows[0].supports_graph);
}
}