1use crate::array::Array;
17use crate::backend::dispatch::{BackendValidation, ElementwiseFn, MatmulFn, ReductionFn};
18use std::time::Instant;
19
20#[derive(Debug, Clone, Copy)]
22pub struct SizeRange<F: Copy> {
23 pub min_size: usize,
25 pub max_size: Option<usize>,
27 pub kernel: F,
29 pub backend_name: &'static str,
31}
32
33#[derive(Debug, Clone)]
36pub struct AdaptiveLookupTable<F: Copy> {
37 pub ranges: Vec<SizeRange<F>>,
38}
39
40impl<F: Copy> AdaptiveLookupTable<F> {
41 #[inline]
43 pub fn lookup(&self, size: usize) -> F {
44 for range in &self.ranges {
45 if size >= range.min_size {
46 if let Some(max) = range.max_size {
47 if size < max {
48 return range.kernel;
49 }
50 } else {
51 return range.kernel;
53 }
54 }
55 }
56
57 self.ranges[0].kernel
59 }
60
61 pub fn backend_name(&self, size: usize) -> &'static str {
63 for range in &self.ranges {
64 if size >= range.min_size {
65 if let Some(max) = range.max_size {
66 if size < max {
67 return range.backend_name;
68 }
69 } else {
70 return range.backend_name;
71 }
72 }
73 }
74 self.ranges[0].backend_name
75 }
76}
77
78#[derive(Debug, Clone)]
80pub struct BenchConfig {
81 pub enabled: bool,
83
84 pub iterations: usize,
86
87 pub max_time_ms: u64,
89}
90
91impl Default for BenchConfig {
92 fn default() -> Self {
93 Self {
94 enabled: false,
95 iterations: 3,
96 max_time_ms: 50,
97 }
98 }
99}
100
101impl BenchConfig {
102 pub fn from_env() -> Self {
104 let enabled = std::env::var("NUMRS_ENABLE_PROBING")
105 .map(|v| matches!(v.as_str(), "1" | "true" | "TRUE"))
106 .unwrap_or(false);
107
108 let iterations = std::env::var("NUMRS_BENCH_ITERATIONS")
109 .ok()
110 .and_then(|v| v.parse().ok())
111 .unwrap_or(3);
112
113 let max_time_ms = std::env::var("NUMRS_BENCH_TIMEOUT_MS")
114 .ok()
115 .and_then(|v| v.parse().ok())
116 .unwrap_or(50);
117
118 Self {
119 enabled,
120 iterations,
121 max_time_ms,
122 }
123 }
124}
125
126#[inline]
128fn bench_fn<F>(f: F, iterations: usize) -> f64
129where
130 F: Fn() -> (),
131{
132 let mut total = 0.0;
133
134 for _ in 0..iterations {
135 let start = Instant::now();
136 f();
137 total += start.elapsed().as_secs_f64();
138 }
139
140 total / iterations as f64
141}
142
143pub fn benchmark_matmul(
145 validation: &BackendValidation,
146 config: &BenchConfig,
147) -> AdaptiveLookupTable<MatmulFn> {
148 if !config.enabled {
149 #[cfg(debug_assertions)]
151 eprintln!("[numrs-bench] Matmul: using static heuristic (probing disabled)");
152
153 let mut ranges = Vec::new();
154
155 if validation.simd_validated {
157 ranges.push(SizeRange {
158 min_size: 0,
159 max_size: Some(4_096),
160 kernel: crate::backend::dispatch::kernel_matmul_simd as MatmulFn,
161 backend_name: "cpu-simd",
162 });
163 }
164
165 if validation.blas_validated {
167 ranges.push(SizeRange {
168 min_size: 4_096,
169 max_size: Some(262_144),
170 kernel: crate::backend::dispatch::kernel_matmul_blas_direct as MatmulFn,
171 backend_name: "blas",
172 });
173 }
174
175 if validation.metal_validated {
177 ranges.push(SizeRange {
178 min_size: 262_144,
179 max_size: None,
180 kernel: crate::backend::dispatch::kernel_matmul_metal as MatmulFn,
181 backend_name: "metal",
182 });
183 } else if validation.webgpu_validated {
184 ranges.push(SizeRange {
185 min_size: 262_144,
186 max_size: None,
187 kernel: crate::backend::dispatch::kernel_matmul_webgpu as MatmulFn,
188 backend_name: "webgpu",
189 });
190 } else if validation.blas_validated {
191 ranges.push(SizeRange {
193 min_size: 262_144,
194 max_size: None,
195 kernel: crate::backend::dispatch::kernel_matmul_blas_direct as MatmulFn,
196 backend_name: "blas",
197 });
198 } else if validation.simd_validated {
199 ranges.push(SizeRange {
201 min_size: 262_144,
202 max_size: None,
203 kernel: crate::backend::dispatch::kernel_matmul_simd as MatmulFn,
204 backend_name: "cpu-simd",
205 });
206 }
207
208 if ranges.is_empty() {
210 ranges.push(SizeRange {
211 min_size: 0,
212 max_size: None,
213 kernel: crate::backend::dispatch::kernel_matmul_scalar as MatmulFn,
214 backend_name: "cpu-scalar",
215 });
216 }
217
218 return AdaptiveLookupTable { ranges };
219 }
220
221 eprintln!("[numrs-bench] Running matmul microbenchmarks...");
223
224 let test_sizes = vec![32, 64, 128, 256, 512];
226 let mut results: Vec<(usize, &'static str, MatmulFn, f64)> = Vec::new();
227
228 let start_total = Instant::now();
229
230 for &size in &test_sizes {
231 if start_total.elapsed().as_millis() > config.max_time_ms as u128 {
232 eprintln!("[numrs-bench] Timeout reached");
233 break;
234 }
235
236 let output_size = size * size;
237 let a = Array::new(vec![size, size], vec![1.0f32; size * size]);
238 let b = Array::new(vec![size, size], vec![1.0f32; size * size]);
239
240 let mut times = Vec::new();
241
242 if validation.simd_validated {
244 let time = bench_fn(
245 || {
246 let _ = crate::backend::dispatch::kernel_matmul_simd(&a, &b);
247 },
248 config.iterations.min(2),
249 );
250 times.push((
251 "cpu-simd",
252 crate::backend::dispatch::kernel_matmul_simd as MatmulFn,
253 time,
254 ));
255 }
256
257 if validation.blas_validated {
259 let time = bench_fn(
260 || {
261 let _ = crate::backend::dispatch::kernel_matmul_blas_direct(&a, &b);
262 },
263 config.iterations.min(2),
264 );
265 times.push((
266 "blas",
267 crate::backend::dispatch::kernel_matmul_blas_direct as MatmulFn,
268 time,
269 ));
270 }
271
272 if size >= 128 {
274 if validation.metal_validated {
275 if let Ok(_) = crate::backend::dispatch::kernel_matmul_metal(&a, &b) {
276 let time = bench_fn(
277 || {
278 let _ = crate::backend::dispatch::kernel_matmul_metal(&a, &b);
279 },
280 1,
281 );
282 times.push((
283 "metal",
284 crate::backend::dispatch::kernel_matmul_metal as MatmulFn,
285 time,
286 ));
287 }
288 }
289
290 if validation.webgpu_validated {
291 if let Ok(_) = crate::backend::dispatch::kernel_matmul_webgpu(&a, &b) {
292 let time = bench_fn(
293 || {
294 let _ = crate::backend::dispatch::kernel_matmul_webgpu(&a, &b);
295 },
296 1,
297 );
298 times.push((
299 "webgpu",
300 crate::backend::dispatch::kernel_matmul_webgpu as MatmulFn,
301 time,
302 ));
303 }
304 }
305 }
306
307 if let Some((name, kernel, time)) =
309 times.iter().min_by(|a, b| a.2.partial_cmp(&b.2).unwrap())
310 {
311 results.push((output_size, *name, *kernel, *time));
312 eprintln!(
313 "[numrs-bench] {}x{} (size={}): {} ({:.3}ms)",
314 size,
315 size,
316 output_size,
317 name,
318 time * 1000.0
319 );
320 }
321 }
322
323 build_lookup_table_from_results(results, validation)
325}
326
327fn build_lookup_table_from_results(
329 results: Vec<(usize, &'static str, MatmulFn, f64)>,
330 _validation: &BackendValidation,
331) -> AdaptiveLookupTable<MatmulFn> {
332 let mut ranges = Vec::new();
333
334 if results.is_empty() {
335 let disabled_config = BenchConfig {
337 enabled: false,
338 iterations: 3,
339 max_time_ms: 50,
340 };
341 return benchmark_matmul(_validation, &disabled_config);
342 }
343
344 let mut current_backend = results[0].1;
346 let mut current_kernel = results[0].2;
347 let mut range_start = 0;
348
349 for (i, (size, backend, kernel, _)) in results.iter().enumerate() {
350 if *backend != current_backend || i == results.len() - 1 {
351 let range_end = if i == results.len() - 1 {
353 None
354 } else {
355 Some(*size)
356 };
357
358 ranges.push(SizeRange {
359 min_size: range_start,
360 max_size: range_end,
361 kernel: current_kernel,
362 backend_name: current_backend,
363 });
364
365 if i < results.len() - 1 {
366 current_backend = *backend;
367 current_kernel = *kernel;
368 range_start = *size;
369 }
370 }
371 }
372
373 if let Some(last_range) = ranges.last_mut() {
375 if last_range.max_size.is_some() {
376 last_range.max_size = None;
377 }
378 }
379
380 eprintln!(
381 "[numrs-bench] Matmul lookup table created with {} ranges",
382 ranges.len()
383 );
384 for range in &ranges {
385 let max_str = range
386 .max_size
387 .map(|m| m.to_string())
388 .unwrap_or_else(|| "∞".to_string());
389 eprintln!(
390 "[numrs-bench] [{}, {}): {}",
391 range.min_size, max_str, range.backend_name
392 );
393 }
394
395 AdaptiveLookupTable { ranges }
396}
397
398pub fn benchmark_elementwise(
400 _validation: &BackendValidation,
401 _config: &BenchConfig,
402) -> AdaptiveLookupTable<ElementwiseFn> {
403 use crate::backend::dispatch::kernel_elementwise_simd;
404
405 #[cfg(debug_assertions)]
406 eprintln!("[numrs-bench] Elementwise: using static fallback (not implemented yet)");
407
408 AdaptiveLookupTable {
409 ranges: vec![SizeRange {
410 min_size: 0,
411 max_size: None,
412 kernel: kernel_elementwise_simd as ElementwiseFn,
413 backend_name: "cpu-simd",
414 }],
415 }
416}
417
418pub fn benchmark_reduction(
420 _validation: &BackendValidation,
421 _config: &BenchConfig,
422) -> AdaptiveLookupTable<ReductionFn> {
423 use crate::backend::dispatch::kernel_reduction_simd;
424
425 #[cfg(debug_assertions)]
426 eprintln!("[numrs-bench] Reduction: using static fallback (not implemented yet)");
427
428 AdaptiveLookupTable {
429 ranges: vec![SizeRange {
430 min_size: 0,
431 max_size: None,
432 kernel: kernel_reduction_simd as ReductionFn,
433 backend_name: "cpu-simd",
434 }],
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use anyhow::Result;
442
443 #[test]
444 fn test_lookup_table() {
445 fn dummy_kernel(_a: &Array, _b: &Array) -> Result<Array> {
447 Ok(Array::new(vec![1], vec![0.0]))
448 }
449
450 let table = AdaptiveLookupTable {
451 ranges: vec![
452 SizeRange {
453 min_size: 0,
454 max_size: Some(100),
455 kernel: dummy_kernel as MatmulFn,
456 backend_name: "small",
457 },
458 SizeRange {
459 min_size: 100,
460 max_size: Some(1000),
461 kernel: dummy_kernel as MatmulFn,
462 backend_name: "medium",
463 },
464 SizeRange {
465 min_size: 1000,
466 max_size: None,
467 kernel: dummy_kernel as MatmulFn,
468 backend_name: "large",
469 },
470 ],
471 };
472
473 assert_eq!(table.backend_name(50), "small");
474 assert_eq!(table.backend_name(500), "medium");
475 assert_eq!(table.backend_name(5000), "large");
476 }
477}