Skip to main content

rlx_runtime/
graph_devices.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// Licensed under the GNU General Public License, version 3.
5
6//! Multi-backend execution — compile once per device, run on any of them.
7
8use std::collections::HashMap;
9
10use rlx_driver::Device;
11use rlx_ir::{DType, Graph, Op};
12
13use crate::compiled::CompiledGraph;
14use crate::cost::fastest_device_for_with_policy;
15use crate::device_bench::{DeviceBenchResult, benchmark_devices, warm_all};
16use crate::device_ext::is_available;
17use crate::device_policy::{
18    DeviceCandidate, DeviceFallbackError, DevicePickStrategy, DevicePolicy, device_chain_from_env,
19    device_report, devices_for_with_policy, resolve_device, resolve_device_chain,
20};
21use crate::session::Session;
22
23/// Param names declared in `graph` (`Op::Param`).
24pub fn graph_param_names(graph: &Graph) -> Vec<String> {
25    graph
26        .nodes()
27        .iter()
28        .filter_map(|n| match &n.op {
29            Op::Param { name } => Some(name.clone()),
30            _ => None,
31        })
32        .collect()
33}
34
35#[derive(Debug, Clone)]
36enum CachedParam {
37    F32(Vec<f32>),
38    Typed { bytes: Vec<u8>, dtype: DType },
39}
40
41fn apply_cached_params(compiled: &mut CompiledGraph, params: &HashMap<String, CachedParam>) {
42    for (name, param) in params {
43        match param {
44            CachedParam::F32(data) => compiled.set_param(name, data),
45            CachedParam::Typed { bytes, dtype } => compiled.set_param_typed(name, bytes, *dtype),
46        }
47    }
48}
49
50/// A graph plus lazy per-device compiled executables.
51pub struct GraphDevices {
52    graph: Graph,
53    policy: DevicePolicy,
54    pick: DevicePickStrategy,
55    supported: Vec<Device>,
56    params: HashMap<String, CachedParam>,
57    benchmark_winner: Option<Device>,
58    cache: HashMap<Device, CompiledGraph>,
59}
60
61impl GraphDevices {
62    pub fn new(graph: Graph) -> Self {
63        Self::with_policy(graph, DevicePolicy::default())
64    }
65
66    pub fn with_policy(graph: Graph, policy: DevicePolicy) -> Self {
67        let pick = policy.pick_strategy();
68        let supported = devices_for_with_policy(&graph, &policy);
69        Self {
70            graph,
71            policy,
72            pick,
73            supported,
74            params: HashMap::new(),
75            benchmark_winner: None,
76            cache: HashMap::new(),
77        }
78    }
79
80    pub fn from_env(graph: Graph) -> Self {
81        Self::with_policy(graph, DevicePolicy::from_env())
82    }
83
84    pub fn policy(&self) -> &DevicePolicy {
85        &self.policy
86    }
87
88    pub fn graph(&self) -> &Graph {
89        &self.graph
90    }
91
92    pub fn devices(&self) -> &[Device] {
93        &self.supported
94    }
95
96    pub fn report(&self) -> Vec<DeviceCandidate> {
97        device_report(&self.graph, &self.policy)
98    }
99
100    pub fn fastest(&self) -> Device {
101        fastest_device_for_with_policy(&self.graph, &self.policy)
102    }
103
104    pub fn resolve(&self, hint: Option<Device>) -> Result<Device, String> {
105        resolve_device(&self.graph, hint, &self.policy)
106    }
107
108    /// Resolve using `RLX_DEVICE_CHAIN` when set, else [`Self::resolve`].
109    pub fn resolve_chain(&self, hint: Option<Device>) -> Result<Device, String> {
110        if let Some(device) = hint {
111            return self.resolve(Some(device));
112        }
113        let chain = device_chain_from_env();
114        if chain.is_empty() {
115            return self.resolve(None);
116        }
117        resolve_device_chain(&self.graph, &chain, &self.policy)
118    }
119
120    /// Upload a param to every cached executor and future compilations.
121    pub fn set_param(&mut self, name: &str, data: &[f32]) {
122        self.params
123            .insert(name.to_string(), CachedParam::F32(data.to_vec()));
124        for compiled in self.cache.values_mut() {
125            compiled.set_param(name, data);
126        }
127    }
128
129    /// Typed param upload — mirrored to all cached backends.
130    pub fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: DType) {
131        self.params.insert(
132            name.to_string(),
133            CachedParam::Typed {
134                bytes: data.to_vec(),
135                dtype,
136            },
137        );
138        for compiled in self.cache.values_mut() {
139            compiled.set_param_typed(name, data, dtype);
140        }
141    }
142
143    /// Re-apply stored params to all cached backends (after manual cache changes).
144    pub fn sync_params_to_all(&mut self) {
145        for compiled in self.cache.values_mut() {
146            apply_cached_params(compiled, &self.params);
147        }
148    }
149
150    /// Hint → env → cost model, or micro-benchmark when policy requests it.
151    pub fn resolve_with_inputs(
152        &mut self,
153        hint: Option<Device>,
154        inputs: &[(&str, &[f32])],
155    ) -> Result<Device, String> {
156        if hint.is_some() {
157            return self.resolve(hint);
158        }
159        match self.pick {
160            DevicePickStrategy::CostModel => self.resolve(None),
161            DevicePickStrategy::Benchmark { runs } => {
162                if let Some(device) = self.benchmark_winner {
163                    return Ok(device);
164                }
165                let ranked = self.benchmark(inputs, runs)?;
166                let device = ranked
167                    .first()
168                    .map(|r| r.device)
169                    .unwrap_or_else(|| self.fastest());
170                self.benchmark_winner = Some(device);
171                Ok(device)
172            }
173        }
174    }
175
176    pub fn compile(&mut self, device: Device) -> Result<&mut CompiledGraph, String> {
177        Self::ensure_supported(&self.supported, device)?;
178        if !self.cache.contains_key(&device) {
179            let mut compiled = Session::new(device).compile(self.graph.clone());
180            apply_cached_params(&mut compiled, &self.params);
181            self.cache.insert(device, compiled);
182        }
183        Ok(self.cache.get_mut(&device).expect("just inserted"))
184    }
185
186    pub fn compile_fastest(&mut self) -> Result<&mut CompiledGraph, String> {
187        self.compile(self.fastest())
188    }
189
190    pub fn compile_resolved(&mut self, hint: Option<Device>) -> Result<&mut CompiledGraph, String> {
191        self.compile(self.resolve(hint)?)
192    }
193
194    pub fn compile_chain(&mut self, hint: Option<Device>) -> Result<&mut CompiledGraph, String> {
195        self.compile(self.resolve_chain(hint)?)
196    }
197
198    pub fn warm_all(&mut self) -> Result<Vec<Device>, String> {
199        warm_all(self)
200    }
201
202    pub fn benchmark(
203        &mut self,
204        inputs: &[(&str, &[f32])],
205        runs: usize,
206    ) -> Result<Vec<DeviceBenchResult>, String> {
207        benchmark_devices(self, inputs, runs)
208    }
209
210    pub fn run(
211        &mut self,
212        device: Device,
213        inputs: &[(&str, &[f32])],
214    ) -> Result<Vec<Vec<f32>>, String> {
215        Ok(self.compile(device)?.run(inputs))
216    }
217
218    pub fn run_resolved(
219        &mut self,
220        hint: Option<Device>,
221        inputs: &[(&str, &[f32])],
222    ) -> Result<Vec<Vec<f32>>, String> {
223        Ok(self.compile_resolved(hint)?.run(inputs))
224    }
225
226    pub fn run_fastest(&mut self, inputs: &[(&str, &[f32])]) -> Result<Vec<Vec<f32>>, String> {
227        Ok(self.compile_fastest()?.run(inputs))
228    }
229
230    /// Try `chain` in order until one backend compiles and runs successfully.
231    pub fn run_try(
232        &mut self,
233        chain: &[Device],
234        inputs: &[(&str, &[f32])],
235    ) -> Result<(Device, Vec<Vec<f32>>), DeviceFallbackError> {
236        let viable: Vec<Device> = self.devices().to_vec();
237        let mut attempts = Vec::new();
238        for &device in chain {
239            if !viable.contains(&device) {
240                attempts.push((device, "not viable for this graph under policy".into()));
241                continue;
242            }
243            match self.run(device, inputs) {
244                Ok(value) => return Ok((device, value)),
245                Err(err) => attempts.push((device, err)),
246            }
247        }
248        if attempts.is_empty() {
249            attempts.push((Device::Cpu, "empty fallback chain".into()));
250        }
251        Err(DeviceFallbackError { attempts })
252    }
253
254    /// Like [`Self::run_try`] using `RLX_DEVICE_CHAIN` when set.
255    pub fn run_chain(
256        &mut self,
257        hint: Option<Device>,
258        inputs: &[(&str, &[f32])],
259    ) -> Result<(Device, Vec<Vec<f32>>), DeviceFallbackError> {
260        if let Some(device) = hint {
261            self.run(device, inputs)
262                .map(|v| (device, v))
263                .map_err(|e| DeviceFallbackError {
264                    attempts: vec![(device, e)],
265                })
266        } else {
267            let chain = device_chain_from_env();
268            if chain.is_empty() {
269                let device = self.resolve(None).map_err(|e| DeviceFallbackError {
270                    attempts: vec![(Device::Cpu, e)],
271                })?;
272                self.run(device, inputs)
273                    .map(|v| (device, v))
274                    .map_err(|e| DeviceFallbackError {
275                        attempts: vec![(device, e)],
276                    })
277            } else {
278                self.run_try(&chain, inputs)
279            }
280        }
281    }
282
283    pub fn compile_resolved_with_inputs(
284        &mut self,
285        hint: Option<Device>,
286        inputs: &[(&str, &[f32])],
287    ) -> Result<&mut CompiledGraph, String> {
288        let device = self.resolve_with_inputs(hint, inputs)?;
289        self.compile(device)
290    }
291
292    pub fn run_resolved_with_inputs(
293        &mut self,
294        hint: Option<Device>,
295        inputs: &[(&str, &[f32])],
296    ) -> Result<Vec<Vec<f32>>, String> {
297        Ok(self.compile_resolved_with_inputs(hint, inputs)?.run(inputs))
298    }
299
300    pub fn invalidate_cache(&mut self) {
301        self.cache.clear();
302        self.benchmark_winner = None;
303        self.supported = devices_for_with_policy(&self.graph, &self.policy);
304    }
305
306    pub fn set_policy(&mut self, policy: DevicePolicy) {
307        self.policy = policy.clone();
308        self.pick = policy.pick_strategy();
309        self.invalidate_cache();
310    }
311
312    fn ensure_supported(supported: &[Device], device: Device) -> Result<(), String> {
313        if !is_available(device) {
314            return Err(format!(
315                "device {device} is not available — enable the matching Cargo feature"
316            ));
317        }
318        if !supported.contains(&device) {
319            return Err(format!(
320                "device {device} cannot lower this graph under the active policy"
321            ));
322        }
323        Ok(())
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use rlx_ir::{DType, Shape};
331
332    fn identity_graph() -> Graph {
333        let mut g = Graph::new("id");
334        let x = g.input("x", Shape::new(&[4], DType::F32));
335        g.set_outputs(vec![x]);
336        g
337    }
338
339    #[test]
340    fn set_param_applies_to_new_compile() {
341        let mut g = Graph::new("p");
342        let x = g.input("x", Shape::new(&[2], DType::F32));
343        let w = g.param("w", Shape::new(&[2], DType::F32));
344        let y = g.binary(
345            rlx_ir::op::BinaryOp::Add,
346            x,
347            w,
348            Shape::new(&[2], DType::F32),
349        );
350        g.set_outputs(vec![y]);
351
352        let mut runner = GraphDevices::new(g);
353        runner.set_param("w", &[1.0, 2.0]);
354        let out = runner.run(Device::Cpu, &[("x", &[3.0, 4.0])]).unwrap();
355        assert_eq!(out[0], vec![4.0, 6.0]);
356    }
357
358    #[test]
359    fn run_on_cpu_roundtrip() {
360        let mut runner = GraphDevices::new(identity_graph());
361        let out = runner
362            .run(Device::Cpu, &[("x", &[1.0, 2.0, 3.0, 4.0])])
363            .expect("cpu run");
364        assert_eq!(out[0], vec![1.0, 2.0, 3.0, 4.0]);
365    }
366
367    #[test]
368    fn run_try_falls_back_to_cpu() {
369        let mut runner = GraphDevices::new(identity_graph());
370        let chain = [Device::Cuda, Device::Cpu];
371        let (dev, out) = runner
372            .run_try(&chain, &[("x", &[1.0, 2.0, 3.0, 4.0])])
373            .expect("fallback");
374        assert_eq!(dev, Device::Cpu);
375        assert_eq!(out[0][0], 1.0);
376    }
377}