rlx_runtime/
device_bench.rs1use rlx_driver::Device;
9use rlx_ir::Tick;
10
11use crate::device_parse::device_label;
12use crate::graph_devices::GraphDevices;
13
14#[derive(Debug, Clone, PartialEq)]
16pub struct DeviceBenchResult {
17 pub device: Device,
18 pub label: &'static str,
19 pub compile_ns: u64,
20 pub median_exec_ns: u64,
21}
22
23pub fn warm_all(runner: &mut GraphDevices) -> Result<Vec<Device>, String> {
25 let devices: Vec<Device> = runner.devices().to_vec();
26 for device in &devices {
27 runner.compile(*device)?;
28 }
29 Ok(devices)
30}
31
32pub fn benchmark_devices(
34 runner: &mut GraphDevices,
35 inputs: &[(&str, &[f32])],
36 runs: usize,
37) -> Result<Vec<DeviceBenchResult>, String> {
38 let mut results = Vec::new();
39 let devices: Vec<Device> = runner.devices().to_vec();
40 for device in devices {
41 let t0 = Tick::now();
42 runner.compile(device)?;
43 let compile_ns = Tick::now().elapsed_ns(t0);
44 let mut samples = Vec::with_capacity(runs.max(1));
45 for _ in 0..runs.max(1) {
46 let t1 = Tick::now();
47 runner.run(device, inputs)?;
48 samples.push(Tick::now().elapsed_ns(t1));
49 }
50 samples.sort_unstable();
51 let median_exec_ns = samples[samples.len() / 2];
52 results.push(DeviceBenchResult {
53 device,
54 label: device_label(device),
55 compile_ns,
56 median_exec_ns,
57 });
58 }
59 results.sort_by_key(|r| r.median_exec_ns);
60 Ok(results)
61}