1use crate::Device;
28use rlx_ir::{Graph, Node, Op};
29
30pub trait BackendCostModel: Send + Sync {
32 fn device(&self) -> Device;
34
35 fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64;
39
40 fn dispatch_overhead_ns(&self) -> f64;
42
43 fn roundtrip_overhead_ns(&self) -> f64;
46
47 fn memory_bw(&self) -> f64;
49
50 fn num_threads(&self) -> usize;
52}
53
54pub fn estimate_graph_cost(graph: &Graph, model: &dyn BackendCostModel) -> f64 {
58 let mut total = model.roundtrip_overhead_ns();
59 for node in graph.nodes() {
60 total += node_cost(node, graph, model);
61 }
62 total
63}
64
65fn node_cost(node: &Node, graph: &Graph, model: &dyn BackendCostModel) -> f64 {
66 let dispatch = model.dispatch_overhead_ns();
67 match &node.op {
68 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => 0.0,
69 Op::MatMul | Op::FusedMatMulBiasAct { .. } => {
70 let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
71 let total = node.shape.num_elements().unwrap_or(0);
72 let m = total / n.max(1);
73 let a_total = graph.node(node.inputs[0]).shape.num_elements().unwrap_or(0);
74 let k = a_total / m.max(1);
75 let flops = 2.0 * m as f64 * k as f64 * n as f64;
76 flops / (model.sgemm_gflops(m, k, n) + 1.0) + dispatch
77 }
78 Op::Attention {
79 num_heads,
80 head_dim,
81 ..
82 } => {
83 let q_shape = &graph.node(node.inputs[0]).shape;
84 let seq = q_shape.dim(q_shape.rank() - 2).unwrap_static();
85 let batch = q_shape.num_elements().unwrap_or(0) / (seq * num_heads * head_dim).max(1);
86 let flops = (batch * num_heads * seq * seq * head_dim * 2) as f64;
87 flops / (model.sgemm_gflops(seq, *head_dim, seq) + 1.0) + dispatch
88 }
89 _ => {
91 let bytes = node.shape.num_elements().unwrap_or(0) * 4;
92 (bytes as f64) / model.memory_bw().max(1.0) + dispatch
93 }
94 }
95}
96
97pub fn pick_best_device(graph: &Graph, models: &[&dyn BackendCostModel]) -> Device {
99 let mut best = (Device::Cpu, f64::INFINITY);
100 for &m in models {
101 let cost = estimate_graph_cost(graph, m);
102 if cost < best.1 {
103 best = (m.device(), cost);
104 }
105 }
106 best.0
107}
108
109pub fn fastest_device_for(graph: &Graph) -> Device {
111 fastest_device_for_with_policy(graph, &crate::device_policy::DevicePolicy::default())
112}
113
114pub fn fastest_device_for_with_policy(
116 graph: &Graph,
117 policy: &crate::device_policy::DevicePolicy,
118) -> Device {
119 let candidates = crate::device_policy::devices_for_with_policy(graph, policy);
120 if candidates.is_empty() {
121 return crate::device_ext::fastest_among(&policy.apply(crate::available_devices()));
122 }
123
124 #[cfg(feature = "cpu")]
125 let cpu = CpuCostModel::new();
126 #[cfg(feature = "metal")]
127 let metal = MetalCostModel::new();
128 #[cfg(all(feature = "mlx", rlx_mlx_host))]
129 let mlx = MlxCostModel::new();
130 #[cfg(feature = "cuda")]
131 let cuda = CudaCostModel::new();
132 #[cfg(feature = "rocm")]
133 let rocm = RocmCostModel::new();
134 #[cfg(feature = "gpu")]
135 let wgpu = WgpuCostModel::new();
136
137 let mut models: Vec<&dyn BackendCostModel> = Vec::new();
138 #[cfg(feature = "cpu")]
139 if candidates.contains(&Device::Cpu) {
140 models.push(&cpu);
141 }
142 #[cfg(feature = "metal")]
143 if candidates.contains(&Device::Metal) {
144 models.push(&metal);
145 }
146 #[cfg(all(feature = "mlx", rlx_mlx_host))]
147 if candidates.contains(&Device::Mlx) {
148 models.push(&mlx);
149 }
150 #[cfg(feature = "cuda")]
151 if candidates.contains(&Device::Cuda) {
152 models.push(&cuda);
153 }
154 #[cfg(feature = "rocm")]
155 if candidates.contains(&Device::Rocm) {
156 models.push(&rocm);
157 }
158 #[cfg(feature = "gpu")]
159 if candidates.contains(&Device::Gpu) {
160 models.push(&wgpu);
161 }
162
163 if models.len() >= 2 {
164 pick_best_device(graph, &models)
165 } else if let Some(m) = models.first() {
166 m.device()
167 } else {
168 crate::device_ext::fastest_among(&candidates)
169 }
170}
171
172#[cfg(feature = "cpu")]
181pub struct CpuCostModel(rlx_cpu::cost::HwModel);
182
183#[cfg(feature = "cpu")]
184impl CpuCostModel {
185 pub fn new() -> Self {
186 let cfg = rlx_cpu::config::RuntimeConfig::global();
187 Self(rlx_cpu::cost::HwModel::from_config(cfg))
188 }
189}
190
191#[cfg(feature = "cpu")]
192impl Default for CpuCostModel {
193 fn default() -> Self {
194 Self::new()
195 }
196}
197
198#[cfg(feature = "cpu")]
199impl BackendCostModel for CpuCostModel {
200 fn device(&self) -> Device {
201 Device::Cpu
202 }
203 fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64 {
204 let flops = 2.0 * m as f64 * k as f64 * n as f64;
206 let neon_time = flops / self.0.neon_flops.max(1.0);
207 let blas_time = flops / self.0.blas_flops.max(1.0);
208 let pick = neon_time.min(blas_time);
209 if pick > 0.0 {
210 flops / (pick * 1e9)
211 } else {
212 0.0
213 }
214 }
215 fn dispatch_overhead_ns(&self) -> f64 {
216 self.0.blas_overhead_ns
217 }
218 fn roundtrip_overhead_ns(&self) -> f64 {
219 self.0.par_for_overhead_ns
220 }
221 fn memory_bw(&self) -> f64 {
222 self.0.mem_bw
223 }
224 fn num_threads(&self) -> usize {
225 self.0.num_threads
226 }
227}
228
229#[cfg(feature = "metal")]
233pub struct MetalCostModel {
234 sgemm_gflops_avg: f64,
235 roundtrip_ns: f64,
236 memory_bw: f64,
237}
238
239#[cfg(feature = "metal")]
240impl MetalCostModel {
241 pub fn new() -> Self {
242 let cal = rlx_metal::calibrate::Calibration::load_or_measure();
243 let best = cal
245 .sgemm_simd_4x4_flops
246 .max(cal.sgemm_simd_flops)
247 .max(cal.sgemm_padded_flops);
248 Self {
249 sgemm_gflops_avg: best,
250 roundtrip_ns: cal.roundtrip_overhead_ns,
251 memory_bw: 200.0,
256 }
257 }
258}
259
260#[cfg(feature = "metal")]
261impl Default for MetalCostModel {
262 fn default() -> Self {
263 Self::new()
264 }
265}
266
267#[cfg(feature = "metal")]
268impl BackendCostModel for MetalCostModel {
269 fn device(&self) -> Device {
270 Device::Metal
271 }
272 fn sgemm_gflops(&self, _m: usize, _k: usize, _n: usize) -> f64 {
273 self.sgemm_gflops_avg
274 }
275 fn dispatch_overhead_ns(&self) -> f64 {
276 2_000.0
278 }
279 fn roundtrip_overhead_ns(&self) -> f64 {
280 self.roundtrip_ns
281 }
282 fn memory_bw(&self) -> f64 {
283 self.memory_bw
284 }
285 fn num_threads(&self) -> usize {
286 1
287 } }
289
290#[cfg(all(feature = "mlx", rlx_mlx_host))]
295pub struct MlxCostModel {
296 sgemm_large_flops: f64,
297 sgemm_small_flops: f64,
298 roundtrip_ns: f64,
299 memory_bw: f64,
300}
301
302#[cfg(all(feature = "mlx", rlx_mlx_host))]
303impl MlxCostModel {
304 pub fn new() -> Self {
305 let cal = rlx_mlx::calibrate::Calibration::load_or_measure();
306 let memory_bw = if cal.memory_bw_gbps > 0.0 {
311 cal.memory_bw_gbps
312 } else {
313 200.0
314 };
315 Self {
316 sgemm_large_flops: cal.sgemm_large_flops,
317 sgemm_small_flops: cal.sgemm_small_flops,
318 roundtrip_ns: cal.roundtrip_overhead_ns,
319 memory_bw,
320 }
321 }
322}
323
324#[cfg(all(feature = "mlx", rlx_mlx_host))]
325impl Default for MlxCostModel {
326 fn default() -> Self {
327 Self::new()
328 }
329}
330
331#[cfg(all(feature = "mlx", rlx_mlx_host))]
332impl BackendCostModel for MlxCostModel {
333 fn device(&self) -> Device {
334 Device::Mlx
335 }
336 fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64 {
337 let total = m as f64 * k as f64 * n as f64;
341 if total < 32_768.0 {
342 self.sgemm_small_flops
343 } else {
344 self.sgemm_large_flops
345 }
346 }
347 fn dispatch_overhead_ns(&self) -> f64 {
348 2_000.0
351 }
352 fn roundtrip_overhead_ns(&self) -> f64 {
353 self.roundtrip_ns
354 }
355 fn memory_bw(&self) -> f64 {
356 self.memory_bw
357 }
358 fn num_threads(&self) -> usize {
359 1
360 }
361}
362
363#[cfg(feature = "cuda")]
365pub struct CudaCostModel {
366 sgemm_gflops: f64,
367 roundtrip_ns: f64,
368 memory_bw: f64,
369}
370
371#[cfg(feature = "cuda")]
372impl CudaCostModel {
373 pub fn new() -> Self {
374 if crate::is_available(crate::Device::Cuda) {
375 let cal = rlx_cuda::calibrate::Calibration::load_or_measure();
376 return Self {
377 sgemm_gflops: cal.sgemm_gflops,
378 roundtrip_ns: cal.roundtrip_overhead_ns,
379 memory_bw: cal.memory_bw_gbps,
380 };
381 }
382 Self {
383 sgemm_gflops: 12_000.0,
384 roundtrip_ns: 35_000.0,
385 memory_bw: 900.0,
386 }
387 }
388}
389
390#[cfg(feature = "cuda")]
391impl Default for CudaCostModel {
392 fn default() -> Self {
393 Self::new()
394 }
395}
396
397#[cfg(feature = "cuda")]
398impl BackendCostModel for CudaCostModel {
399 fn device(&self) -> Device {
400 Device::Cuda
401 }
402 fn sgemm_gflops(&self, _m: usize, _k: usize, _n: usize) -> f64 {
403 self.sgemm_gflops
404 }
405 fn dispatch_overhead_ns(&self) -> f64 {
406 3_000.0
407 }
408 fn roundtrip_overhead_ns(&self) -> f64 {
409 self.roundtrip_ns
410 }
411 fn memory_bw(&self) -> f64 {
412 self.memory_bw
413 }
414 fn num_threads(&self) -> usize {
415 1
416 }
417}
418
419#[cfg(feature = "rocm")]
421pub struct RocmCostModel {
422 sgemm_gflops: f64,
423 roundtrip_ns: f64,
424 memory_bw: f64,
425}
426
427#[cfg(feature = "rocm")]
428impl RocmCostModel {
429 pub fn new() -> Self {
430 if crate::is_available(crate::Device::Rocm) {
431 let cal = rlx_rocm::calibrate::Calibration::load_or_measure();
432 return Self {
433 sgemm_gflops: cal.sgemm_gflops,
434 roundtrip_ns: cal.roundtrip_overhead_ns,
435 memory_bw: cal.memory_bw_gbps,
436 };
437 }
438 Self {
439 sgemm_gflops: 10_000.0,
440 roundtrip_ns: 40_000.0,
441 memory_bw: 800.0,
442 }
443 }
444}
445
446#[cfg(feature = "rocm")]
447impl Default for RocmCostModel {
448 fn default() -> Self {
449 Self::new()
450 }
451}
452
453#[cfg(feature = "rocm")]
454impl BackendCostModel for RocmCostModel {
455 fn device(&self) -> Device {
456 Device::Rocm
457 }
458 fn sgemm_gflops(&self, _m: usize, _k: usize, _n: usize) -> f64 {
459 self.sgemm_gflops
460 }
461 fn dispatch_overhead_ns(&self) -> f64 {
462 3_000.0
463 }
464 fn roundtrip_overhead_ns(&self) -> f64 {
465 self.roundtrip_ns
466 }
467 fn memory_bw(&self) -> f64 {
468 self.memory_bw
469 }
470 fn num_threads(&self) -> usize {
471 1
472 }
473}
474
475#[cfg(feature = "gpu")]
477pub struct WgpuCostModel {
478 sgemm_gflops: f64,
479 roundtrip_ns: f64,
480 memory_bw: f64,
481}
482
483#[cfg(feature = "gpu")]
484impl WgpuCostModel {
485 pub fn new() -> Self {
486 if rlx_wgpu::is_available() {
487 let cal = rlx_wgpu::calibrate::Calibration::load_or_measure();
488 return Self {
489 sgemm_gflops: cal.sgemm_gflops,
490 roundtrip_ns: cal.roundtrip_overhead_ns,
491 memory_bw: cal.memory_bw_gbps,
492 };
493 }
494 Self {
495 sgemm_gflops: 2_500.0,
496 roundtrip_ns: 80_000.0,
497 memory_bw: 120.0,
498 }
499 }
500}
501
502#[cfg(feature = "gpu")]
503impl Default for WgpuCostModel {
504 fn default() -> Self {
505 Self::new()
506 }
507}
508
509#[cfg(feature = "gpu")]
510impl BackendCostModel for WgpuCostModel {
511 fn device(&self) -> Device {
512 Device::Gpu
513 }
514 fn sgemm_gflops(&self, _m: usize, _k: usize, _n: usize) -> f64 {
515 self.sgemm_gflops
516 }
517 fn dispatch_overhead_ns(&self) -> f64 {
518 5_000.0
519 }
520 fn roundtrip_overhead_ns(&self) -> f64 {
521 self.roundtrip_ns
522 }
523 fn memory_bw(&self) -> f64 {
524 self.memory_bw
525 }
526 fn num_threads(&self) -> usize {
527 1
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534 use rlx_ir::{DType, Graph, Shape};
535
536 #[test]
537 fn fastest_device_for_falls_back_to_cpu_for_simple_graph() {
538 let mut g = Graph::new("mm");
539 let x = g.input("x", Shape::new(&[4, 4], DType::F32));
540 let w = g.param("w", Shape::new(&[4, 4], DType::F32));
541 let y = g.matmul(x, w, Shape::new(&[4, 4], DType::F32));
542 g.set_outputs(vec![y]);
543 let pick = fastest_device_for(&g);
544 assert!(crate::is_available(pick));
545 assert!(crate::devices_for(&g).contains(&pick));
546 }
547}