1use crate::Device;
28use rlx_ir::{Graph, Node, Op};
29
30pub trait BackendCostModel: Send + Sync {
32 fn device(&self) -> Device;
34
35 fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64;
39
40 fn dispatch_overhead_ns(&self) -> f64;
42
43 fn roundtrip_overhead_ns(&self) -> f64;
46
47 fn memory_bw(&self) -> f64;
49
50 fn host_readback_bw(&self) -> f64 {
52 self.memory_bw()
53 }
54
55 fn unified_memory(&self) -> bool {
57 false
58 }
59
60 fn num_threads(&self) -> usize;
62}
63
64pub fn estimate_graph_cost(graph: &Graph, model: &dyn BackendCostModel) -> f64 {
68 estimate_graph_cost_with_io(graph, model, &crate::graph_io::profile_graph_io(graph))
69}
70
71pub fn estimate_graph_cost_with_io(
73 graph: &Graph,
74 model: &dyn BackendCostModel,
75 io: &crate::graph_io::GraphIoProfile,
76) -> f64 {
77 let mut total = model.roundtrip_overhead_ns();
78 for node in graph.nodes() {
79 total += node_cost(node, graph, model);
80 }
81 total += io.device_traffic_bytes as f64 / model.memory_bw().max(1.0);
82 total +=
83 io.host_readback_bytes(model.unified_memory()) as f64 / model.host_readback_bw().max(1.0);
84 total += io.sync_points as f64 * model.roundtrip_overhead_ns();
85 total
86}
87
88fn node_cost(node: &Node, graph: &Graph, model: &dyn BackendCostModel) -> f64 {
89 let dispatch = model.dispatch_overhead_ns();
90 match &node.op {
91 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => 0.0,
92 Op::MatMul | Op::FusedMatMulBiasAct { .. } => {
93 let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
94 let total = node.shape.num_elements().unwrap_or(0);
95 let m = total / n.max(1);
96 let a_total = graph.node(node.inputs[0]).shape.num_elements().unwrap_or(0);
97 let k = a_total / m.max(1);
98 let flops = 2.0 * m as f64 * k as f64 * n as f64;
99 flops / (model.sgemm_gflops(m, k, n) + 1.0) + dispatch
100 }
101 Op::Attention {
102 num_heads,
103 head_dim,
104 ..
105 } => {
106 let q_shape = &graph.node(node.inputs[0]).shape;
107 let seq = q_shape.dim(q_shape.rank() - 2).unwrap_static();
108 let batch = q_shape.num_elements().unwrap_or(0) / (seq * num_heads * head_dim).max(1);
109 let flops = (batch * num_heads * seq * seq * head_dim * 2) as f64;
110 flops / (model.sgemm_gflops(seq, *head_dim, seq) + 1.0) + dispatch
111 }
112 _ => {
114 let bytes = node.shape.num_elements().unwrap_or(0) * 4;
115 (bytes as f64) / model.memory_bw().max(1.0) + dispatch
116 }
117 }
118}
119
120pub fn pick_best_device(graph: &Graph, models: &[&dyn BackendCostModel]) -> Device {
122 let mut best = (Device::Cpu, f64::INFINITY);
123 for &m in models {
124 let cost = estimate_graph_cost(graph, m);
125 if cost < best.1 {
126 best = (m.device(), cost);
127 }
128 }
129 best.0
130}
131
132pub fn fastest_device_for(graph: &Graph) -> Device {
134 fastest_device_for_with_policy(graph, &crate::device_policy::DevicePolicy::default())
135}
136
137pub fn fastest_device_for_with_policy(
139 graph: &Graph,
140 policy: &crate::device_policy::DevicePolicy,
141) -> Device {
142 let candidates = crate::device_policy::devices_for_with_policy(graph, policy);
143 if candidates.is_empty() {
144 return crate::device_ext::fastest_among(&policy.apply(crate::available_devices()));
145 }
146
147 #[cfg(feature = "cpu")]
148 let cpu = CpuCostModel::new();
149 #[cfg(feature = "metal")]
150 let metal = MetalCostModel::new();
151 #[cfg(all(feature = "mlx", rlx_mlx_host))]
152 let mlx = MlxCostModel::new();
153 #[cfg(feature = "cuda")]
154 let cuda = CudaCostModel::new();
155 #[cfg(feature = "rocm")]
156 let rocm = RocmCostModel::new();
157 #[cfg(feature = "gpu")]
158 let wgpu = WgpuCostModel::new();
159
160 let mut models: Vec<&dyn BackendCostModel> = Vec::new();
161 #[cfg(feature = "cpu")]
162 if candidates.contains(&Device::Cpu) {
163 models.push(&cpu);
164 }
165 #[cfg(feature = "metal")]
166 if candidates.contains(&Device::Metal) {
167 models.push(&metal);
168 }
169 #[cfg(all(feature = "mlx", rlx_mlx_host))]
170 if candidates.contains(&Device::Mlx) {
171 models.push(&mlx);
172 }
173 #[cfg(feature = "cuda")]
174 if candidates.contains(&Device::Cuda) {
175 models.push(&cuda);
176 }
177 #[cfg(feature = "rocm")]
178 if candidates.contains(&Device::Rocm) {
179 models.push(&rocm);
180 }
181 #[cfg(feature = "gpu")]
182 if candidates.contains(&Device::Gpu) {
183 models.push(&wgpu);
184 }
185
186 if models.len() >= 2 {
187 pick_best_device(graph, &models)
188 } else if let Some(m) = models.first() {
189 m.device()
190 } else {
191 crate::device_ext::fastest_among(&candidates)
192 }
193}
194
195#[cfg(feature = "cpu")]
204pub struct CpuCostModel(rlx_cpu::cost::HwModel);
205
206#[cfg(feature = "cpu")]
207impl CpuCostModel {
208 pub fn new() -> Self {
209 let cfg = rlx_cpu::config::RuntimeConfig::global();
210 Self(rlx_cpu::cost::HwModel::from_config(cfg))
211 }
212}
213
214#[cfg(feature = "cpu")]
215impl Default for CpuCostModel {
216 fn default() -> Self {
217 Self::new()
218 }
219}
220
221#[cfg(feature = "cpu")]
222impl BackendCostModel for CpuCostModel {
223 fn device(&self) -> Device {
224 Device::Cpu
225 }
226 fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64 {
227 let flops = 2.0 * m as f64 * k as f64 * n as f64;
229 let neon_time = flops / self.0.neon_flops.max(1.0);
230 let blas_time = flops / self.0.blas_flops.max(1.0);
231 let pick = neon_time.min(blas_time);
232 if pick > 0.0 {
233 flops / (pick * 1e9)
234 } else {
235 0.0
236 }
237 }
238 fn dispatch_overhead_ns(&self) -> f64 {
239 self.0.blas_overhead_ns
240 }
241 fn roundtrip_overhead_ns(&self) -> f64 {
242 self.0.par_for_overhead_ns
243 }
244 fn memory_bw(&self) -> f64 {
245 self.0.mem_bw
246 }
247 fn num_threads(&self) -> usize {
248 self.0.num_threads
249 }
250}
251
252#[cfg(feature = "metal")]
256pub struct MetalCostModel {
257 sgemm_gflops_avg: f64,
258 roundtrip_ns: f64,
259 memory_bw: f64,
260}
261
262#[cfg(feature = "metal")]
263impl MetalCostModel {
264 pub fn new() -> Self {
265 let cal = rlx_metal::calibrate::Calibration::load_or_measure();
266 let best = cal
268 .sgemm_simd_4x4_flops
269 .max(cal.sgemm_simd_flops)
270 .max(cal.sgemm_padded_flops);
271 Self {
272 sgemm_gflops_avg: best,
273 roundtrip_ns: cal.roundtrip_overhead_ns,
274 memory_bw: 200.0,
279 }
280 }
281}
282
283#[cfg(feature = "metal")]
284impl Default for MetalCostModel {
285 fn default() -> Self {
286 Self::new()
287 }
288}
289
290#[cfg(feature = "metal")]
291impl BackendCostModel for MetalCostModel {
292 fn device(&self) -> Device {
293 Device::Metal
294 }
295 fn sgemm_gflops(&self, _m: usize, _k: usize, _n: usize) -> f64 {
296 self.sgemm_gflops_avg
297 }
298 fn dispatch_overhead_ns(&self) -> f64 {
299 2_000.0
301 }
302 fn roundtrip_overhead_ns(&self) -> f64 {
303 self.roundtrip_ns
304 }
305 fn memory_bw(&self) -> f64 {
306 self.memory_bw
307 }
308 fn unified_memory(&self) -> bool {
309 true
310 }
311 fn num_threads(&self) -> usize {
312 1
313 } }
315
316#[cfg(all(feature = "mlx", rlx_mlx_host))]
321pub struct MlxCostModel {
322 sgemm_large_flops: f64,
323 sgemm_small_flops: f64,
324 roundtrip_ns: f64,
325 memory_bw: f64,
326}
327
328#[cfg(all(feature = "mlx", rlx_mlx_host))]
329impl MlxCostModel {
330 pub fn new() -> Self {
331 let cal = rlx_mlx::calibrate::Calibration::load_or_measure();
332 let memory_bw = if cal.memory_bw_gbps > 0.0 {
337 cal.memory_bw_gbps
338 } else {
339 200.0
340 };
341 Self {
342 sgemm_large_flops: cal.sgemm_large_flops,
343 sgemm_small_flops: cal.sgemm_small_flops,
344 roundtrip_ns: cal.roundtrip_overhead_ns,
345 memory_bw,
346 }
347 }
348}
349
350#[cfg(all(feature = "mlx", rlx_mlx_host))]
351impl Default for MlxCostModel {
352 fn default() -> Self {
353 Self::new()
354 }
355}
356
357#[cfg(all(feature = "mlx", rlx_mlx_host))]
358impl BackendCostModel for MlxCostModel {
359 fn device(&self) -> Device {
360 Device::Mlx
361 }
362 fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64 {
363 let total = m as f64 * k as f64 * n as f64;
367 if total < 32_768.0 {
368 self.sgemm_small_flops
369 } else {
370 self.sgemm_large_flops
371 }
372 }
373 fn dispatch_overhead_ns(&self) -> f64 {
374 2_000.0
377 }
378 fn roundtrip_overhead_ns(&self) -> f64 {
379 self.roundtrip_ns
380 }
381 fn memory_bw(&self) -> f64 {
382 self.memory_bw
383 }
384 fn num_threads(&self) -> usize {
385 1
386 }
387}
388
389#[cfg(feature = "cuda")]
391pub struct CudaCostModel {
392 sgemm_gflops: f64,
393 roundtrip_ns: f64,
394 memory_bw: f64,
395}
396
397#[cfg(feature = "cuda")]
398impl CudaCostModel {
399 pub fn new() -> Self {
400 if crate::is_available(crate::Device::Cuda) {
401 let cal = rlx_cuda::calibrate::Calibration::load_or_measure();
402 return Self {
403 sgemm_gflops: cal.sgemm_gflops,
404 roundtrip_ns: cal.roundtrip_overhead_ns,
405 memory_bw: cal.memory_bw_gbps,
406 };
407 }
408 Self {
409 sgemm_gflops: 12_000.0,
410 roundtrip_ns: 35_000.0,
411 memory_bw: 900.0,
412 }
413 }
414}
415
416#[cfg(feature = "cuda")]
417impl Default for CudaCostModel {
418 fn default() -> Self {
419 Self::new()
420 }
421}
422
423#[cfg(feature = "cuda")]
424impl BackendCostModel for CudaCostModel {
425 fn device(&self) -> Device {
426 Device::Cuda
427 }
428 fn sgemm_gflops(&self, _m: usize, _k: usize, _n: usize) -> f64 {
429 self.sgemm_gflops
430 }
431 fn dispatch_overhead_ns(&self) -> f64 {
432 3_000.0
433 }
434 fn roundtrip_overhead_ns(&self) -> f64 {
435 self.roundtrip_ns
436 }
437 fn memory_bw(&self) -> f64 {
438 self.memory_bw
439 }
440 fn num_threads(&self) -> usize {
441 1
442 }
443}
444
445#[cfg(feature = "rocm")]
447pub struct RocmCostModel {
448 sgemm_gflops: f64,
449 roundtrip_ns: f64,
450 memory_bw: f64,
451}
452
453#[cfg(feature = "rocm")]
454impl RocmCostModel {
455 pub fn new() -> Self {
456 if crate::is_available(crate::Device::Rocm) {
457 let cal = rlx_rocm::calibrate::Calibration::load_or_measure();
458 return Self {
459 sgemm_gflops: cal.sgemm_gflops,
460 roundtrip_ns: cal.roundtrip_overhead_ns,
461 memory_bw: cal.memory_bw_gbps,
462 };
463 }
464 Self {
465 sgemm_gflops: 10_000.0,
466 roundtrip_ns: 40_000.0,
467 memory_bw: 800.0,
468 }
469 }
470}
471
472#[cfg(feature = "rocm")]
473impl Default for RocmCostModel {
474 fn default() -> Self {
475 Self::new()
476 }
477}
478
479#[cfg(feature = "rocm")]
480impl BackendCostModel for RocmCostModel {
481 fn device(&self) -> Device {
482 Device::Rocm
483 }
484 fn sgemm_gflops(&self, _m: usize, _k: usize, _n: usize) -> f64 {
485 self.sgemm_gflops
486 }
487 fn dispatch_overhead_ns(&self) -> f64 {
488 3_000.0
489 }
490 fn roundtrip_overhead_ns(&self) -> f64 {
491 self.roundtrip_ns
492 }
493 fn memory_bw(&self) -> f64 {
494 self.memory_bw
495 }
496 fn num_threads(&self) -> usize {
497 1
498 }
499}
500
501#[cfg(feature = "gpu")]
503pub struct WgpuCostModel {
504 sgemm_gflops: f64,
505 roundtrip_ns: f64,
506 memory_bw: f64,
507}
508
509#[cfg(feature = "gpu")]
510impl WgpuCostModel {
511 pub fn new() -> Self {
512 if rlx_wgpu::is_available() {
513 let cal = rlx_wgpu::calibrate::Calibration::load_or_measure();
514 return Self {
515 sgemm_gflops: cal.sgemm_gflops,
516 roundtrip_ns: cal.roundtrip_overhead_ns,
517 memory_bw: cal.memory_bw_gbps,
518 };
519 }
520 Self {
521 sgemm_gflops: 2_500.0,
522 roundtrip_ns: 80_000.0,
523 memory_bw: 120.0,
524 }
525 }
526}
527
528#[cfg(feature = "gpu")]
529impl Default for WgpuCostModel {
530 fn default() -> Self {
531 Self::new()
532 }
533}
534
535#[cfg(feature = "gpu")]
536impl BackendCostModel for WgpuCostModel {
537 fn device(&self) -> Device {
538 Device::Gpu
539 }
540 fn sgemm_gflops(&self, _m: usize, _k: usize, _n: usize) -> f64 {
541 self.sgemm_gflops
542 }
543 fn dispatch_overhead_ns(&self) -> f64 {
544 5_000.0
545 }
546 fn roundtrip_overhead_ns(&self) -> f64 {
547 self.roundtrip_ns
548 }
549 fn memory_bw(&self) -> f64 {
550 self.memory_bw
551 }
552 fn num_threads(&self) -> usize {
553 1
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560 use rlx_ir::{DType, Graph, Shape};
561
562 #[test]
563 fn fastest_device_for_falls_back_to_cpu_for_simple_graph() {
564 let mut g = Graph::new("mm");
565 let x = g.input("x", Shape::new(&[4, 4], DType::F32));
566 let w = g.param("w", Shape::new(&[4, 4], DType::F32));
567 let y = g.matmul(x, w, Shape::new(&[4, 4], DType::F32));
568 g.set_outputs(vec![y]);
569 let pick = fastest_device_for(&g);
570 assert!(crate::is_available(pick));
571 assert!(crate::devices_for(&g).contains(&pick));
572 }
573}