1use std::collections::HashMap;
9
10use rlx_driver::Device;
11use rlx_ir::{DType, Graph, Op};
12
13use crate::compiled::CompiledGraph;
14use crate::cost::fastest_device_for_with_policy;
15use crate::device_bench::{DeviceBenchResult, benchmark_devices, warm_all};
16use crate::device_ext::is_available;
17use crate::device_policy::{
18 DeviceCandidate, DeviceFallbackError, DevicePickStrategy, DevicePolicy, device_chain_from_env,
19 device_report, devices_for_with_policy, resolve_device, resolve_device_chain,
20};
21use crate::session::Session;
22
23pub fn graph_param_names(graph: &Graph) -> Vec<String> {
25 graph
26 .nodes()
27 .iter()
28 .filter_map(|n| match &n.op {
29 Op::Param { name } => Some(name.clone()),
30 _ => None,
31 })
32 .collect()
33}
34
35#[derive(Debug, Clone)]
36enum CachedParam {
37 F32(Vec<f32>),
38 Typed { bytes: Vec<u8>, dtype: DType },
39}
40
41fn apply_cached_params(compiled: &mut CompiledGraph, params: &HashMap<String, CachedParam>) {
42 for (name, param) in params {
43 match param {
44 CachedParam::F32(data) => compiled.set_param(name, data),
45 CachedParam::Typed { bytes, dtype } => compiled.set_param_typed(name, bytes, *dtype),
46 }
47 }
48}
49
50pub struct GraphDevices {
52 graph: Graph,
53 policy: DevicePolicy,
54 pick: DevicePickStrategy,
55 supported: Vec<Device>,
56 params: HashMap<String, CachedParam>,
57 benchmark_winner: Option<Device>,
58 cache: HashMap<Device, CompiledGraph>,
59}
60
61impl GraphDevices {
62 pub fn new(graph: Graph) -> Self {
63 Self::with_policy(graph, DevicePolicy::default())
64 }
65
66 pub fn with_policy(graph: Graph, policy: DevicePolicy) -> Self {
67 let pick = policy.pick_strategy();
68 let supported = devices_for_with_policy(&graph, &policy);
69 Self {
70 graph,
71 policy,
72 pick,
73 supported,
74 params: HashMap::new(),
75 benchmark_winner: None,
76 cache: HashMap::new(),
77 }
78 }
79
80 pub fn from_env(graph: Graph) -> Self {
81 Self::with_policy(graph, DevicePolicy::from_env())
82 }
83
84 pub fn policy(&self) -> &DevicePolicy {
85 &self.policy
86 }
87
88 pub fn graph(&self) -> &Graph {
89 &self.graph
90 }
91
92 pub fn devices(&self) -> &[Device] {
93 &self.supported
94 }
95
96 pub fn report(&self) -> Vec<DeviceCandidate> {
97 device_report(&self.graph, &self.policy)
98 }
99
100 pub fn fastest(&self) -> Device {
101 fastest_device_for_with_policy(&self.graph, &self.policy)
102 }
103
104 pub fn resolve(&self, hint: Option<Device>) -> Result<Device, String> {
105 resolve_device(&self.graph, hint, &self.policy)
106 }
107
108 pub fn resolve_chain(&self, hint: Option<Device>) -> Result<Device, String> {
110 if let Some(device) = hint {
111 return self.resolve(Some(device));
112 }
113 let chain = device_chain_from_env();
114 if chain.is_empty() {
115 return self.resolve(None);
116 }
117 resolve_device_chain(&self.graph, &chain, &self.policy)
118 }
119
120 pub fn set_param(&mut self, name: &str, data: &[f32]) {
122 self.params
123 .insert(name.to_string(), CachedParam::F32(data.to_vec()));
124 for compiled in self.cache.values_mut() {
125 compiled.set_param(name, data);
126 }
127 }
128
129 pub fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: DType) {
131 self.params.insert(
132 name.to_string(),
133 CachedParam::Typed {
134 bytes: data.to_vec(),
135 dtype,
136 },
137 );
138 for compiled in self.cache.values_mut() {
139 compiled.set_param_typed(name, data, dtype);
140 }
141 }
142
143 pub fn sync_params_to_all(&mut self) {
145 for compiled in self.cache.values_mut() {
146 apply_cached_params(compiled, &self.params);
147 }
148 }
149
150 pub fn resolve_with_inputs(
152 &mut self,
153 hint: Option<Device>,
154 inputs: &[(&str, &[f32])],
155 ) -> Result<Device, String> {
156 if hint.is_some() {
157 return self.resolve(hint);
158 }
159 match self.pick {
160 DevicePickStrategy::CostModel => self.resolve(None),
161 DevicePickStrategy::Benchmark { runs } => {
162 if let Some(device) = self.benchmark_winner {
163 return Ok(device);
164 }
165 let ranked = self.benchmark(inputs, runs)?;
166 let device = ranked
167 .first()
168 .map(|r| r.device)
169 .unwrap_or_else(|| self.fastest());
170 self.benchmark_winner = Some(device);
171 Ok(device)
172 }
173 }
174 }
175
176 pub fn compile(&mut self, device: Device) -> Result<&mut CompiledGraph, String> {
177 Self::ensure_supported(&self.supported, device)?;
178 if !self.cache.contains_key(&device) {
179 let mut compiled = Session::new(device).compile(self.graph.clone());
180 apply_cached_params(&mut compiled, &self.params);
181 self.cache.insert(device, compiled);
182 }
183 Ok(self.cache.get_mut(&device).expect("just inserted"))
184 }
185
186 pub fn compile_fastest(&mut self) -> Result<&mut CompiledGraph, String> {
187 self.compile(self.fastest())
188 }
189
190 pub fn compile_resolved(&mut self, hint: Option<Device>) -> Result<&mut CompiledGraph, String> {
191 self.compile(self.resolve(hint)?)
192 }
193
194 pub fn compile_chain(&mut self, hint: Option<Device>) -> Result<&mut CompiledGraph, String> {
195 self.compile(self.resolve_chain(hint)?)
196 }
197
198 pub fn warm_all(&mut self) -> Result<Vec<Device>, String> {
199 warm_all(self)
200 }
201
202 pub fn benchmark(
203 &mut self,
204 inputs: &[(&str, &[f32])],
205 runs: usize,
206 ) -> Result<Vec<DeviceBenchResult>, String> {
207 benchmark_devices(self, inputs, runs)
208 }
209
210 pub fn run(
211 &mut self,
212 device: Device,
213 inputs: &[(&str, &[f32])],
214 ) -> Result<Vec<Vec<f32>>, String> {
215 Ok(self.compile(device)?.run(inputs))
216 }
217
218 pub fn run_resolved(
219 &mut self,
220 hint: Option<Device>,
221 inputs: &[(&str, &[f32])],
222 ) -> Result<Vec<Vec<f32>>, String> {
223 Ok(self.compile_resolved(hint)?.run(inputs))
224 }
225
226 pub fn run_fastest(&mut self, inputs: &[(&str, &[f32])]) -> Result<Vec<Vec<f32>>, String> {
227 Ok(self.compile_fastest()?.run(inputs))
228 }
229
230 pub fn run_try(
232 &mut self,
233 chain: &[Device],
234 inputs: &[(&str, &[f32])],
235 ) -> Result<(Device, Vec<Vec<f32>>), DeviceFallbackError> {
236 let viable: Vec<Device> = self.devices().to_vec();
237 let mut attempts = Vec::new();
238 for &device in chain {
239 if !viable.contains(&device) {
240 attempts.push((device, "not viable for this graph under policy".into()));
241 continue;
242 }
243 match self.run(device, inputs) {
244 Ok(value) => return Ok((device, value)),
245 Err(err) => attempts.push((device, err)),
246 }
247 }
248 if attempts.is_empty() {
249 attempts.push((Device::Cpu, "empty fallback chain".into()));
250 }
251 Err(DeviceFallbackError { attempts })
252 }
253
254 pub fn run_chain(
256 &mut self,
257 hint: Option<Device>,
258 inputs: &[(&str, &[f32])],
259 ) -> Result<(Device, Vec<Vec<f32>>), DeviceFallbackError> {
260 if let Some(device) = hint {
261 self.run(device, inputs)
262 .map(|v| (device, v))
263 .map_err(|e| DeviceFallbackError {
264 attempts: vec![(device, e)],
265 })
266 } else {
267 let chain = device_chain_from_env();
268 if chain.is_empty() {
269 let device = self.resolve(None).map_err(|e| DeviceFallbackError {
270 attempts: vec![(Device::Cpu, e)],
271 })?;
272 self.run(device, inputs)
273 .map(|v| (device, v))
274 .map_err(|e| DeviceFallbackError {
275 attempts: vec![(device, e)],
276 })
277 } else {
278 self.run_try(&chain, inputs)
279 }
280 }
281 }
282
283 pub fn compile_resolved_with_inputs(
284 &mut self,
285 hint: Option<Device>,
286 inputs: &[(&str, &[f32])],
287 ) -> Result<&mut CompiledGraph, String> {
288 let device = self.resolve_with_inputs(hint, inputs)?;
289 self.compile(device)
290 }
291
292 pub fn run_resolved_with_inputs(
293 &mut self,
294 hint: Option<Device>,
295 inputs: &[(&str, &[f32])],
296 ) -> Result<Vec<Vec<f32>>, String> {
297 Ok(self.compile_resolved_with_inputs(hint, inputs)?.run(inputs))
298 }
299
300 pub fn invalidate_cache(&mut self) {
301 self.cache.clear();
302 self.benchmark_winner = None;
303 self.supported = devices_for_with_policy(&self.graph, &self.policy);
304 }
305
306 pub fn set_policy(&mut self, policy: DevicePolicy) {
307 self.policy = policy.clone();
308 self.pick = policy.pick_strategy();
309 self.invalidate_cache();
310 }
311
312 fn ensure_supported(supported: &[Device], device: Device) -> Result<(), String> {
313 if !is_available(device) {
314 return Err(format!(
315 "device {device} is not available — enable the matching Cargo feature"
316 ));
317 }
318 if !supported.contains(&device) {
319 return Err(format!(
320 "device {device} cannot lower this graph under the active policy"
321 ));
322 }
323 Ok(())
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use rlx_ir::{DType, Shape};
331
332 fn identity_graph() -> Graph {
333 let mut g = Graph::new("id");
334 let x = g.input("x", Shape::new(&[4], DType::F32));
335 g.set_outputs(vec![x]);
336 g
337 }
338
339 #[test]
340 fn set_param_applies_to_new_compile() {
341 let mut g = Graph::new("p");
342 let x = g.input("x", Shape::new(&[2], DType::F32));
343 let w = g.param("w", Shape::new(&[2], DType::F32));
344 let y = g.binary(
345 rlx_ir::op::BinaryOp::Add,
346 x,
347 w,
348 Shape::new(&[2], DType::F32),
349 );
350 g.set_outputs(vec![y]);
351
352 let mut runner = GraphDevices::new(g);
353 runner.set_param("w", &[1.0, 2.0]);
354 let out = runner.run(Device::Cpu, &[("x", &[3.0, 4.0])]).unwrap();
355 assert_eq!(out[0], vec![4.0, 6.0]);
356 }
357
358 #[test]
359 fn run_on_cpu_roundtrip() {
360 let mut runner = GraphDevices::new(identity_graph());
361 let out = runner
362 .run(Device::Cpu, &[("x", &[1.0, 2.0, 3.0, 4.0])])
363 .expect("cpu run");
364 assert_eq!(out[0], vec![1.0, 2.0, 3.0, 4.0]);
365 }
366
367 #[test]
368 fn run_try_falls_back_to_cpu() {
369 let mut runner = GraphDevices::new(identity_graph());
370 let chain = [Device::Cuda, Device::Cpu];
371 let (dev, out) = runner
372 .run_try(&chain, &[("x", &[1.0, 2.0, 3.0, 4.0])])
373 .expect("fallback");
374 assert_eq!(dev, Device::Cpu);
375 assert_eq!(out[0][0], 1.0);
376 }
377}