use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use wgpu::{Device, DeviceType};
#[derive(Clone)]
pub struct AdaptiveSelector {
#[allow(dead_code)]
device: Arc<Device>,
device_info: DeviceInfo,
performance_history: Arc<RwLock<PerformanceHistory>>,
#[allow(dead_code)]
config: AdaptiveConfig,
}
impl AdaptiveSelector {
pub fn new(device: Arc<Device>, device_type: DeviceType) -> Self {
let device_info = DeviceInfo::from_device(&device, device_type);
Self {
device,
device_info,
performance_history: Arc::new(RwLock::new(PerformanceHistory::new())),
config: AdaptiveConfig::default(),
}
}
pub fn select_algorithm(&self, operation: &str, workload: &WorkloadInfo) -> Algorithm {
let history = self.performance_history.read();
if let Some(best) = history.get_best_algorithm(operation, workload) {
return best;
}
drop(history);
self.select_by_heuristics(operation, workload)
}
fn select_by_heuristics(&self, operation: &str, workload: &WorkloadInfo) -> Algorithm {
match operation {
"matrix_multiply" => self.select_matmul_algorithm(workload),
"convolution" => self.select_convolution_algorithm(workload),
"reduction" => self.select_reduction_algorithm(workload),
"sort" => self.select_sort_algorithm(workload),
"fft" => self.select_fft_algorithm(workload),
_ => Algorithm::default(),
}
}
fn select_matmul_algorithm(&self, workload: &WorkloadInfo) -> Algorithm {
let size = workload.data_size;
if size < 1024 * 1024 {
Algorithm {
name: "matmul_naive".to_string(),
workgroup_size: (8, 8, 1),
strategy: ExecutionStrategy::Direct,
tuning_params: TuningParams::default(),
}
} else if size < 16 * 1024 * 1024 {
Algorithm {
name: "matmul_tiled".to_string(),
workgroup_size: (16, 16, 1),
strategy: ExecutionStrategy::Tiled { tile_size: 32 },
tuning_params: TuningParams {
use_shared_memory: true,
unroll_factor: 4,
vectorize: true,
},
}
} else {
Algorithm {
name: "matmul_hierarchical".to_string(),
workgroup_size: (16, 16, 1),
strategy: ExecutionStrategy::Hierarchical { levels: 2 },
tuning_params: TuningParams {
use_shared_memory: true,
unroll_factor: 8,
vectorize: true,
},
}
}
}
fn select_convolution_algorithm(&self, workload: &WorkloadInfo) -> Algorithm {
let kernel_size = workload.dimensions.first().copied().unwrap_or(3);
if kernel_size <= 3 {
Algorithm {
name: "conv_direct".to_string(),
workgroup_size: (8, 8, 1),
strategy: ExecutionStrategy::Direct,
tuning_params: TuningParams {
use_shared_memory: true,
unroll_factor: 1,
vectorize: false,
},
}
} else {
Algorithm {
name: "conv_im2col".to_string(),
workgroup_size: (16, 16, 1),
strategy: ExecutionStrategy::Transform,
tuning_params: TuningParams {
use_shared_memory: true,
unroll_factor: 4,
vectorize: true,
},
}
}
}
fn select_reduction_algorithm(&self, workload: &WorkloadInfo) -> Algorithm {
let compute_units = self.device_info.compute_units.unwrap_or(8);
Algorithm {
name: "reduce_hierarchical".to_string(),
workgroup_size: (256, 1, 1),
strategy: ExecutionStrategy::Hierarchical {
levels: (workload.data_size as f32).log2().ceil() as usize / 8,
},
tuning_params: TuningParams {
use_shared_memory: true,
unroll_factor: (compute_units / 8).max(1),
vectorize: true,
},
}
}
fn select_sort_algorithm(&self, workload: &WorkloadInfo) -> Algorithm {
if workload.data_size < 1024 {
Algorithm {
name: "sort_insertion".to_string(),
workgroup_size: (64, 1, 1),
strategy: ExecutionStrategy::Direct,
tuning_params: TuningParams::default(),
}
} else if workload.data_size < 1024 * 1024 {
Algorithm {
name: "sort_bitonic".to_string(),
workgroup_size: (128, 1, 1),
strategy: ExecutionStrategy::Parallel,
tuning_params: TuningParams {
use_shared_memory: true,
unroll_factor: 2,
vectorize: false,
},
}
} else {
Algorithm {
name: "sort_radix".to_string(),
workgroup_size: (256, 1, 1),
strategy: ExecutionStrategy::Hierarchical { levels: 4 },
tuning_params: TuningParams {
use_shared_memory: true,
unroll_factor: 4,
vectorize: true,
},
}
}
}
fn select_fft_algorithm(&self, workload: &WorkloadInfo) -> Algorithm {
let size = workload.data_size;
let is_power_of_2 = (size & (size - 1)) == 0;
if is_power_of_2 {
Algorithm {
name: "fft_cooley_tukey".to_string(),
workgroup_size: (256, 1, 1),
strategy: ExecutionStrategy::Hierarchical {
levels: (size as f32).log2() as usize,
},
tuning_params: TuningParams {
use_shared_memory: true,
unroll_factor: 4,
vectorize: true,
},
}
} else {
Algorithm {
name: "fft_bluestein".to_string(),
workgroup_size: (128, 1, 1),
strategy: ExecutionStrategy::Transform,
tuning_params: TuningParams {
use_shared_memory: false,
unroll_factor: 2,
vectorize: false,
},
}
}
}
pub fn record_performance(
&self,
operation: &str,
workload: &WorkloadInfo,
algorithm: &Algorithm,
duration: Duration,
) {
let mut history = self.performance_history.write();
history.record(operation, workload, algorithm, duration);
}
pub fn get_statistics(&self, operation: &str) -> Option<AlgorithmStats> {
let history = self.performance_history.read();
history.get_stats(operation)
}
}
#[derive(Debug, Clone)]
pub struct DeviceInfo {
pub device_type: DeviceType,
pub compute_units: Option<u32>,
pub memory_bandwidth: Option<f32>,
pub peak_flops: Option<f64>,
}
impl DeviceInfo {
fn from_device(_device: &Device, device_type: DeviceType) -> Self {
let compute_units = match device_type {
DeviceType::DiscreteGpu => Some(64),
DeviceType::IntegratedGpu => Some(16),
_ => None,
};
Self {
device_type,
compute_units,
memory_bandwidth: None,
peak_flops: None,
}
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct WorkloadInfo {
pub data_size: u64,
pub dimensions: Vec<u32>,
pub element_size: u32,
}
#[derive(Debug, Clone)]
pub struct Algorithm {
pub name: String,
pub workgroup_size: (u32, u32, u32),
pub strategy: ExecutionStrategy,
pub tuning_params: TuningParams,
}
impl Default for Algorithm {
fn default() -> Self {
Self {
name: "default".to_string(),
workgroup_size: (8, 8, 1),
strategy: ExecutionStrategy::Direct,
tuning_params: TuningParams::default(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum ExecutionStrategy {
Direct,
Tiled {
tile_size: u32,
},
Hierarchical {
levels: usize,
},
Transform,
Parallel,
}
#[derive(Debug, Clone)]
pub struct TuningParams {
pub use_shared_memory: bool,
pub unroll_factor: u32,
pub vectorize: bool,
}
impl Default for TuningParams {
fn default() -> Self {
Self {
use_shared_memory: false,
unroll_factor: 1,
vectorize: false,
}
}
}
struct PerformanceHistory {
records: HashMap<String, Vec<PerformanceRecord>>,
max_records_per_operation: usize,
}
impl PerformanceHistory {
fn new() -> Self {
Self {
records: HashMap::new(),
max_records_per_operation: 100,
}
}
fn record(
&mut self,
operation: &str,
workload: &WorkloadInfo,
algorithm: &Algorithm,
duration: Duration,
) {
let record = PerformanceRecord {
workload: workload.clone(),
algorithm_name: algorithm.name.clone(),
duration,
timestamp: Instant::now(),
};
let records = self.records.entry(operation.to_string()).or_default();
records.push(record);
if records.len() > self.max_records_per_operation {
records.remove(0);
}
}
fn get_best_algorithm(&self, operation: &str, workload: &WorkloadInfo) -> Option<Algorithm> {
let records = self.records.get(operation)?;
let mut similar: Vec<_> = records
.iter()
.filter(|r| Self::is_similar_workload(&r.workload, workload))
.collect();
if similar.is_empty() {
return None;
}
similar.sort_by_key(|r| r.duration);
Some(Algorithm {
name: similar[0].algorithm_name.clone(),
..Algorithm::default()
})
}
fn is_similar_workload(w1: &WorkloadInfo, w2: &WorkloadInfo) -> bool {
let size_ratio = (w1.data_size as f64) / (w2.data_size as f64);
size_ratio > 0.8 && size_ratio < 1.2 && w1.dimensions.len() == w2.dimensions.len()
}
fn get_stats(&self, operation: &str) -> Option<AlgorithmStats> {
let records = self.records.get(operation)?;
if records.is_empty() {
return None;
}
let total_duration: Duration = records.iter().map(|r| r.duration).sum();
let count = records.len() as u32;
Some(AlgorithmStats {
total_executions: count,
average_duration: total_duration / count,
total_duration,
})
}
}
#[derive(Debug, Clone)]
struct PerformanceRecord {
workload: WorkloadInfo,
algorithm_name: String,
duration: Duration,
#[allow(dead_code)]
timestamp: Instant,
}
#[derive(Debug, Clone)]
pub struct AlgorithmStats {
pub total_executions: u32,
pub average_duration: Duration,
pub total_duration: Duration,
}
#[derive(Debug, Clone)]
pub struct AdaptiveConfig {
pub auto_tune: bool,
pub min_samples: usize,
}
impl Default for AdaptiveConfig {
fn default() -> Self {
Self {
auto_tune: true,
min_samples: 3,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_workload_similarity() {
let w1 = WorkloadInfo {
data_size: 1000,
dimensions: vec![100, 10],
element_size: 4,
};
let w2 = WorkloadInfo {
data_size: 1100,
dimensions: vec![110, 10],
element_size: 4,
};
let w3 = WorkloadInfo {
data_size: 2000,
dimensions: vec![100, 20],
element_size: 4,
};
assert!(PerformanceHistory::is_similar_workload(&w1, &w2));
assert!(!PerformanceHistory::is_similar_workload(&w1, &w3));
}
#[test]
fn test_algorithm_default() {
let algo = Algorithm::default();
assert_eq!(algo.name, "default");
assert_eq!(algo.workgroup_size, (8, 8, 1));
}
}