Skip to main content

rlx_runtime/
device_policy.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Backend allowlists, env-driven defaults, and selection introspection.
17
18use rlx_driver::Device;
19use rlx_ir::Graph;
20
21use crate::cost::fastest_device_for_with_policy;
22use crate::device_ext::{DEVICE_PRIORITY, is_available, supports_graph};
23use crate::device_parse::{device_label, parse_device, parse_device_list};
24use crate::registry::backend_for;
25
26/// How [`GraphDevices::resolve_with_inputs`] picks a backend when no hint is set.
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
28pub enum DevicePickStrategy {
29    /// Rank via calibrated cost models + platform priority (default).
30    #[default]
31    CostModel,
32    /// Run a short [`crate::benchmark_devices`] once and cache the winner.
33    Benchmark { runs: usize },
34}
35
36/// Which backends a process may use — intersected with compile-time features
37/// and runtime availability.
38#[derive(Debug, Clone, Default, PartialEq, Eq)]
39pub struct DevicePolicy {
40    allow: Option<Vec<Device>>,
41    deny: Vec<Device>,
42    prefer: Vec<Device>,
43    pick: DevicePickStrategy,
44}
45
46impl DevicePolicy {
47    /// Allow every compiled-in backend (default).
48    pub fn all() -> Self {
49        Self::default()
50    }
51
52    /// Restrict to an explicit backend set the developer ships.
53    pub fn only(devices: impl IntoIterator<Item = Device>) -> Self {
54        Self {
55            allow: Some(devices.into_iter().collect()),
56            ..Self::default()
57        }
58    }
59
60    /// Exclude specific backends from consideration.
61    pub fn with_deny(mut self, devices: impl IntoIterator<Item = Device>) -> Self {
62        self.deny.extend(devices);
63        self
64    }
65
66    /// Prefer these backends when cost models tie or are unavailable.
67    pub fn with_prefer(mut self, devices: impl IntoIterator<Item = Device>) -> Self {
68        self.prefer.extend(devices);
69        self
70    }
71
72    /// Pick the fastest backend via a one-time micro-benchmark (needs inputs at resolve time).
73    pub fn with_benchmark_pick(mut self, runs: usize) -> Self {
74        self.pick = DevicePickStrategy::Benchmark { runs: runs.max(1) };
75        self
76    }
77
78    pub fn pick_strategy(&self) -> DevicePickStrategy {
79        self.pick
80    }
81
82    /// Read policy from process env (see [`Self::from_env_key`]).
83    pub fn from_env() -> Self {
84        Self::from_env_key("RLX")
85    }
86
87    /// Read `PREFIX_DEVICES`, `PREFIX_DENY_DEVICES`, `PREFIX_PREFER_DEVICES`.
88    pub fn from_env_key(prefix: &str) -> Self {
89        let mut policy = Self::default();
90        let devices_key = format!("{prefix}_DEVICES");
91        let deny_key = format!("{prefix}_DENY_DEVICES");
92        let prefer_key = format!("{prefix}_PREFER_DEVICES");
93
94        if let Some(raw) = rlx_ir::env::var(&devices_key) {
95            if let Ok(list) = parse_device_list(&raw) {
96                policy.allow = Some(list);
97            }
98        }
99        if let Some(raw) = rlx_ir::env::var(&deny_key) {
100            if let Ok(list) = parse_device_list(&raw) {
101                policy.deny = list;
102            }
103        }
104        if let Some(raw) = rlx_ir::env::var(&prefer_key) {
105            if let Ok(list) = parse_device_list(&raw) {
106                policy.prefer = list;
107            }
108        }
109        let bench_key = format!("{prefix}_BENCHMARK_PICK");
110        if let Some(raw) = rlx_ir::env::var(&bench_key) {
111            if let Ok(runs) = raw.trim().parse::<usize>() {
112                policy.pick = DevicePickStrategy::Benchmark { runs: runs.max(1) };
113            }
114        }
115        policy
116    }
117
118    /// Devices to show in reports when no allow-list is set.
119    pub fn probe_set(&self) -> Vec<Device> {
120        self.allow.clone().unwrap_or_else(|| Device::all().to_vec())
121    }
122
123    /// Filter and order `candidates` according to this policy.
124    pub fn apply(&self, mut candidates: Vec<Device>) -> Vec<Device> {
125        if let Some(allow) = &self.allow {
126            candidates.retain(|d| allow.contains(d));
127        }
128        candidates.retain(|d| !self.deny.contains(d));
129        candidates.sort_by_key(|d| self.rank_key(*d));
130        candidates
131    }
132
133    fn rank_key(&self, device: Device) -> (u8, u8) {
134        let prefer = self
135            .prefer
136            .iter()
137            .position(|d| *d == device)
138            .map(|i| i as u8)
139            .unwrap_or(u8::MAX);
140        let platform = DEVICE_PRIORITY
141            .iter()
142            .position(|d| *d == device)
143            .map(|i| i as u8)
144            .unwrap_or(u8::MAX);
145        (prefer, platform)
146    }
147}
148
149/// Default device hint from `RLX_DEVICE` (or `PREFIX_DEVICE` via [`device_from_env_key`]).
150pub fn device_from_env() -> Option<Device> {
151    device_from_env_key("RLX")
152}
153
154/// Read `PREFIX_DEVICE` env var.
155pub fn device_from_env_key(prefix: &str) -> Option<Device> {
156    let key = format!("{prefix}_DEVICE");
157    rlx_ir::env::var(&key).and_then(|raw| parse_device(&raw).ok())
158}
159
160/// Backends on this host that can lower `graph`, filtered by `policy`.
161pub fn devices_for_with_policy(graph: &Graph, policy: &DevicePolicy) -> Vec<Device> {
162    policy.apply(
163        crate::available_devices()
164            .into_iter()
165            .filter(|d| supports_graph(*d, graph))
166            .collect(),
167    )
168}
169
170/// One row of backend introspection for UIs and logs.
171#[derive(Debug, Clone, PartialEq)]
172pub struct DeviceCandidate {
173    pub device: Device,
174    pub label: &'static str,
175    pub available: bool,
176    pub registered: bool,
177    pub supports_graph: bool,
178    pub recommended: bool,
179    pub blocker: Option<String>,
180}
181
182/// Explain which backends are viable for `graph` under `policy`.
183pub fn device_report(graph: &Graph, policy: &DevicePolicy) -> Vec<DeviceCandidate> {
184    let recommended = fastest_device_for_with_policy(graph, policy);
185    policy
186        .probe_set()
187        .into_iter()
188        .map(|device| {
189            let available = is_available(device);
190            let registered = backend_for(device).is_some();
191            let supports = available && supports_graph(device, graph);
192            let blocker = if !available {
193                Some("not available on this host or in this build".into())
194            } else if !supports {
195                crate::first_unsupported_op(device, graph)
196                    .map(|(idx, op)| format!("unsupported op at node {idx}: {op:?}"))
197            } else if policy.deny.contains(&device) {
198                Some("denied by DevicePolicy".into())
199            } else if policy
200                .allow
201                .as_ref()
202                .is_some_and(|allow| !allow.contains(&device))
203            {
204                Some("not in DevicePolicy allow-list".into())
205            } else {
206                None
207            };
208            DeviceCandidate {
209                device,
210                label: device_label(device),
211                available,
212                registered,
213                supports_graph: supports,
214                recommended: device == recommended,
215                blocker,
216            }
217        })
218        .collect()
219}
220
221/// Resolve the backend to use: explicit hint → env → fastest for `graph`.
222pub fn resolve_device(
223    graph: &Graph,
224    hint: Option<Device>,
225    policy: &DevicePolicy,
226) -> Result<Device, String> {
227    let candidates = devices_for_with_policy(graph, policy);
228    if candidates.is_empty() {
229        return Err(
230            "no backend can lower this graph under the current policy — \
231             widen DevicePolicy or enable additional Cargo features"
232                .into(),
233        );
234    }
235
236    if let Some(device) = hint {
237        return pick_from_candidates(device, &candidates, "hint");
238    }
239    if let Some(device) = device_from_env() {
240        if let Ok(device) = pick_from_candidates(device, &candidates, "RLX_DEVICE") {
241            return Ok(device);
242        }
243    }
244    Ok(fastest_device_for_with_policy(graph, policy))
245}
246
247fn pick_from_candidates(
248    device: Device,
249    candidates: &[Device],
250    source: &str,
251) -> Result<Device, String> {
252    if candidates.contains(&device) {
253        return Ok(device);
254    }
255    Err(format!(
256        "{source} requested {device} but viable backends are [{}]",
257        candidates
258            .iter()
259            .map(|d| device_label(*d))
260            .collect::<Vec<_>>()
261            .join(", ")
262    ))
263}
264
265/// Ordered fallback chain from `RLX_DEVICE_CHAIN` (`cuda,gpu,cpu`).
266pub fn device_chain_from_env() -> Vec<Device> {
267    device_chain_from_env_key("RLX")
268}
269
270/// Read `PREFIX_DEVICE_CHAIN`.
271pub fn device_chain_from_env_key(prefix: &str) -> Vec<Device> {
272    let key = format!("{prefix}_DEVICE_CHAIN");
273    rlx_ir::env::var(&key)
274        .and_then(|raw| parse_device_list(&raw).ok())
275        .unwrap_or_default()
276}
277
278/// First device in `chain` that is viable under `policy` for `graph`.
279pub fn resolve_device_chain(
280    graph: &Graph,
281    chain: &[Device],
282    policy: &DevicePolicy,
283) -> Result<Device, String> {
284    let viable = devices_for_with_policy(graph, policy);
285    for &device in chain {
286        if viable.contains(&device) {
287            return Ok(device);
288        }
289    }
290    Err(format!(
291        "no device in chain [{}] can run this graph — viable: [{}]",
292        chain
293            .iter()
294            .map(|d| device_label(*d))
295            .collect::<Vec<_>>()
296            .join(", "),
297        viable
298            .iter()
299            .map(|d| device_label(*d))
300            .collect::<Vec<_>>()
301            .join(", ")
302    ))
303}
304
305/// Errors collected when every backend in the chain fails at run time.
306#[derive(Debug, Clone, PartialEq, Eq)]
307pub struct DeviceFallbackError {
308    pub attempts: Vec<(Device, String)>,
309}
310
311impl std::fmt::Display for DeviceFallbackError {
312    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313        write!(f, "all backends failed:")?;
314        for (d, e) in &self.attempts {
315            write!(f, "\n  {}: {e}", device_label(*d))?;
316        }
317        Ok(())
318    }
319}
320
321impl std::error::Error for DeviceFallbackError {}
322
323impl From<String> for DeviceFallbackError {
324    fn from(msg: String) -> Self {
325        Self {
326            attempts: vec![(Device::Cpu, msg)],
327        }
328    }
329}
330
331/// Try `chain` in order; return the first successful result from `run`.
332pub fn run_with_fallback<T, F>(
333    graph: &Graph,
334    policy: &DevicePolicy,
335    chain: &[Device],
336    mut run: F,
337) -> Result<(Device, T), DeviceFallbackError>
338where
339    F: FnMut(Device) -> Result<T, String>,
340{
341    let viable = devices_for_with_policy(graph, policy);
342    let mut attempts = Vec::new();
343    for &device in chain {
344        if !viable.contains(&device) {
345            attempts.push((device, "not viable for this graph under policy".into()));
346            continue;
347        }
348        match run(device) {
349            Ok(value) => return Ok((device, value)),
350            Err(err) => attempts.push((device, err)),
351        }
352    }
353    if attempts.is_empty() {
354        attempts.push((Device::Cpu, "empty fallback chain".into()));
355    }
356    Err(DeviceFallbackError { attempts })
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use rlx_ir::{DType, Shape};
363
364    fn tiny_graph() -> Graph {
365        let mut g = Graph::new("tiny");
366        let x = g.input("x", Shape::new(&[2], DType::F32));
367        g.set_outputs(vec![x]);
368        g
369    }
370
371    #[test]
372    fn only_policy_restricts_devices_for() {
373        let g = tiny_graph();
374        let all = devices_for_with_policy(&g, &DevicePolicy::default());
375        let cpu_only = devices_for_with_policy(&g, &DevicePolicy::only([Device::Cpu]));
376        assert_eq!(cpu_only, vec![Device::Cpu]);
377        assert!(all.contains(&Device::Cpu));
378    }
379
380    #[test]
381    fn resolve_honors_hint_then_env() {
382        let g = tiny_graph();
383        let policy = DevicePolicy::only([Device::Cpu]);
384        assert_eq!(
385            resolve_device(&g, Some(Device::Cpu), &policy).unwrap(),
386            Device::Cpu
387        );
388
389        rlx_ir::env::set("RLX_DEVICE", "cpu");
390        assert_eq!(resolve_device(&g, None, &policy).unwrap(), Device::Cpu);
391        rlx_ir::env::unset("RLX_DEVICE");
392    }
393
394    #[test]
395    fn device_report_marks_recommended() {
396        let g = tiny_graph();
397        let policy = DevicePolicy::only([Device::Cpu]);
398        let rows = device_report(&g, &policy);
399        assert_eq!(rows.len(), 1);
400        assert!(rows[0].recommended);
401        assert!(rows[0].supports_graph);
402    }
403}