1use serde::{Deserialize, Serialize};
40use std::fmt;
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
44pub enum Backend {
45 Scalar,
47 Simd,
49 Gpu,
51}
52
53impl fmt::Display for Backend {
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 match self {
56 Backend::Scalar => write!(f, "Scalar"),
57 Backend::Simd => write!(f, "SIMD"),
58 Backend::Gpu => write!(f, "GPU"),
59 }
60 }
61}
62
63impl Default for Backend {
64 fn default() -> Self {
65 Backend::Scalar
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
71pub enum OpComplexity {
72 Low,
74 Medium,
76 High,
78}
79
80impl Default for OpComplexity {
81 fn default() -> Self {
82 OpComplexity::Low
83 }
84}
85
86#[derive(Debug, Clone)]
90pub struct BackendSelector {
91 pcie_bandwidth: f64,
93 gpu_gflops: f64,
95 min_dispatch_ratio: f64,
97 simd_threshold_low: usize,
99 simd_threshold_medium: usize,
101 gpu_threshold_medium: usize,
103 simd_threshold_high: usize,
105 gpu_threshold_high: usize,
107}
108
109impl Default for BackendSelector {
110 fn default() -> Self {
111 Self {
112 pcie_bandwidth: 32e9, gpu_gflops: 20e12, min_dispatch_ratio: 5.0, simd_threshold_low: 1_000_000,
116 simd_threshold_medium: 10_000,
117 gpu_threshold_medium: 100_000,
118 simd_threshold_high: 1_000,
119 gpu_threshold_high: 10_000,
120 }
121 }
122}
123
124impl BackendSelector {
125 #[must_use]
127 pub fn new() -> Self {
128 Self::default()
129 }
130
131 #[must_use]
133 pub fn with_pcie_bandwidth(mut self, bandwidth: f64) -> Self {
134 self.pcie_bandwidth = bandwidth;
135 self
136 }
137
138 #[must_use]
140 pub fn with_gpu_gflops(mut self, gflops: f64) -> Self {
141 self.gpu_gflops = gflops;
142 self
143 }
144
145 #[must_use]
147 pub fn with_min_dispatch_ratio(mut self, ratio: f64) -> Self {
148 self.min_dispatch_ratio = ratio;
149 self
150 }
151
152 #[must_use]
173 pub fn select_backend(&self, data_bytes: usize, flops: u64) -> Backend {
174 let transfer_s = data_bytes as f64 / self.pcie_bandwidth;
176
177 let compute_s = flops as f64 / self.gpu_gflops;
179
180 if compute_s > self.min_dispatch_ratio * transfer_s {
182 Backend::Gpu
183 } else {
184 Backend::Simd
186 }
187 }
188
189 #[must_use]
198 pub fn select_for_matmul(&self, m: usize, n: usize, k: usize) -> Backend {
199 let data_bytes = (m * k + k * n + m * n) * 4;
201
202 let flops = (2 * m * n * k) as u64;
204
205 self.select_backend(data_bytes, flops)
206 }
207
208 #[must_use]
214 pub fn select_for_vector_op(&self, n: usize, ops_per_element: u64) -> Backend {
215 let data_bytes = n * 3 * 4;
217
218 let flops = n as u64 * ops_per_element;
220
221 self.select_backend(data_bytes, flops)
222 }
223
224 #[must_use]
228 pub fn select_for_elementwise(&self, n: usize) -> Backend {
229 if n > self.simd_threshold_low {
232 Backend::Simd
233 } else {
234 Backend::Scalar
235 }
236 }
237
238 #[must_use]
252 pub fn select_with_moe(&self, complexity: OpComplexity, data_size: usize) -> Backend {
253 match complexity {
254 OpComplexity::Low => {
255 if data_size > self.simd_threshold_low {
257 Backend::Simd
258 } else {
259 Backend::Scalar
260 }
261 }
262 OpComplexity::Medium => {
263 if data_size > self.gpu_threshold_medium {
265 Backend::Gpu
266 } else if data_size > self.simd_threshold_medium {
267 Backend::Simd
268 } else {
269 Backend::Scalar
270 }
271 }
272 OpComplexity::High => {
273 if data_size > self.gpu_threshold_high {
275 Backend::Gpu
276 } else if data_size > self.simd_threshold_high {
277 Backend::Simd
278 } else {
279 Backend::Scalar
280 }
281 }
282 }
283 }
284
285 #[must_use]
287 pub fn selection_stats(&self, complexity: OpComplexity, data_size: usize) -> SelectionStats {
288 let backend = self.select_with_moe(complexity, data_size);
289
290 let speedup = match backend {
292 Backend::Scalar => 1.0,
293 Backend::Simd => {
294 match complexity {
296 OpComplexity::Low => 4.0,
297 OpComplexity::Medium => 6.0,
298 OpComplexity::High => 8.0,
299 }
300 }
301 Backend::Gpu => {
302 let base = match complexity {
304 OpComplexity::Low => 1.0, OpComplexity::Medium => 10.0,
306 OpComplexity::High => 50.0,
307 };
308 base * (data_size as f64 / 10_000.0).min(10.0)
310 }
311 };
312
313 SelectionStats {
314 backend,
315 complexity,
316 data_size,
317 estimated_speedup: speedup,
318 }
319 }
320}
321
322#[derive(Debug, Clone)]
324pub struct SelectionStats {
325 pub backend: Backend,
327 pub complexity: OpComplexity,
329 pub data_size: usize,
331 pub estimated_speedup: f64,
333}
334
335impl fmt::Display for SelectionStats {
336 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
337 write!(
338 f,
339 "{} backend for {:?} complexity ({} elements) - ~{:.1}x speedup",
340 self.backend, self.complexity, self.data_size, self.estimated_speedup
341 )
342 }
343}
344
345#[derive(Debug, Clone)]
347pub struct BatchConfig {
348 selector: BackendSelector,
350 pub batch_size: usize,
352 pub complexity: OpComplexity,
354}
355
356impl BatchConfig {
357 #[must_use]
359 pub fn new(batch_size: usize) -> Self {
360 Self {
361 selector: BackendSelector::new(),
362 batch_size,
363 complexity: OpComplexity::Low,
364 }
365 }
366
367 #[must_use]
369 pub fn with_complexity(mut self, complexity: OpComplexity) -> Self {
370 self.complexity = complexity;
371 self
372 }
373
374 #[must_use]
376 pub fn recommended_backend(&self) -> Backend {
377 self.selector
378 .select_with_moe(self.complexity, self.batch_size)
379 }
380
381 #[must_use]
383 pub fn should_use_gpu(&self) -> bool {
384 self.recommended_backend() == Backend::Gpu
385 }
386
387 #[must_use]
389 pub fn should_use_simd(&self) -> bool {
390 matches!(self.recommended_backend(), Backend::Simd | Backend::Gpu)
391 }
392}
393
394impl Default for BatchConfig {
395 fn default() -> Self {
396 Self::new(1000)
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 fn test_backend_display() {
406 assert_eq!(format!("{}", Backend::Scalar), "Scalar");
407 assert_eq!(format!("{}", Backend::Simd), "SIMD");
408 assert_eq!(format!("{}", Backend::Gpu), "GPU");
409 }
410
411 #[test]
412 fn test_backend_default() {
413 assert_eq!(Backend::default(), Backend::Scalar);
414 }
415
416 #[test]
417 fn test_op_complexity_ordering() {
418 assert!(OpComplexity::Low < OpComplexity::Medium);
419 assert!(OpComplexity::Medium < OpComplexity::High);
420 }
421
422 #[test]
423 fn test_selector_default() {
424 let selector = BackendSelector::new();
425 assert_eq!(selector.min_dispatch_ratio, 5.0);
426 }
427
428 #[test]
429 fn test_select_elementwise_small() {
430 let selector = BackendSelector::new();
431 let backend = selector.select_for_elementwise(100);
432 assert_eq!(backend, Backend::Scalar);
433 }
434
435 #[test]
436 fn test_select_elementwise_large() {
437 let selector = BackendSelector::new();
438 let backend = selector.select_for_elementwise(10_000_000);
439 assert_eq!(backend, Backend::Simd);
440 }
441
442 #[test]
443 fn test_moe_low_complexity() {
444 let selector = BackendSelector::new();
445
446 assert_eq!(
448 selector.select_with_moe(OpComplexity::Low, 100),
449 Backend::Scalar
450 );
451
452 assert_eq!(
454 selector.select_with_moe(OpComplexity::Low, 10_000_000),
455 Backend::Simd
456 );
457 }
458
459 #[test]
460 fn test_moe_medium_complexity() {
461 let selector = BackendSelector::new();
462
463 assert_eq!(
465 selector.select_with_moe(OpComplexity::Medium, 100),
466 Backend::Scalar
467 );
468
469 assert_eq!(
471 selector.select_with_moe(OpComplexity::Medium, 50_000),
472 Backend::Simd
473 );
474
475 assert_eq!(
477 selector.select_with_moe(OpComplexity::Medium, 500_000),
478 Backend::Gpu
479 );
480 }
481
482 #[test]
483 fn test_moe_high_complexity() {
484 let selector = BackendSelector::new();
485
486 assert_eq!(
488 selector.select_with_moe(OpComplexity::High, 100),
489 Backend::Scalar
490 );
491
492 assert_eq!(
494 selector.select_with_moe(OpComplexity::High, 5_000),
495 Backend::Simd
496 );
497
498 assert_eq!(
500 selector.select_with_moe(OpComplexity::High, 50_000),
501 Backend::Gpu
502 );
503 }
504
505 #[test]
506 fn test_select_matmul_small() {
507 let selector = BackendSelector::new();
508 let backend = selector.select_for_matmul(10, 10, 10);
510 assert_eq!(backend, Backend::Simd);
511 }
512
513 #[test]
514 fn test_select_matmul_large() {
515 let selector = BackendSelector::new();
516 let backend = selector.select_for_matmul(1000, 1000, 1000);
519 assert_eq!(backend, Backend::Simd);
520
521 let fast_gpu_selector = BackendSelector::new().with_min_dispatch_ratio(2.0);
525 let backend = fast_gpu_selector.select_for_matmul(10000, 10000, 10000);
526 assert_eq!(backend, Backend::Gpu);
527 }
528
529 #[test]
530 fn test_selection_stats() {
531 let selector = BackendSelector::new();
532 let stats = selector.selection_stats(OpComplexity::High, 100_000);
533
534 assert_eq!(stats.backend, Backend::Gpu);
535 assert!(stats.estimated_speedup > 1.0);
536 assert!(format!("{}", stats).contains("GPU"));
537 }
538
539 #[test]
540 fn test_batch_config() {
541 let config = BatchConfig::new(50_000).with_complexity(OpComplexity::Medium);
542
543 assert_eq!(config.batch_size, 50_000);
544 assert!(config.should_use_simd());
545 assert!(!config.should_use_gpu());
546 }
547
548 #[test]
549 fn test_batch_config_gpu() {
550 let config = BatchConfig::new(500_000).with_complexity(OpComplexity::Medium);
551
552 assert!(config.should_use_gpu());
553 }
554
555 #[test]
556 fn test_batch_config_default() {
557 let config = BatchConfig::default();
558 assert_eq!(config.batch_size, 1000);
559 }
560
561 #[test]
562 fn test_custom_thresholds() {
563 let selector = BackendSelector::new()
564 .with_pcie_bandwidth(64e9)
565 .with_gpu_gflops(80e12)
566 .with_min_dispatch_ratio(3.0);
567
568 assert!(selector.pcie_bandwidth > 32e9);
570 assert!(selector.gpu_gflops > 20e12);
571 }
572
573 #[test]
574 fn test_vector_op_selection() {
575 let selector = BackendSelector::new();
576
577 let backend = selector.select_for_vector_op(100, 2);
579 assert_eq!(backend, Backend::Simd);
580
581 let backend = selector.select_for_vector_op(10_000_000, 2);
584 assert_eq!(backend, Backend::Simd);
585
586 let fast_gpu_selector = BackendSelector::new()
589 .with_min_dispatch_ratio(0.1) .with_gpu_gflops(1e12); let backend = fast_gpu_selector.select_for_vector_op(10_000_000, 100);
592 assert_eq!(backend, Backend::Gpu);
593 }
594
595 #[test]
596 fn test_op_complexity_default() {
597 assert_eq!(OpComplexity::default(), OpComplexity::Low);
598 }
599
600 #[test]
601 fn test_backend_serialization() {
602 let backend = Backend::Gpu;
603 let json = serde_json::to_string(&backend).expect("serialize");
604 assert_eq!(json, "\"Gpu\"");
605
606 let parsed: Backend = serde_json::from_str(&json).expect("deserialize");
607 assert_eq!(parsed, Backend::Gpu);
608 }
609
610 #[test]
611 fn test_complexity_serialization() {
612 let complexity = OpComplexity::High;
613 let json = serde_json::to_string(&complexity).expect("serialize");
614
615 let parsed: OpComplexity = serde_json::from_str(&json).expect("deserialize");
616 assert_eq!(parsed, OpComplexity::High);
617 }
618}