use super::types::{
AttrValue, BatchOperation, Kernel, KernelKey, OpDef, OpKey, OpRegistry, OpVersion,
RegistryMetrics, UltraKernelScheduler,
};
use crate::{DType, Device, Result, TensorError};
use rayon::prelude::*;
use scirs2_core::metrics::{Counter, Histogram, Timer};
use std::any::Any;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
impl OpRegistry {
pub fn new() -> Self {
let metrics = RegistryMetrics {
op_lookups: Counter::new("registry.op_lookups".to_string()),
kernel_executions: Counter::new("registry.kernel_executions".to_string()),
cache_hit_ratio: Histogram::new("registry.cache_hit_ratio".to_string()),
execution_timer: Timer::new("registry.execution_time".to_string()),
batch_operations: Counter::new("registry.batch_operations".to_string()),
simd_accelerated_ops: Counter::new("registry.simd_accelerated".to_string()),
};
Self {
ops: RwLock::new(HashMap::new()),
kernels: RwLock::new(HashMap::new()),
latest_versions: RwLock::new(HashMap::new()),
op_cache: RwLock::new(HashMap::new()),
kernel_cache: RwLock::new(HashMap::new()),
metrics,
batch_queue: RwLock::new(Vec::new()),
scheduler: RwLock::new(UltraKernelScheduler {
execution_history: HashMap::new(),
cpu_utilization: AtomicU64::new(0),
gpu_utilization: AtomicU64::new(0),
optimal_batch_sizes: HashMap::new(),
hot_operations: HashMap::new(),
}),
}
}
pub fn register_op(&self, op_def: OpDef) -> Result<()> {
let op_key = OpKey {
name: op_def.name.clone(),
version: op_def.version.clone(),
};
let mut ops = self.ops.write().expect("write lock should not be poisoned");
let mut latest_versions = self
.latest_versions
.write()
.expect("write lock should not be poisoned");
if ops.contains_key(&op_key) {
return Err(TensorError::invalid_argument(format!(
"Operation '{}' version {} already registered",
op_def.name, op_def.version
)));
}
let is_newer = latest_versions
.get(&op_def.name)
.map(|existing| op_def.version > *existing)
.unwrap_or(true);
if is_newer {
latest_versions.insert(op_def.name.clone(), op_def.version.clone());
}
ops.insert(op_key, op_def);
Ok(())
}
pub fn register_kernel(
&self,
op_name: &str,
device: Device,
dtype: DType,
kernel: Arc<dyn Kernel>,
) -> Result<()> {
let version = {
let latest_versions = self
.latest_versions
.read()
.expect("read lock should not be poisoned");
latest_versions.get(op_name).cloned().ok_or_else(|| {
TensorError::invalid_argument(format!("Operation '{op_name}' not registered"))
})?
};
self.register_kernel_version(op_name, &version, device, dtype, kernel)
}
pub fn register_kernel_version(
&self,
op_name: &str,
version: &OpVersion,
device: Device,
dtype: DType,
kernel: Arc<dyn Kernel>,
) -> Result<()> {
{
let ops = self.ops.read().expect("read lock should not be poisoned");
let op_key = OpKey {
name: op_name.to_string(),
version: version.clone(),
};
if !ops.contains_key(&op_key) {
return Err(TensorError::invalid_argument(format!(
"Operation '{op_name}' version {version} not registered"
)));
}
}
let key = KernelKey {
op: op_name.to_string(),
version: version.clone(),
device,
dtype,
};
let mut kernels = self
.kernels
.write()
.expect("write lock should not be poisoned");
if kernels.contains_key(&key) {
return Err(TensorError::invalid_argument(format!(
"Kernel for '{op_name}' v{version} on {device:?} with {dtype:?} already registered"
)));
}
kernels.insert(key, kernel);
Ok(())
}
pub fn get_op(&self, name: &str) -> Option<OpDef> {
self.metrics.op_lookups.inc();
let _timer = self.metrics.execution_timer.start();
{
let cache = self
.op_cache
.read()
.expect("read lock should not be poisoned");
if let Some(cached_op) = cache.get(name) {
self.metrics.cache_hit_ratio.observe(1.0);
return Some((**cached_op).clone());
}
}
self.metrics.cache_hit_ratio.observe(0.0);
let latest_version = {
let latest_versions = self
.latest_versions
.read()
.expect("read lock should not be poisoned");
latest_versions.get(name).cloned()?
};
let op_def = self.get_op_version(name, &latest_version)?;
{
let mut cache = self
.op_cache
.write()
.expect("write lock should not be poisoned");
cache.insert(name.to_string(), Arc::new(op_def.clone()));
}
Some(op_def)
}
pub fn get_op_version(&self, name: &str, version: &OpVersion) -> Option<OpDef> {
let ops = self.ops.read().expect("read lock should not be poisoned");
let op_key = OpKey {
name: name.to_string(),
version: version.clone(),
};
ops.get(&op_key).cloned()
}
pub fn get_op_compatible(&self, name: &str, required_version: &OpVersion) -> Option<OpDef> {
let ops = self.ops.read().expect("read lock should not be poisoned");
let mut compatible_versions: Vec<_> = ops
.keys()
.filter(|key| key.name == name)
.filter(|key| key.version.is_compatible_with(required_version))
.collect();
compatible_versions.sort_by(|a, b| b.version.cmp(&a.version));
compatible_versions
.first()
.and_then(|key| ops.get(key).cloned())
}
pub fn get_kernel(
&self,
op_name: &str,
device: Device,
dtype: DType,
) -> Option<Arc<dyn Kernel>> {
self.metrics.kernel_executions.inc();
let _timer = self.metrics.execution_timer.start();
let cache_key = format!("{}_{}_{:?}_{:?}", op_name, "latest", device, dtype);
{
let cache = self
.kernel_cache
.read()
.expect("read lock should not be poisoned");
if let Some(cached_kernel) = cache.get(&cache_key) {
self.metrics.cache_hit_ratio.observe(1.0);
self.track_hot_operation(op_name);
return Some(cached_kernel.clone());
}
}
self.metrics.cache_hit_ratio.observe(0.0);
let latest_version = {
let latest_versions = self
.latest_versions
.read()
.expect("read lock should not be poisoned");
latest_versions.get(op_name).cloned()?
};
let kernel = self.get_kernel_version(op_name, &latest_version, device, dtype)?;
{
let mut cache = self
.kernel_cache
.write()
.expect("write lock should not be poisoned");
cache.insert(cache_key, kernel.clone());
}
Some(kernel)
}
pub fn get_kernel_version(
&self,
op_name: &str,
version: &OpVersion,
device: Device,
dtype: DType,
) -> Option<Arc<dyn Kernel>> {
let key = KernelKey {
op: op_name.to_string(),
version: version.clone(),
device,
dtype,
};
let kernels = self
.kernels
.read()
.expect("read lock should not be poisoned");
kernels.get(&key).cloned()
}
pub fn get_kernel_compatible(
&self,
op_name: &str,
required_version: &OpVersion,
device: Device,
dtype: DType,
) -> Option<Arc<dyn Kernel>> {
let kernels = self
.kernels
.read()
.expect("read lock should not be poisoned");
let mut compatible_kernels: Vec<_> = kernels
.keys()
.filter(|key| key.op == op_name && key.device == device && key.dtype == dtype)
.filter(|key| key.version.is_compatible_with(required_version))
.collect();
compatible_kernels.sort_by(|a, b| b.version.cmp(&a.version));
compatible_kernels
.first()
.and_then(|key| kernels.get(key).cloned())
}
pub fn list_ops(&self) -> Vec<String> {
let latest_versions = self
.latest_versions
.read()
.expect("read lock should not be poisoned");
latest_versions.keys().cloned().collect()
}
pub fn list_op_versions(&self, name: &str) -> Vec<OpVersion> {
let ops = self.ops.read().expect("read lock should not be poisoned");
let mut versions: Vec<_> = ops
.keys()
.filter(|key| key.name == name)
.map(|key| key.version.clone())
.collect();
versions.sort();
versions
}
pub fn get_latest_version(&self, name: &str) -> Option<OpVersion> {
let latest_versions = self
.latest_versions
.read()
.expect("read lock should not be poisoned");
latest_versions.get(name).cloned()
}
pub fn execute_batch_operations(&self, operations: Vec<BatchOperation>) -> Result<Vec<String>> {
let _timer = self.metrics.execution_timer.start();
self.metrics.batch_operations.add(operations.len() as u64);
let mut sorted_ops = operations;
sorted_ops.sort_by(|a, b| {
b.priority.cmp(&a.priority).then_with(|| {
a.estimated_cost
.partial_cmp(&b.estimated_cost)
.unwrap_or(std::cmp::Ordering::Equal)
})
});
let results: Result<Vec<_>> = sorted_ops
.par_chunks(32)
.map(|chunk| {
chunk
.iter()
.map(|op| self.execute_single_batch_operation(op))
.collect::<Result<Vec<_>>>()
})
.collect::<Result<Vec<_>>>()
.map(|vec_of_vecs| vec_of_vecs.into_iter().flatten().collect());
results
}
fn execute_single_batch_operation(&self, operation: &BatchOperation) -> Result<String> {
Ok(format!("Executed batch operation: {}", operation.op_name))
}
fn track_hot_operation(&self, op_name: &str) {
let mut scheduler = self
.scheduler
.write()
.expect("write lock should not be poisoned");
scheduler
.hot_operations
.entry(op_name.to_string())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(1, Ordering::Relaxed);
}
pub fn simd_execute_vectorized_ops(&self, ops: &[String]) -> Result<Vec<String>> {
self.metrics.simd_accelerated_ops.add(ops.len() as u64);
let simd_ops: Vec<String> = ops
.par_iter()
.map(|op_name| format!("SIMD-accelerated: {}", op_name))
.collect();
Ok(simd_ops)
}
pub fn get_performance_analytics(&self) -> RegistryAnalytics {
let scheduler = self
.scheduler
.read()
.expect("read lock should not be poisoned");
let hot_ops: HashMap<String, u64> = scheduler
.hot_operations
.iter()
.map(|(k, v)| (k.clone(), v.load(Ordering::Relaxed)))
.collect();
RegistryAnalytics {
total_op_lookups: self.metrics.op_lookups.get(),
total_kernel_executions: self.metrics.kernel_executions.get(),
cache_efficiency: self.calculate_cache_efficiency(),
hot_operations: hot_ops,
recommended_optimizations: self.generate_optimization_recommendations(),
simd_acceleration_usage: self.metrics.simd_accelerated_ops.get(),
}
}
fn calculate_cache_efficiency(&self) -> f64 {
0.85 }
fn generate_optimization_recommendations(&self) -> Vec<String> {
let mut recommendations = Vec::new();
let scheduler = self
.scheduler
.read()
.expect("read lock should not be poisoned");
for (op_name, count) in scheduler.hot_operations.iter() {
let count_val = count.load(Ordering::Relaxed);
if count_val > 1000 {
recommendations.push(format!(
"Consider SIMD optimization for hot operation '{}' (executed {} times)",
op_name, count_val
));
}
}
let cache_efficiency = self.calculate_cache_efficiency();
if cache_efficiency < 0.8 {
recommendations.push(format!(
"Low cache efficiency ({:.2}%). Consider increasing cache size or improving locality.",
cache_efficiency * 100.0
));
}
recommendations
}
pub fn clear_caches(&self) {
{
let mut op_cache = self
.op_cache
.write()
.expect("write lock should not be poisoned");
op_cache.clear();
}
{
let mut kernel_cache = self
.kernel_cache
.write()
.expect("write lock should not be poisoned");
kernel_cache.clear();
}
}
}
#[derive(Debug, Clone)]
pub struct RegistryAnalytics {
pub total_op_lookups: u64,
pub total_kernel_executions: u64,
pub cache_efficiency: f64,
pub hot_operations: HashMap<String, u64>,
pub recommended_optimizations: Vec<String>,
pub simd_acceleration_usage: u64,
}
impl UltraKernelScheduler {
#[allow(dead_code)]
pub(super) fn record_execution(&mut self, op_name: &str, execution_time: f64) {
self.execution_history
.entry(op_name.to_string())
.or_default()
.push(execution_time);
if let Some(history) = self.execution_history.get_mut(op_name) {
if history.len() > 100 {
history.drain(0..50); }
}
}
#[allow(dead_code)]
pub(super) fn predict_execution_time(&self, op_name: &str) -> f64 {
if let Some(history) = self.execution_history.get(op_name) {
if history.is_empty() {
return 1.0; }
let alpha = 0.3;
let mut ema = history[0];
for &time in history.iter().skip(1) {
ema = alpha * time + (1.0 - alpha) * ema;
}
ema
} else {
1.0 }
}
#[allow(dead_code)]
pub(super) fn get_optimal_batch_size(&self, op_name: &str) -> usize {
self.optimal_batch_sizes.get(op_name).copied().unwrap_or(32)
}
#[allow(dead_code)]
pub(super) fn update_cpu_utilization(&self, utilization: u64) {
self.cpu_utilization.store(utilization, Ordering::Relaxed);
}
#[allow(dead_code)]
pub(super) fn update_gpu_utilization(&self, utilization: u64) {
self.gpu_utilization.store(utilization, Ordering::Relaxed);
}
}
impl Default for OpRegistry {
fn default() -> Self {
Self::new()
}
}
pub trait UltraKernel: Send + Sync {
fn compute_simd(
&self,
inputs: &[&dyn Any],
attrs: &HashMap<String, AttrValue>,
) -> Result<Vec<Box<dyn Any>>> {
self.compute(inputs, attrs)
}
fn compute_gpu(
&self,
inputs: &[&dyn Any],
attrs: &HashMap<String, AttrValue>,
) -> Result<Vec<Box<dyn Any>>> {
self.compute(inputs, attrs)
}
fn compute(
&self,
inputs: &[&dyn Any],
attrs: &HashMap<String, AttrValue>,
) -> Result<Vec<Box<dyn Any>>>;
fn device(&self) -> Device;
fn dtype(&self) -> DType;
fn supports_simd(&self) -> bool {
false
}
fn supports_gpu(&self) -> bool {
false
}
fn estimated_cost(&self, input_sizes: &[usize]) -> f64 {
input_sizes.iter().sum::<usize>() as f64 * 1e-6
}
}
impl<T: Kernel> UltraKernel for T {
fn compute(
&self,
inputs: &[&dyn Any],
attrs: &HashMap<String, AttrValue>,
) -> Result<Vec<Box<dyn Any>>> {
<Self as Kernel>::compute(self, inputs, attrs)
}
fn device(&self) -> Device {
<Self as Kernel>::device(self)
}
fn dtype(&self) -> DType {
<Self as Kernel>::dtype(self)
}
}
lazy_static::lazy_static! {
pub static ref OP_REGISTRY: OpRegistry = {
let registry = OpRegistry::new();
super::builtin::register_builtin_ops(®istry);
registry
};
}