use super::unified_kernel::{UnifiedKernelExecutor, KernelOp, KernelParams, KernelMetrics};
use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use num_traits::Float;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
use std::time::Duration;
pub struct PerformanceDatabase {
records: HashMap<(DeviceType, KernelOp, usize), KernelMetrics>,
}
impl PerformanceDatabase {
pub fn new() -> Self {
Self {
records: HashMap::new(),
}
}
pub fn record(&mut self, device: DeviceType, op: KernelOp, input_size: usize, metrics: KernelMetrics) {
let key = (device, op, input_size);
self.records.insert(key, metrics);
}
pub fn get_best_performance(&self, op: KernelOp, input_size: usize) -> Option<(DeviceType, &KernelMetrics)> {
self.records
.iter()
.filter(|((_, stored_op, stored_size), _)| *stored_op == op && *stored_size == input_size)
.min_by_key(|(_, metrics)| metrics.execution_time)
.map(|((device, _, _), metrics)| (*device, metrics))
}
pub fn get_performance(&self, device: DeviceType, op: KernelOp, input_size: usize) -> Option<&KernelMetrics> {
let key = (device, op, input_size);
self.records.get(&key)
}
}
impl Default for PerformanceDatabase {
fn default() -> Self {
Self::new()
}
}
pub struct WorkloadProfile {
pub total_elements: usize,
pub memory_requirement: usize,
pub compute_intensity: f64,
pub parallelization: f64,
}
impl WorkloadProfile {
pub fn analyze<T: Float>(inputs: &[&Tensor<T>], op: KernelOp) -> Self {
let total_elements = inputs.iter()
.map(|t| t.size())
.sum();
let memory_requirement = total_elements * std::mem::size_of::<T>();
let compute_intensity = Self::estimate_compute_intensity(op, inputs);
let parallelization = Self::estimate_parallelization(op, inputs);
Self {
total_elements,
memory_requirement,
compute_intensity,
parallelization,
}
}
fn estimate_compute_intensity<T: Float>(op: KernelOp, inputs: &[&Tensor<T>]) -> f64 {
match op {
KernelOp::Add | KernelOp::Sub | KernelOp::Mul | KernelOp::Div => 1.0, KernelOp::MatMul => {
if inputs.len() >= 2 {
let m = inputs[0].shape().get(0).unwrap_or(&1);
let n = inputs[1].shape().get(1).unwrap_or(&1);
let k = inputs[0].shape().get(1).unwrap_or(&1);
(*m * n * k) as f64 / (inputs[0].size() + inputs[1].size()) as f64
} else {
1.0
}
},
KernelOp::Conv2D => 10.0, KernelOp::BatchNorm => 3.0, _ => 2.0, }
}
fn estimate_parallelization<T: Float>(op: KernelOp, inputs: &[&Tensor<T>]) -> f64 {
match op {
KernelOp::Add | KernelOp::Sub | KernelOp::Mul | KernelOp::Div => 1.0, KernelOp::MatMul => 0.9, KernelOp::Conv2D => 0.95, KernelOp::ReduceSum | KernelOp::ReduceMean => 0.7, KernelOp::BatchNorm => 0.8, _ => 0.8, }
}
}
pub enum SelectionStrategy {
FastestDevice,
EnergyEfficient,
Balanced,
PreferDevice(DeviceType),
}
pub struct KernelSelector {
executors: Vec<Box<dyn UnifiedKernelExecutor>>,
performance_db: Arc<RwLock<PerformanceDatabase>>,
strategy: SelectionStrategy,
benchmark_cache: Arc<Mutex<HashMap<DeviceType, f64>>>,
}
impl KernelSelector {
pub fn new(strategy: SelectionStrategy) -> Self {
Self {
executors: Vec::new(),
performance_db: Arc::new(RwLock::new(PerformanceDatabase::new())),
strategy,
benchmark_cache: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn add_executor(&mut self, executor: Box<dyn UnifiedKernelExecutor>) {
self.executors.push(executor);
}
pub fn select_executor<T: Float>(
&self,
op: KernelOp,
inputs: &[&Tensor<T>],
params: &KernelParams
) -> RusTorchResult<&dyn UnifiedKernelExecutor> {
if self.executors.is_empty() {
return Err(RusTorchError::DeviceNotFound(0));
}
let workload = WorkloadProfile::analyze(inputs, op);
match self.strategy {
SelectionStrategy::FastestDevice => self.select_fastest(op, &workload),
SelectionStrategy::EnergyEfficient => self.select_energy_efficient(op, &workload),
SelectionStrategy::Balanced => self.select_balanced(op, &workload),
SelectionStrategy::PreferDevice(device_type) => self.select_preferred(device_type, op),
}
}
pub fn execute<T: Float + 'static + Send + Sync>(
&self,
op: KernelOp,
inputs: &[&Tensor<T>],
params: &KernelParams
) -> RusTorchResult<Tensor<T>> {
let executor = self.select_executor(op, inputs, params)?;
let result = executor.execute(op, inputs, params)?;
let metrics = executor.get_metrics();
let workload = WorkloadProfile::analyze(inputs, op);
if let Ok(mut db) = self.performance_db.write() {
db.record(executor.device_type(), op, workload.total_elements, metrics);
}
Ok(result)
}
fn select_fastest<T: Float>(
&self,
op: KernelOp,
workload: &WorkloadProfile
) -> RusTorchResult<&dyn UnifiedKernelExecutor> {
if let Ok(db) = self.performance_db.read() {
if let Some((best_device, _)) = db.get_best_performance(op, workload.total_elements) {
if let Some(executor) = self.find_executor(best_device) {
return Ok(executor);
}
}
}
self.select_by_heuristics(op, workload)
}
fn select_energy_efficient<T: Float>(
&self,
op: KernelOp,
workload: &WorkloadProfile
) -> RusTorchResult<&dyn UnifiedKernelExecutor> {
if workload.total_elements < 1000 {
if let Some(executor) = self.find_executor(DeviceType::Cpu) {
if executor.supports_operation(op) {
return Ok(executor);
}
}
}
if let Some(executor) = self.find_executor(DeviceType::Metal(0)) {
if executor.supports_operation(op) {
return Ok(executor);
}
}
self.select_fastest::<T>(op, workload)
}
fn select_balanced<T: Float>(
&self,
op: KernelOp,
workload: &WorkloadProfile
) -> RusTorchResult<&dyn UnifiedKernelExecutor> {
let score_threshold = workload.compute_intensity * workload.parallelization;
if score_threshold > 5.0 {
for device in &[DeviceType::Cuda(0), DeviceType::Metal(0), DeviceType::OpenCL(0)] {
if let Some(executor) = self.find_executor(*device) {
if executor.supports_operation(op) {
return Ok(executor);
}
}
}
}
if let Some(executor) = self.find_executor(DeviceType::Cpu) {
if executor.supports_operation(op) {
return Ok(executor);
}
}
self.select_any_available(op)
}
fn select_preferred(
&self,
preferred_device: DeviceType,
op: KernelOp
) -> RusTorchResult<&dyn UnifiedKernelExecutor> {
if let Some(executor) = self.find_executor(preferred_device) {
if executor.supports_operation(op) {
return Ok(executor);
}
}
self.select_any_available(op)
}
fn select_by_heuristics<T: Float>(
&self,
op: KernelOp,
workload: &WorkloadProfile
) -> RusTorchResult<&dyn UnifiedKernelExecutor> {
match op {
KernelOp::MatMul if workload.total_elements > 10000 => {
if let Some(executor) = self.find_executor(DeviceType::Cuda(0)) {
if executor.supports_operation(op) {
return Ok(executor);
}
}
},
KernelOp::Conv2D => {
if let Some(executor) = self.find_executor(DeviceType::Cuda(0)) {
if executor.supports_operation(op) {
return Ok(executor);
}
}
},
KernelOp::Add | KernelOp::Mul if workload.total_elements < 1000 => {
if let Some(executor) = self.find_executor(DeviceType::Cpu) {
if executor.supports_operation(op) {
return Ok(executor);
}
}
},
_ => {}
}
self.select_any_available(op)
}
fn select_any_available(&self, op: KernelOp) -> RusTorchResult<&dyn UnifiedKernelExecutor> {
for executor in &self.executors {
if executor.supports_operation(op) {
return Ok(executor.as_ref());
}
}
Err(RusTorchError::UnsupportedOperation(format!("No executor supports operation {:?}", op)))
}
fn find_executor(&self, device_type: DeviceType) -> Option<&dyn UnifiedKernelExecutor> {
self.executors
.iter()
.find(|e| e.device_type() == device_type)
.map(|e| e.as_ref())
}
pub fn benchmark_devices(&self) -> RusTorchResult<()> {
let benchmark_op = KernelOp::MatMul;
let size = 100;
let test_data: Vec<f32> = (0..size*size).map(|i| i as f32).collect();
let a = Tensor::from_vec(test_data.clone(), vec![size, size]);
let b = Tensor::from_vec(test_data, vec![size, size]);
let params = KernelParams::default();
let mut benchmark_results = HashMap::new();
for executor in &self.executors {
if executor.supports_operation(benchmark_op) {
let mut total_time = Duration::ZERO;
let iterations = 5;
for _ in 0..iterations {
let start = std::time::Instant::now();
let _result = executor.execute(benchmark_op, &[&a, &b], ¶ms);
total_time += start.elapsed();
}
let avg_time = total_time / iterations as u32;
let performance_score = 1000.0 / avg_time.as_millis() as f64;
benchmark_results.insert(executor.device_type(), performance_score);
}
}
if let Ok(mut cache) = self.benchmark_cache.lock() {
*cache = benchmark_results;
}
Ok(())
}
pub fn get_benchmark_score(&self, device: DeviceType) -> Option<f64> {
self.benchmark_cache
.lock()
.ok()
.and_then(|cache| cache.get(&device).copied())
}
pub fn set_strategy(&mut self, strategy: SelectionStrategy) {
self.strategy = strategy;
}
pub fn strategy(&self) -> SelectionStrategy {
self.strategy
}
pub fn performance_database(&self) -> Arc<RwLock<PerformanceDatabase>> {
Arc::clone(&self.performance_db)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::gpu::unified_kernel::CpuFallbackExecutor;
#[test]
fn test_performance_database() {
let mut db = PerformanceDatabase::new();
let metrics = KernelMetrics {
execution_time: Duration::from_millis(10),
memory_bandwidth: 100.0,
occupancy: 80.0,
flops: 1000.0,
};
db.record(DeviceType::Cpu, KernelOp::Add, 1000, metrics.clone());
let retrieved = db.get_performance(DeviceType::Cpu, KernelOp::Add, 1000);
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().execution_time, Duration::from_millis(10));
}
#[test]
fn test_workload_analysis() {
let a = Tensor::from_vec(vec![1.0f32; 100], vec![10, 10]);
let b = Tensor::from_vec(vec![1.0f32; 100], vec![10, 10]);
let workload = WorkloadProfile::analyze(&[&a, &b], KernelOp::Add);
assert_eq!(workload.total_elements, 200);
assert_eq!(workload.memory_requirement, 200 * 4); assert_eq!(workload.compute_intensity, 1.0); assert_eq!(workload.parallelization, 1.0); }
#[test]
fn test_kernel_selector_creation() {
let mut selector = KernelSelector::new(SelectionStrategy::FastestDevice);
let executor = Box::new(CpuFallbackExecutor::new());
selector.add_executor(executor);
assert_eq!(selector.executors.len(), 1);
assert_eq!(selector.strategy(), SelectionStrategy::FastestDevice);
}
#[test]
fn test_selection_strategies() {
let strategies = [
SelectionStrategy::FastestDevice,
SelectionStrategy::EnergyEfficient,
SelectionStrategy::Balanced,
SelectionStrategy::PreferDevice(DeviceType::Cpu),
];
for strategy in &strategies {
let selector = KernelSelector::new(*strategy);
assert_eq!(selector.strategy(), *strategy);
}
}
#[test]
fn test_cpu_execution_through_selector() {
let mut selector = KernelSelector::new(SelectionStrategy::FastestDevice);
selector.add_executor(Box::new(CpuFallbackExecutor::new()));
let a = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], vec![3]);
let b = Tensor::from_vec(vec![4.0f32, 5.0, 6.0], vec![3]);
let params = KernelParams::default();
let result = selector.execute(KernelOp::Add, &[&a, &b], ¶ms).unwrap();
let expected = vec![5.0f32, 7.0, 9.0];
assert_eq!(result.as_slice().unwrap(), &expected);
}
}