rlx_runtime/
device_router.rs1use rlx_driver::Device;
9use rlx_ir::{DType, Graph};
10
11use crate::device_parse::device_label;
12use crate::device_policy::{DeviceFallbackError, DevicePolicy};
13use crate::graph_devices::GraphDevices;
14use crate::hwinfo::HwSnapshot;
15
16pub struct DeviceRouter {
18 runner: GraphDevices,
19 rebench_on_throttle: bool,
20}
21
22impl DeviceRouter {
23 pub fn new(graph: Graph, policy: DevicePolicy) -> Result<Self, String> {
24 let mut runner = GraphDevices::with_policy(graph, policy);
25 runner.warm_all()?;
26 Ok(Self {
27 runner,
28 rebench_on_throttle: true,
29 })
30 }
31
32 pub fn from_env(graph: Graph) -> Result<Self, String> {
33 Self::new(graph, DevicePolicy::from_env())
34 }
35
36 pub fn with_rebench_on_throttle(mut self, enabled: bool) -> Self {
37 self.rebench_on_throttle = enabled;
38 self
39 }
40
41 pub fn set_rebench_on_throttle(&mut self, enabled: bool) {
42 self.rebench_on_throttle = enabled;
43 }
44
45 pub fn runner(&self) -> &GraphDevices {
46 &self.runner
47 }
48
49 pub fn runner_mut(&mut self) -> &mut GraphDevices {
50 &mut self.runner
51 }
52
53 fn maybe_rewarm(&mut self) -> Result<(), String> {
54 if self.rebench_on_throttle && HwSnapshot::collect().is_throttled() {
55 self.runner.invalidate_cache();
56 self.runner.warm_all()?;
57 }
58 Ok(())
59 }
60
61 pub fn run_on(
63 &mut self,
64 device: Device,
65 inputs: &[(&str, &[f32])],
66 ) -> Result<Vec<Vec<f32>>, String> {
67 self.maybe_rewarm()?;
68 self.runner.run(device, inputs)
69 }
70
71 pub fn run(
73 &mut self,
74 inputs: &[(&str, &[f32])],
75 hint: Option<Device>,
76 ) -> Result<(Device, Vec<Vec<f32>>), String> {
77 self.maybe_rewarm()?;
78 let device = self.runner.resolve_with_inputs(hint, inputs)?;
79 let out = self.runner.run(device, inputs)?;
80 Ok((device, out))
81 }
82
83 pub fn run_chain(
85 &mut self,
86 inputs: &[(&str, &[f32])],
87 hint: Option<Device>,
88 ) -> Result<(Device, Vec<Vec<f32>>), DeviceFallbackError> {
89 self.maybe_rewarm()?;
90 self.runner.run_chain(hint, inputs)
91 }
92
93 pub fn set_param(&mut self, name: &str, data: &[f32]) {
94 self.runner.set_param(name, data);
95 }
96
97 pub fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: DType) {
98 self.runner.set_param_typed(name, data, dtype);
99 }
100
101 pub fn devices(&self) -> Vec<&'static str> {
102 self.runner
103 .devices()
104 .iter()
105 .map(|d| device_label(*d))
106 .collect()
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use rlx_ir::{DType, Shape};
114
115 #[test]
116 fn router_runs_identity_on_cpu() {
117 let mut g = Graph::new("id");
118 let x = g.input("x", Shape::new(&[4], DType::F32));
119 g.set_outputs(vec![x]);
120 let mut router = DeviceRouter::new(g, DevicePolicy::only([Device::Cpu])).expect("router");
121 let (dev, out) = router
122 .run(&[("x", &[1.0, 2.0, 3.0, 4.0])], None)
123 .expect("run");
124 assert_eq!(dev, Device::Cpu);
125 assert_eq!(out[0], vec![1.0, 2.0, 3.0, 4.0]);
126 }
127}