use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Backend {
Scalar,
Simd,
Gpu,
}
impl fmt::Display for Backend {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Backend::Scalar => write!(f, "Scalar"),
Backend::Simd => write!(f, "SIMD"),
Backend::Gpu => write!(f, "GPU"),
}
}
}
impl Default for Backend {
fn default() -> Self {
Backend::Scalar
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum OpComplexity {
Low,
Medium,
High,
}
impl Default for OpComplexity {
fn default() -> Self {
OpComplexity::Low
}
}
#[derive(Debug, Clone)]
pub struct BackendSelector {
pcie_bandwidth: f64,
gpu_gflops: f64,
min_dispatch_ratio: f64,
simd_threshold_low: usize,
simd_threshold_medium: usize,
gpu_threshold_medium: usize,
simd_threshold_high: usize,
gpu_threshold_high: usize,
}
impl Default for BackendSelector {
fn default() -> Self {
Self {
pcie_bandwidth: 32e9, gpu_gflops: 20e12, min_dispatch_ratio: 5.0, simd_threshold_low: 1_000_000,
simd_threshold_medium: 10_000,
gpu_threshold_medium: 100_000,
simd_threshold_high: 1_000,
gpu_threshold_high: 10_000,
}
}
}
impl BackendSelector {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_pcie_bandwidth(mut self, bandwidth: f64) -> Self {
self.pcie_bandwidth = bandwidth;
self
}
#[must_use]
pub fn with_gpu_gflops(mut self, gflops: f64) -> Self {
self.gpu_gflops = gflops;
self
}
#[must_use]
pub fn with_min_dispatch_ratio(mut self, ratio: f64) -> Self {
self.min_dispatch_ratio = ratio;
self
}
#[must_use]
pub fn select_backend(&self, data_bytes: usize, flops: u64) -> Backend {
let transfer_s = data_bytes as f64 / self.pcie_bandwidth;
let compute_s = flops as f64 / self.gpu_gflops;
if compute_s > self.min_dispatch_ratio * transfer_s {
Backend::Gpu
} else {
Backend::Simd
}
}
#[must_use]
pub fn select_for_matmul(&self, m: usize, n: usize, k: usize) -> Backend {
let data_bytes = (m * k + k * n + m * n) * 4;
let flops = (2 * m * n * k) as u64;
self.select_backend(data_bytes, flops)
}
#[must_use]
pub fn select_for_vector_op(&self, n: usize, ops_per_element: u64) -> Backend {
let data_bytes = n * 3 * 4;
let flops = n as u64 * ops_per_element;
self.select_backend(data_bytes, flops)
}
#[must_use]
pub fn select_for_elementwise(&self, n: usize) -> Backend {
if n > self.simd_threshold_low {
Backend::Simd
} else {
Backend::Scalar
}
}
#[must_use]
pub fn select_with_moe(&self, complexity: OpComplexity, data_size: usize) -> Backend {
match complexity {
OpComplexity::Low => {
if data_size > self.simd_threshold_low {
Backend::Simd
} else {
Backend::Scalar
}
}
OpComplexity::Medium => {
if data_size > self.gpu_threshold_medium {
Backend::Gpu
} else if data_size > self.simd_threshold_medium {
Backend::Simd
} else {
Backend::Scalar
}
}
OpComplexity::High => {
if data_size > self.gpu_threshold_high {
Backend::Gpu
} else if data_size > self.simd_threshold_high {
Backend::Simd
} else {
Backend::Scalar
}
}
}
}
#[must_use]
pub fn selection_stats(&self, complexity: OpComplexity, data_size: usize) -> SelectionStats {
let backend = self.select_with_moe(complexity, data_size);
let speedup = match backend {
Backend::Scalar => 1.0,
Backend::Simd => {
match complexity {
OpComplexity::Low => 4.0,
OpComplexity::Medium => 6.0,
OpComplexity::High => 8.0,
}
}
Backend::Gpu => {
let base = match complexity {
OpComplexity::Low => 1.0, OpComplexity::Medium => 10.0,
OpComplexity::High => 50.0,
};
base * (data_size as f64 / 10_000.0).min(10.0)
}
};
SelectionStats {
backend,
complexity,
data_size,
estimated_speedup: speedup,
}
}
}
#[derive(Debug, Clone)]
pub struct SelectionStats {
pub backend: Backend,
pub complexity: OpComplexity,
pub data_size: usize,
pub estimated_speedup: f64,
}
impl fmt::Display for SelectionStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{} backend for {:?} complexity ({} elements) - ~{:.1}x speedup",
self.backend, self.complexity, self.data_size, self.estimated_speedup
)
}
}
#[derive(Debug, Clone)]
pub struct BatchConfig {
selector: BackendSelector,
pub batch_size: usize,
pub complexity: OpComplexity,
}
impl BatchConfig {
#[must_use]
pub fn new(batch_size: usize) -> Self {
Self {
selector: BackendSelector::new(),
batch_size,
complexity: OpComplexity::Low,
}
}
#[must_use]
pub fn with_complexity(mut self, complexity: OpComplexity) -> Self {
self.complexity = complexity;
self
}
#[must_use]
pub fn recommended_backend(&self) -> Backend {
self.selector
.select_with_moe(self.complexity, self.batch_size)
}
#[must_use]
pub fn should_use_gpu(&self) -> bool {
self.recommended_backend() == Backend::Gpu
}
#[must_use]
pub fn should_use_simd(&self) -> bool {
matches!(self.recommended_backend(), Backend::Simd | Backend::Gpu)
}
}
impl Default for BatchConfig {
fn default() -> Self {
Self::new(1000)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backend_display() {
assert_eq!(format!("{}", Backend::Scalar), "Scalar");
assert_eq!(format!("{}", Backend::Simd), "SIMD");
assert_eq!(format!("{}", Backend::Gpu), "GPU");
}
#[test]
fn test_backend_default() {
assert_eq!(Backend::default(), Backend::Scalar);
}
#[test]
fn test_op_complexity_ordering() {
assert!(OpComplexity::Low < OpComplexity::Medium);
assert!(OpComplexity::Medium < OpComplexity::High);
}
#[test]
fn test_selector_default() {
let selector = BackendSelector::new();
assert_eq!(selector.min_dispatch_ratio, 5.0);
}
#[test]
fn test_select_elementwise_small() {
let selector = BackendSelector::new();
let backend = selector.select_for_elementwise(100);
assert_eq!(backend, Backend::Scalar);
}
#[test]
fn test_select_elementwise_large() {
let selector = BackendSelector::new();
let backend = selector.select_for_elementwise(10_000_000);
assert_eq!(backend, Backend::Simd);
}
#[test]
fn test_moe_low_complexity() {
let selector = BackendSelector::new();
assert_eq!(
selector.select_with_moe(OpComplexity::Low, 100),
Backend::Scalar
);
assert_eq!(
selector.select_with_moe(OpComplexity::Low, 10_000_000),
Backend::Simd
);
}
#[test]
fn test_moe_medium_complexity() {
let selector = BackendSelector::new();
assert_eq!(
selector.select_with_moe(OpComplexity::Medium, 100),
Backend::Scalar
);
assert_eq!(
selector.select_with_moe(OpComplexity::Medium, 50_000),
Backend::Simd
);
assert_eq!(
selector.select_with_moe(OpComplexity::Medium, 500_000),
Backend::Gpu
);
}
#[test]
fn test_moe_high_complexity() {
let selector = BackendSelector::new();
assert_eq!(
selector.select_with_moe(OpComplexity::High, 100),
Backend::Scalar
);
assert_eq!(
selector.select_with_moe(OpComplexity::High, 5_000),
Backend::Simd
);
assert_eq!(
selector.select_with_moe(OpComplexity::High, 50_000),
Backend::Gpu
);
}
#[test]
fn test_select_matmul_small() {
let selector = BackendSelector::new();
let backend = selector.select_for_matmul(10, 10, 10);
assert_eq!(backend, Backend::Simd);
}
#[test]
fn test_select_matmul_large() {
let selector = BackendSelector::new();
let backend = selector.select_for_matmul(1000, 1000, 1000);
assert_eq!(backend, Backend::Simd);
let fast_gpu_selector = BackendSelector::new().with_min_dispatch_ratio(2.0);
let backend = fast_gpu_selector.select_for_matmul(10000, 10000, 10000);
assert_eq!(backend, Backend::Gpu);
}
#[test]
fn test_selection_stats() {
let selector = BackendSelector::new();
let stats = selector.selection_stats(OpComplexity::High, 100_000);
assert_eq!(stats.backend, Backend::Gpu);
assert!(stats.estimated_speedup > 1.0);
assert!(format!("{}", stats).contains("GPU"));
}
#[test]
fn test_batch_config() {
let config = BatchConfig::new(50_000).with_complexity(OpComplexity::Medium);
assert_eq!(config.batch_size, 50_000);
assert!(config.should_use_simd());
assert!(!config.should_use_gpu());
}
#[test]
fn test_batch_config_gpu() {
let config = BatchConfig::new(500_000).with_complexity(OpComplexity::Medium);
assert!(config.should_use_gpu());
}
#[test]
fn test_batch_config_default() {
let config = BatchConfig::default();
assert_eq!(config.batch_size, 1000);
}
#[test]
fn test_custom_thresholds() {
let selector = BackendSelector::new()
.with_pcie_bandwidth(64e9)
.with_gpu_gflops(80e12)
.with_min_dispatch_ratio(3.0);
assert!(selector.pcie_bandwidth > 32e9);
assert!(selector.gpu_gflops > 20e12);
}
#[test]
fn test_vector_op_selection() {
let selector = BackendSelector::new();
let backend = selector.select_for_vector_op(100, 2);
assert_eq!(backend, Backend::Simd);
let backend = selector.select_for_vector_op(10_000_000, 2);
assert_eq!(backend, Backend::Simd);
let fast_gpu_selector = BackendSelector::new()
.with_min_dispatch_ratio(0.1) .with_gpu_gflops(1e12); let backend = fast_gpu_selector.select_for_vector_op(10_000_000, 100);
assert_eq!(backend, Backend::Gpu);
}
#[test]
fn test_op_complexity_default() {
assert_eq!(OpComplexity::default(), OpComplexity::Low);
}
#[test]
fn test_backend_serialization() {
let backend = Backend::Gpu;
let json = serde_json::to_string(&backend).expect("serialize");
assert_eq!(json, "\"Gpu\"");
let parsed: Backend = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed, Backend::Gpu);
}
#[test]
fn test_complexity_serialization() {
let complexity = OpComplexity::High;
let json = serde_json::to_string(&complexity).expect("serialize");
let parsed: OpComplexity = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed, OpComplexity::High);
}
}