use super::registry::{Kernel, OpRegistry, OpVersion, OP_REGISTRY};
use crate::{DType, Device, Result, TensorError};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct DeviceCapabilities {
pub device: Device,
pub available_memory: usize,
pub compute_capability: Option<(u32, u32)>,
pub max_workgroup_size: Option<usize>,
pub supports_fp16: bool,
pub supports_bf16: bool,
pub supports_tensor_cores: bool,
}
impl DeviceCapabilities {
pub fn for_device(device: Device) -> Self {
match device {
Device::Cpu => Self {
device,
available_memory: 0, compute_capability: None,
max_workgroup_size: None,
supports_fp16: true,
supports_bf16: true,
supports_tensor_cores: false,
},
#[cfg(feature = "gpu")]
Device::Gpu(_) => Self {
device,
available_memory: 0, compute_capability: Some((8, 0)), max_workgroup_size: Some(1024),
supports_fp16: true,
supports_bf16: true,
supports_tensor_cores: true,
},
#[cfg(feature = "rocm")]
Device::Rocm(_) => Self {
device,
available_memory: 0,
compute_capability: Some((9, 0)),
max_workgroup_size: Some(1024),
supports_fp16: true,
supports_bf16: true,
supports_tensor_cores: false,
},
}
}
pub fn supports_dtype(&self, dtype: DType) -> bool {
match dtype {
DType::Float16 => self.supports_fp16,
DType::BFloat16 => self.supports_bf16,
_ => true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KernelSelectionStrategy {
Performance,
MemoryEfficient,
Balanced,
Compatible,
}
#[derive(Debug, Clone, Default)]
pub struct KernelStats {
pub execution_count: u64,
pub total_time: Duration,
pub avg_time: Duration,
pub last_execution: Option<Instant>,
pub success_count: u64,
pub failure_count: u64,
}
impl KernelStats {
pub fn record_success(&mut self, duration: Duration) {
self.execution_count += 1;
self.success_count += 1;
self.total_time += duration;
self.avg_time = self.total_time / self.execution_count as u32;
self.last_execution = Some(Instant::now());
}
pub fn record_failure(&mut self) {
self.execution_count += 1;
self.failure_count += 1;
}
pub fn success_rate(&self) -> f64 {
if self.execution_count == 0 {
0.0
} else {
self.success_count as f64 / self.execution_count as f64
}
}
}
pub struct EnhancedRegistry {
base: &'static OpRegistry,
device_capabilities: Mutex<HashMap<Device, DeviceCapabilities>>,
kernel_stats: Mutex<HashMap<String, KernelStats>>,
strategy: Mutex<KernelSelectionStrategy>,
warmed_kernels: Mutex<HashMap<String, Arc<dyn Kernel>>>,
}
impl EnhancedRegistry {
pub fn new() -> Self {
Self {
base: &OP_REGISTRY,
device_capabilities: Mutex::new(HashMap::new()),
kernel_stats: Mutex::new(HashMap::new()),
strategy: Mutex::new(KernelSelectionStrategy::Balanced),
warmed_kernels: Mutex::new(HashMap::new()),
}
}
pub fn get_device_capabilities(&self, device: Device) -> DeviceCapabilities {
let mut caps = self
.device_capabilities
.lock()
.expect("lock should not be poisoned");
caps.entry(device)
.or_insert_with(|| DeviceCapabilities::for_device(device))
.clone()
}
pub fn set_strategy(&self, strategy: KernelSelectionStrategy) {
*self.strategy.lock().expect("lock should not be poisoned") = strategy;
}
pub fn get_kernel_smart(
&self,
op_name: &str,
preferred_device: Device,
dtype: DType,
) -> Result<Arc<dyn Kernel>> {
if let Some(kernel) = self.base.get_kernel(op_name, preferred_device, dtype) {
return Ok(kernel);
}
let caps = self.get_device_capabilities(preferred_device);
if !caps.supports_dtype(dtype) {
return Err(TensorError::unsupported_device(
op_name,
&format!("{:?}", preferred_device),
true,
));
}
if preferred_device != Device::Cpu {
if let Some(kernel) = self.base.get_kernel(op_name, Device::Cpu, dtype) {
return Ok(kernel);
}
}
Err(TensorError::not_implemented_simple(format!(
"No kernel available for operation '{}' on {:?} with {:?}",
op_name, preferred_device, dtype
)))
}
pub fn warm_kernels(&self, ops: &[(String, Device, DType)]) {
let mut warmed = self
.warmed_kernels
.lock()
.expect("lock should not be poisoned");
for (op_name, device, dtype) in ops {
let cache_key = format!("{}_{}_{:?}_{:?}", op_name, "warmed", device, dtype);
if let Some(kernel) = self.base.get_kernel(op_name, *device, *dtype) {
warmed.insert(cache_key, kernel);
}
}
}
pub fn get_warmed_kernel(
&self,
op_name: &str,
device: Device,
dtype: DType,
) -> Option<Arc<dyn Kernel>> {
let cache_key = format!("{}_{}_{:?}_{:?}", op_name, "warmed", device, dtype);
let warmed = self
.warmed_kernels
.lock()
.expect("lock should not be poisoned");
warmed.get(&cache_key).cloned()
}
pub fn record_execution(
&self,
op_name: &str,
device: Device,
dtype: DType,
duration: Duration,
success: bool,
) {
let key = format!("{}_{:?}_{:?}", op_name, device, dtype);
let mut stats = self
.kernel_stats
.lock()
.expect("lock should not be poisoned");
let entry = stats.entry(key).or_insert_with(KernelStats::default);
if success {
entry.record_success(duration);
} else {
entry.record_failure();
}
}
pub fn get_kernel_stats(&self, op_name: &str, device: Device, dtype: DType) -> KernelStats {
let key = format!("{}_{:?}_{:?}", op_name, device, dtype);
let stats = self
.kernel_stats
.lock()
.expect("lock should not be poisoned");
stats.get(&key).cloned().unwrap_or_default()
}
pub fn get_all_stats(&self) -> HashMap<String, KernelStats> {
self.kernel_stats
.lock()
.expect("lock should not be poisoned")
.clone()
}
pub fn find_optimal_device(&self, op_name: &str, dtype: DType, data_size: usize) -> Device {
let strategy = *self.strategy.lock().expect("lock should not be poisoned");
match strategy {
KernelSelectionStrategy::Performance => {
#[cfg(feature = "gpu")]
if data_size > 10_000
&& self
.base
.get_kernel(op_name, Device::Gpu(0), dtype)
.is_some()
{
return Device::Gpu(0);
}
Device::Cpu
}
KernelSelectionStrategy::MemoryEfficient => {
Device::Cpu
}
KernelSelectionStrategy::Balanced => {
#[cfg(feature = "gpu")]
if data_size > 100_000
&& self
.base
.get_kernel(op_name, Device::Gpu(0), dtype)
.is_some()
{
return Device::Gpu(0);
}
Device::Cpu
}
KernelSelectionStrategy::Compatible => {
Device::Cpu
}
}
}
pub fn suggest_optimizations(&self) -> Vec<String> {
let mut suggestions = Vec::new();
let stats = self
.kernel_stats
.lock()
.expect("lock should not be poisoned");
for (key, stat) in stats.iter() {
if stat.execution_count > 100 && !key.contains("warmed") {
suggestions.push(format!(
"Consider warming kernel '{}' (executed {} times)",
key, stat.execution_count
));
}
if stat.failure_count > 10 && stat.success_rate() < 0.5 {
suggestions.push(format!(
"Kernel '{}' has high failure rate ({:.1}%), consider using CPU fallback",
key,
(1.0 - stat.success_rate()) * 100.0
));
}
if key.contains("Cpu") && stat.avg_time > Duration::from_millis(100) {
suggestions.push(format!(
"Kernel '{}' is slow on CPU (avg {:.2}ms), consider GPU acceleration",
key,
stat.avg_time.as_secs_f64() * 1000.0
));
}
}
suggestions
}
pub fn reset_statistics(&self) {
self.kernel_stats
.lock()
.expect("lock should not be poisoned")
.clear();
}
pub fn generate_performance_report(&self) -> PerformanceReport {
let stats = self
.kernel_stats
.lock()
.expect("lock should not be poisoned");
let total_executions: u64 = stats.values().map(|s| s.execution_count).sum();
let total_successes: u64 = stats.values().map(|s| s.success_count).sum();
let total_failures: u64 = stats.values().map(|s| s.failure_count).sum();
let mut slowest_kernels: Vec<_> =
stats.iter().map(|(k, s)| (k.clone(), s.avg_time)).collect();
slowest_kernels.sort_by_key(|a| std::cmp::Reverse(a.1));
slowest_kernels.truncate(10);
let mut most_used: Vec<_> = stats
.iter()
.map(|(k, s)| (k.clone(), s.execution_count))
.collect();
most_used.sort_by_key(|a| std::cmp::Reverse(a.1));
most_used.truncate(10);
PerformanceReport {
total_executions,
total_successes,
total_failures,
overall_success_rate: if total_executions > 0 {
total_successes as f64 / total_executions as f64
} else {
0.0
},
slowest_kernels,
most_used_kernels: most_used,
optimization_suggestions: self.suggest_optimizations(),
}
}
}
#[derive(Debug, Clone)]
pub struct PerformanceReport {
pub total_executions: u64,
pub total_successes: u64,
pub total_failures: u64,
pub overall_success_rate: f64,
pub slowest_kernels: Vec<(String, Duration)>,
pub most_used_kernels: Vec<(String, u64)>,
pub optimization_suggestions: Vec<String>,
}
impl PerformanceReport {
pub fn print(&self) {
println!("=== Registry Performance Report ===");
println!("\nOverall Statistics:");
println!(" Total Executions: {}", self.total_executions);
println!(" Successes: {}", self.total_successes);
println!(" Failures: {}", self.total_failures);
println!(
" Success Rate: {:.2}%",
self.overall_success_rate * 100.0
);
println!("\nTop 10 Slowest Kernels:");
for (i, (name, time)) in self.slowest_kernels.iter().enumerate() {
println!(
" {}: {} ({:.2}ms avg)",
i + 1,
name,
time.as_secs_f64() * 1000.0
);
}
println!("\nTop 10 Most Used Kernels:");
for (i, (name, count)) in self.most_used_kernels.iter().enumerate() {
println!(" {}: {} ({} executions)", i + 1, name, count);
}
if !self.optimization_suggestions.is_empty() {
println!("\n💡 Optimization Suggestions:");
for suggestion in &self.optimization_suggestions {
println!(" • {}", suggestion);
}
}
println!("\n===================================");
}
}
impl Default for EnhancedRegistry {
fn default() -> Self {
Self::new()
}
}
lazy_static::lazy_static! {
pub static ref ENHANCED_REGISTRY: EnhancedRegistry = EnhancedRegistry::new();
}
pub fn get_kernel_smart(
op_name: &str,
preferred_device: Device,
dtype: DType,
) -> Result<Arc<dyn Kernel>> {
ENHANCED_REGISTRY.get_kernel_smart(op_name, preferred_device, dtype)
}
pub fn warm_kernels(ops: &[(String, Device, DType)]) {
ENHANCED_REGISTRY.warm_kernels(ops);
}
pub fn generate_performance_report() -> PerformanceReport {
ENHANCED_REGISTRY.generate_performance_report()
}
pub fn print_performance_report() {
generate_performance_report().print();
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_capabilities() {
let cpu_caps = DeviceCapabilities::for_device(Device::Cpu);
assert!(!cpu_caps.supports_tensor_cores);
assert!(cpu_caps.supports_fp16);
}
#[test]
fn test_kernel_stats() {
let mut stats = KernelStats::default();
stats.record_success(Duration::from_millis(10));
stats.record_success(Duration::from_millis(20));
stats.record_failure();
assert_eq!(stats.execution_count, 3);
assert_eq!(stats.success_count, 2);
assert_eq!(stats.failure_count, 1);
assert!((stats.success_rate() - 0.666).abs() < 0.01);
}
#[test]
fn test_enhanced_registry() {
let registry = EnhancedRegistry::new();
registry.set_strategy(KernelSelectionStrategy::Performance);
let optimal = registry.find_optimal_device("matmul", DType::Float32, 1_000_000);
#[cfg(feature = "gpu")]
assert!(matches!(optimal, Device::Cpu | Device::Gpu(_)));
#[cfg(not(feature = "gpu"))]
assert!(matches!(optimal, Device::Cpu));
}
}