Skip to main content

rlx_runtime/
device_router.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// Licensed under the GNU General Public License, version 3.
5
6//! Serving-oriented wrapper: warm-all backends, fallback execution, throttle-aware re-warm.
7
8use rlx_driver::Device;
9use rlx_ir::{DType, Graph};
10
11use crate::device_parse::device_label;
12use crate::device_policy::{DeviceFallbackError, DevicePolicy};
13use crate::graph_devices::GraphDevices;
14use crate::hwinfo::HwSnapshot;
15
16/// Production helper: compile all viable backends up front, run with fallback chain.
17pub struct DeviceRouter {
18    runner: GraphDevices,
19    rebench_on_throttle: bool,
20}
21
22impl DeviceRouter {
23    pub fn new(graph: Graph, policy: DevicePolicy) -> Result<Self, String> {
24        let mut runner = GraphDevices::with_policy(graph, policy);
25        runner.warm_all()?;
26        Ok(Self {
27            runner,
28            rebench_on_throttle: true,
29        })
30    }
31
32    pub fn from_env(graph: Graph) -> Result<Self, String> {
33        Self::new(graph, DevicePolicy::from_env())
34    }
35
36    pub fn with_rebench_on_throttle(mut self, enabled: bool) -> Self {
37        self.rebench_on_throttle = enabled;
38        self
39    }
40
41    pub fn set_rebench_on_throttle(&mut self, enabled: bool) {
42        self.rebench_on_throttle = enabled;
43    }
44
45    pub fn runner(&self) -> &GraphDevices {
46        &self.runner
47    }
48
49    pub fn runner_mut(&mut self) -> &mut GraphDevices {
50        &mut self.runner
51    }
52
53    fn maybe_rewarm(&mut self) -> Result<(), String> {
54        if self.rebench_on_throttle && HwSnapshot::collect().is_throttled() {
55            self.runner.invalidate_cache();
56            self.runner.warm_all()?;
57        }
58        Ok(())
59    }
60
61    /// Run on explicit device.
62    pub fn run_on(
63        &mut self,
64        device: Device,
65        inputs: &[(&str, &[f32])],
66    ) -> Result<Vec<Vec<f32>>, String> {
67        self.maybe_rewarm()?;
68        self.runner.run(device, inputs)
69    }
70
71    /// Resolve hint / env / benchmark policy, then execute.
72    pub fn run(
73        &mut self,
74        inputs: &[(&str, &[f32])],
75        hint: Option<Device>,
76    ) -> Result<(Device, Vec<Vec<f32>>), String> {
77        self.maybe_rewarm()?;
78        let device = self.runner.resolve_with_inputs(hint, inputs)?;
79        let out = self.runner.run(device, inputs)?;
80        Ok((device, out))
81    }
82
83    /// Fallback chain (`RLX_DEVICE_CHAIN` or explicit).
84    pub fn run_chain(
85        &mut self,
86        inputs: &[(&str, &[f32])],
87        hint: Option<Device>,
88    ) -> Result<(Device, Vec<Vec<f32>>), DeviceFallbackError> {
89        self.maybe_rewarm()?;
90        self.runner.run_chain(hint, inputs)
91    }
92
93    pub fn set_param(&mut self, name: &str, data: &[f32]) {
94        self.runner.set_param(name, data);
95    }
96
97    pub fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: DType) {
98        self.runner.set_param_typed(name, data, dtype);
99    }
100
101    pub fn devices(&self) -> Vec<&'static str> {
102        self.runner
103            .devices()
104            .iter()
105            .map(|d| device_label(*d))
106            .collect()
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113    use rlx_ir::{DType, Shape};
114
115    #[test]
116    fn router_runs_identity_on_cpu() {
117        let mut g = Graph::new("id");
118        let x = g.input("x", Shape::new(&[4], DType::F32));
119        g.set_outputs(vec![x]);
120        let mut router = DeviceRouter::new(g, DevicePolicy::only([Device::Cpu])).expect("router");
121        let (dev, out) = router
122            .run(&[("x", &[1.0, 2.0, 3.0, 4.0])], None)
123            .expect("run");
124        assert_eq!(dev, Device::Cpu);
125        assert_eq!(out[0], vec![1.0, 2.0, 3.0, 4.0]);
126    }
127}