1use crate::Device;
28use rlx_ir::{Graph, Node, Op};
29
30pub trait BackendCostModel: Send + Sync {
32 fn device(&self) -> Device;
34
35 fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64;
39
40 fn dispatch_overhead_ns(&self) -> f64;
42
43 fn roundtrip_overhead_ns(&self) -> f64;
46
47 fn memory_bw(&self) -> f64;
49
50 fn num_threads(&self) -> usize;
52}
53
54pub fn estimate_graph_cost(graph: &Graph, model: &dyn BackendCostModel) -> f64 {
58 let mut total = model.roundtrip_overhead_ns();
59 for node in graph.nodes() {
60 total += node_cost(node, graph, model);
61 }
62 total
63}
64
65fn node_cost(node: &Node, graph: &Graph, model: &dyn BackendCostModel) -> f64 {
66 let dispatch = model.dispatch_overhead_ns();
67 match &node.op {
68 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => 0.0,
69 Op::MatMul | Op::FusedMatMulBiasAct { .. } => {
70 let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
71 let total = node.shape.num_elements().unwrap_or(0);
72 let m = total / n.max(1);
73 let a_total = graph.node(node.inputs[0]).shape.num_elements().unwrap_or(0);
74 let k = a_total / m.max(1);
75 let flops = 2.0 * m as f64 * k as f64 * n as f64;
76 flops / (model.sgemm_gflops(m, k, n) + 1.0) + dispatch
77 }
78 Op::Attention {
79 num_heads,
80 head_dim,
81 ..
82 } => {
83 let q_shape = &graph.node(node.inputs[0]).shape;
84 let seq = q_shape.dim(q_shape.rank() - 2).unwrap_static();
85 let batch = q_shape.num_elements().unwrap_or(0) / (seq * num_heads * head_dim).max(1);
86 let flops = (batch * num_heads * seq * seq * head_dim * 2) as f64;
87 flops / (model.sgemm_gflops(seq, *head_dim, seq) + 1.0) + dispatch
88 }
89 _ => {
91 let bytes = node.shape.num_elements().unwrap_or(0) * 4;
92 (bytes as f64) / model.memory_bw().max(1.0) + dispatch
93 }
94 }
95}
96
97pub fn pick_best_device(graph: &Graph, models: &[&dyn BackendCostModel]) -> Device {
99 let mut best = (Device::Cpu, f64::INFINITY);
100 for &m in models {
101 let cost = estimate_graph_cost(graph, m);
102 if cost < best.1 {
103 best = (m.device(), cost);
104 }
105 }
106 best.0
107}
108
109#[cfg(feature = "cpu")]
118pub struct CpuCostModel(rlx_cpu::cost::HwModel);
119
120#[cfg(feature = "cpu")]
121impl CpuCostModel {
122 pub fn new() -> Self {
123 let cfg = rlx_cpu::config::RuntimeConfig::global();
124 Self(rlx_cpu::cost::HwModel::from_config(cfg))
125 }
126}
127
128#[cfg(feature = "cpu")]
129impl Default for CpuCostModel {
130 fn default() -> Self {
131 Self::new()
132 }
133}
134
135#[cfg(feature = "cpu")]
136impl BackendCostModel for CpuCostModel {
137 fn device(&self) -> Device {
138 Device::Cpu
139 }
140 fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64 {
141 let flops = 2.0 * m as f64 * k as f64 * n as f64;
143 let neon_time = flops / self.0.neon_flops.max(1.0);
144 let blas_time = flops / self.0.blas_flops.max(1.0);
145 let pick = neon_time.min(blas_time);
146 if pick > 0.0 {
147 flops / (pick * 1e9)
148 } else {
149 0.0
150 }
151 }
152 fn dispatch_overhead_ns(&self) -> f64 {
153 self.0.blas_overhead_ns
154 }
155 fn roundtrip_overhead_ns(&self) -> f64 {
156 self.0.par_for_overhead_ns
157 }
158 fn memory_bw(&self) -> f64 {
159 self.0.mem_bw
160 }
161 fn num_threads(&self) -> usize {
162 self.0.num_threads
163 }
164}
165
166#[cfg(feature = "metal")]
170pub struct MetalCostModel {
171 sgemm_gflops_avg: f64,
172 roundtrip_ns: f64,
173 memory_bw: f64,
174}
175
176#[cfg(feature = "metal")]
177impl MetalCostModel {
178 pub fn new() -> Self {
179 let cal = rlx_metal::calibrate::Calibration::load_or_measure();
180 let best = cal
182 .sgemm_simd_4x4_flops
183 .max(cal.sgemm_simd_flops)
184 .max(cal.sgemm_padded_flops);
185 Self {
186 sgemm_gflops_avg: best,
187 roundtrip_ns: cal.roundtrip_overhead_ns,
188 memory_bw: 200.0,
193 }
194 }
195}
196
197#[cfg(feature = "metal")]
198impl Default for MetalCostModel {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204#[cfg(feature = "metal")]
205impl BackendCostModel for MetalCostModel {
206 fn device(&self) -> Device {
207 Device::Metal
208 }
209 fn sgemm_gflops(&self, _m: usize, _k: usize, _n: usize) -> f64 {
210 self.sgemm_gflops_avg
211 }
212 fn dispatch_overhead_ns(&self) -> f64 {
213 2_000.0
215 }
216 fn roundtrip_overhead_ns(&self) -> f64 {
217 self.roundtrip_ns
218 }
219 fn memory_bw(&self) -> f64 {
220 self.memory_bw
221 }
222 fn num_threads(&self) -> usize {
223 1
224 } }
226
227#[cfg(all(feature = "mlx", target_os = "macos"))]
232pub struct MlxCostModel {
233 sgemm_large_flops: f64,
234 sgemm_small_flops: f64,
235 roundtrip_ns: f64,
236 memory_bw: f64,
237}
238
239#[cfg(all(feature = "mlx", target_os = "macos"))]
240impl MlxCostModel {
241 pub fn new() -> Self {
242 let cal = rlx_mlx::calibrate::Calibration::load_or_measure();
243 let memory_bw = if cal.memory_bw_gbps > 0.0 {
248 cal.memory_bw_gbps
249 } else {
250 200.0
251 };
252 Self {
253 sgemm_large_flops: cal.sgemm_large_flops,
254 sgemm_small_flops: cal.sgemm_small_flops,
255 roundtrip_ns: cal.roundtrip_overhead_ns,
256 memory_bw,
257 }
258 }
259}
260
261#[cfg(all(feature = "mlx", target_os = "macos"))]
262impl Default for MlxCostModel {
263 fn default() -> Self {
264 Self::new()
265 }
266}
267
268#[cfg(all(feature = "mlx", target_os = "macos"))]
269impl BackendCostModel for MlxCostModel {
270 fn device(&self) -> Device {
271 Device::Mlx
272 }
273 fn sgemm_gflops(&self, m: usize, k: usize, n: usize) -> f64 {
274 let total = m as f64 * k as f64 * n as f64;
278 if total < 32_768.0 {
279 self.sgemm_small_flops
280 } else {
281 self.sgemm_large_flops
282 }
283 }
284 fn dispatch_overhead_ns(&self) -> f64 {
285 2_000.0
288 }
289 fn roundtrip_overhead_ns(&self) -> f64 {
290 self.roundtrip_ns
291 }
292 fn memory_bw(&self) -> f64 {
293 self.memory_bw
294 }
295 fn num_threads(&self) -> usize {
296 1
297 }
298}