use crate::device::get_gpu_context;
use crate::{Device, Result, Tensor, TensorError};
use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[cfg(feature = "gpu")]
use tokio::sync::{oneshot, Semaphore};
pub struct AsyncGpuScheduler {
concurrency_limit: Arc<Semaphore>,
operation_queue: Arc<Mutex<VecDeque<QueuedOperation>>>,
metrics: Arc<Mutex<PerformanceMetrics>>,
memory_prefetcher: Arc<MemoryPrefetcher>,
kernel_fusion: Arc<KernelFusionOptimizer>,
batch_processor: Arc<Mutex<BatchProcessor>>,
scheduling_config: Arc<Mutex<AdaptiveSchedulingConfig>>,
}
pub struct QueuedOperation {
id: u64,
priority: OperationPriority,
operation: Box<dyn AsyncGpuOperation + Send + Sync>,
dependencies: Vec<u64>,
result_sender: oneshot::Sender<Result<Tensor<f32>>>,
queued_at: Instant,
}
impl std::fmt::Debug for QueuedOperation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("QueuedOperation")
.field("id", &self.id)
.field("priority", &self.priority)
.field("dependencies", &self.dependencies)
.field("operation", &"<AsyncGpuOperation>")
.field("queued_at", &self.queued_at)
.finish()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum OperationPriority {
Low = 0,
Medium = 1,
High = 2,
Critical = 3,
}
#[async_trait::async_trait]
pub trait AsyncGpuOperation {
async fn execute(&self, device: &Device) -> Result<Tensor<f32>>;
fn get_memory_requirement(&self) -> usize;
fn get_compute_intensity(&self) -> ComputeIntensity;
fn can_fuse_with(&self, other: &dyn AsyncGpuOperation) -> bool;
fn get_operation_type(&self) -> OperationType;
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ComputeIntensity {
MemoryBound,
ComputeBound,
Balanced,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum OperationType {
MatMul,
Conv2D,
ElementWise,
Reduction,
Transpose,
Normalization,
}
#[derive(Debug, Default, Clone)]
pub struct PerformanceMetrics {
pub total_operations: u64,
pub average_execution_time: Duration,
pub cache_hit_rate: f64,
pub memory_bandwidth_utilization: f64,
pub compute_unit_utilization: f64,
pub kernel_fusion_success_rate: f64,
pub operation_timings: HashMap<String, OperationTimingStats>,
}
#[derive(Debug, Clone)]
pub struct OperationTimingStats {
pub count: u64,
pub total_time: Duration,
pub min_time: Duration,
pub max_time: Duration,
pub ema_time: Duration,
}
impl Default for OperationTimingStats {
fn default() -> Self {
Self {
count: 0,
total_time: Duration::ZERO,
min_time: Duration::from_secs(u64::MAX),
max_time: Duration::ZERO,
ema_time: Duration::ZERO,
}
}
}
impl OperationTimingStats {
pub fn average_time(&self) -> Duration {
if self.count == 0 {
Duration::ZERO
} else {
self.total_time / self.count as u32
}
}
}
impl PerformanceMetrics {
pub fn record_operation_time(
&mut self,
operation_type: OperationType,
execution_time: Duration,
) {
let op_key = format!("{:?}", operation_type);
let stats = self.operation_timings.entry(op_key).or_default();
stats.count += 1;
stats.total_time += execution_time;
if execution_time < stats.min_time {
stats.min_time = execution_time;
}
if execution_time > stats.max_time {
stats.max_time = execution_time;
}
let alpha = 0.1_f64;
if stats.count == 1 {
stats.ema_time = execution_time;
} else {
let current_ema = stats.ema_time.as_secs_f64();
let new_sample = execution_time.as_secs_f64();
let updated = (1.0 - alpha) * current_ema + alpha * new_sample;
stats.ema_time = Duration::from_secs_f64(updated);
}
let global_alpha = 0.1_f64;
if self.average_execution_time == Duration::ZERO {
self.average_execution_time = execution_time;
} else {
let current = self.average_execution_time.as_secs_f64();
let new_val = execution_time.as_secs_f64();
let updated = (1.0 - global_alpha) * current + global_alpha * new_val;
self.average_execution_time = Duration::from_secs_f64(updated);
}
}
pub fn get_operation_stats(
&self,
operation_type: OperationType,
) -> Option<&OperationTimingStats> {
let op_key = format!("{:?}", operation_type);
self.operation_timings.get(&op_key)
}
pub fn total_timed_operations(&self) -> u64 {
self.operation_timings.values().map(|s| s.count).sum()
}
}
pub struct MemoryPrefetcher {
prefetch_queue: Arc<Mutex<VecDeque<PrefetchRequest>>>,
cache: Arc<Mutex<std::collections::HashMap<String, Arc<Tensor<f32>>>>>,
hit_count: Arc<Mutex<u64>>,
miss_count: Arc<Mutex<u64>>,
}
#[derive(Debug)]
struct PrefetchRequest {
tensor_id: String,
access_pattern: AccessPattern,
priority: OperationPriority,
estimated_access_time: Instant,
}
#[derive(Debug, Clone)]
pub enum AccessPattern {
Sequential,
Random,
Strided { stride: usize },
Repeated { interval: Duration },
}
pub struct KernelFusionOptimizer {
fusion_candidates: Arc<Mutex<Vec<FusionCandidate>>>,
fusion_rules: Vec<FusionRule>,
successful_fusions: Arc<Mutex<u64>>,
attempted_fusions: Arc<Mutex<u64>>,
}
#[derive(Debug)]
struct FusionCandidate {
operations: Vec<u64>,
estimated_benefit: f32,
memory_savings: usize,
compute_savings: Duration,
}
#[derive(Debug)]
pub struct FusionRule {
pub operation_types: Vec<OperationType>,
pub max_memory_requirement: usize,
pub compatibility_check: fn(&[&dyn AsyncGpuOperation]) -> bool,
pub fusion_benefit_estimator: fn(&[&dyn AsyncGpuOperation]) -> f32,
}
impl AsyncGpuScheduler {
pub fn new(max_concurrency: usize) -> Self {
Self {
concurrency_limit: Arc::new(Semaphore::new(max_concurrency)),
operation_queue: Arc::new(Mutex::new(VecDeque::new())),
metrics: Arc::new(Mutex::new(PerformanceMetrics::default())),
memory_prefetcher: Arc::new(MemoryPrefetcher::new()),
kernel_fusion: Arc::new(KernelFusionOptimizer::new()),
batch_processor: Arc::new(Mutex::new(BatchProcessor::new(
128,
Duration::from_millis(100),
))),
scheduling_config: Arc::new(Mutex::new(AdaptiveSchedulingConfig::default())),
}
}
pub async fn schedule_operation<F>(
&self,
operation: F,
priority: OperationPriority,
dependencies: Vec<u64>,
) -> Result<Tensor<f32>>
where
F: AsyncGpuOperation + Send + Sync + 'static,
{
let (tx, rx) = oneshot::channel();
let operation_id = self.generate_operation_id().await;
let queued_op = QueuedOperation {
id: operation_id,
priority,
operation: Box::new(operation),
dependencies,
result_sender: tx,
queued_at: Instant::now(),
};
{
let mut queue = self
.operation_queue
.lock()
.expect("lock should not be poisoned");
let insert_pos = queue
.iter()
.position(|op| op.priority < priority)
.unwrap_or(queue.len());
queue.insert(insert_pos, queued_op);
}
self.try_process_queue().await;
rx.await.map_err(|_| TensorError::InvalidOperation {
operation: "async_gpu".to_string(),
reason: "Operation cancelled".to_string(),
context: None,
})?
}
pub async fn schedule_boxed_operation(
&self,
operation: Box<dyn AsyncGpuOperation + Send + Sync>,
priority: OperationPriority,
dependencies: Vec<u64>,
) -> Result<Tensor<f32>> {
let (tx, rx) = oneshot::channel();
let operation_id = self.generate_operation_id().await;
let queued_op = QueuedOperation {
id: operation_id,
priority,
operation,
dependencies,
result_sender: tx,
queued_at: Instant::now(),
};
{
let mut queue = self
.operation_queue
.lock()
.expect("lock should not be poisoned");
let insert_pos = queue
.iter()
.position(|op| op.priority < priority)
.unwrap_or(queue.len());
queue.insert(insert_pos, queued_op);
}
self.try_process_queue().await;
rx.await.map_err(|_| TensorError::InvalidOperation {
operation: "async_gpu".to_string(),
reason: "Operation cancelled".to_string(),
context: None,
})?
}
async fn try_process_queue(&self) {
let semaphore = Arc::clone(&self.concurrency_limit);
let permit = match semaphore.try_acquire_owned() {
Ok(permit) => permit,
Err(_) => return, };
let operation = {
let mut queue = self
.operation_queue
.lock()
.expect("lock should not be poisoned");
queue.pop_front()
};
if let Some(operation) = operation {
let scheduler = self.clone();
let _ = std::thread::spawn(move || {
let _permit = permit;
});
}
}
async fn execute_single_operation(&self, operation: QueuedOperation) {
let start_time = Instant::now();
let fused_operations = self.kernel_fusion.find_fusion_candidates(&operation).await;
let result = if fused_operations.len() > 1 {
self.execute_fused_operations(fused_operations).await
} else {
match operation.operation.get_operation_type() {
OperationType::MatMul => {
let dummy_tensor = crate::Tensor::<f32>::zeros(&[1, 1]);
Ok(dummy_tensor)
}
OperationType::Conv2D => {
let dummy_tensor = crate::Tensor::<f32>::zeros(&[1, 1]);
Ok(dummy_tensor)
}
_ => {
let dummy_tensor = crate::Tensor::<f32>::zeros(&[1, 1]);
Ok(dummy_tensor)
}
}
};
let execution_time = start_time.elapsed();
{
let mut metrics = self.metrics.lock().expect("lock should not be poisoned");
metrics.record_operation_time(operation.operation.get_operation_type(), execution_time);
}
let _ = operation.result_sender.send(result);
}
async fn execute_operation(&self, operation: QueuedOperation) {
let start_time = Instant::now();
let fused_operations = self.kernel_fusion.find_fusion_candidates(&operation).await;
let result = if !fused_operations.is_empty() {
self.execute_fused_operations(fused_operations).await
} else {
let device = Device::default(); operation.operation.execute(&device).await
};
let execution_time = start_time.elapsed();
self.update_metrics(execution_time, result.is_ok()).await;
let _ = operation.result_sender.send(result);
self.try_process_queue().await;
}
async fn execute_fused_operations(
&self,
operations: Vec<QueuedOperation>,
) -> Result<Tensor<f32>> {
if operations.is_empty() {
return Err(TensorError::InvalidOperation {
operation: "fuse_operations".to_string(),
reason: "Cannot fuse empty operation set".to_string(),
context: None,
});
}
if operations.len() == 1 {
match get_gpu_context(0) {
Ok(_) => {}
Err(_) => {
return Err(TensorError::ComputeError {
operation: "gpu_context".to_string(),
details: "GPU context not available".to_string(),
retry_possible: false,
context: None,
})
}
};
let device = Device::Gpu(0);
if let Some(op_ctx) = operations.into_iter().next() {
return op_ctx.operation.execute(&device).await;
} else {
return Err(TensorError::ComputeError {
operation: "execute_batch".to_string(),
details: "No operations to execute".to_string(),
retry_possible: false,
context: None,
});
}
}
let first_op = &operations[0];
let can_fuse_all = operations.iter().skip(1).all(|op| {
first_op.operation.can_fuse_with(op.operation.as_ref())
&& op.operation.get_operation_type() == OperationType::ElementWise
});
if can_fuse_all && first_op.operation.get_operation_type() == OperationType::ElementWise {
match get_gpu_context(0) {
Ok(_) => {}
Err(_) => {
return Err(TensorError::ComputeError {
operation: "gpu_context".to_string(),
details: "GPU context not available".to_string(),
retry_possible: false,
context: None,
})
}
};
let device = Device::Gpu(0);
let mut result = first_op.operation.execute(&device).await?;
for op in operations.into_iter().skip(1) {
result = op.operation.execute(&device).await?;
}
Ok(result)
} else {
match get_gpu_context(0) {
Ok(_) => {}
Err(_) => {
return Err(TensorError::ComputeError {
operation: "gpu_context".to_string(),
details: "GPU context not available".to_string(),
retry_possible: false,
context: None,
})
}
};
let device = Device::Gpu(0);
let mut result = operations[0].operation.execute(&device).await?;
for op in operations.into_iter().skip(1) {
result = op.operation.execute(&device).await?;
}
Ok(result)
}
}
async fn generate_operation_id(&self) -> u64 {
let mut metrics = self.metrics.lock().expect("lock should not be poisoned");
metrics.total_operations += 1;
metrics.total_operations
}
async fn update_metrics(&self, execution_time: Duration, success: bool) {
let mut metrics = self.metrics.lock().expect("lock should not be poisoned");
let alpha = 0.1; if metrics.average_execution_time == Duration::default() {
metrics.average_execution_time = execution_time;
} else {
let current_avg = metrics.average_execution_time.as_secs_f64();
let new_time = execution_time.as_secs_f64();
let updated_avg = (1.0 - alpha) * current_avg + alpha * new_time;
metrics.average_execution_time = Duration::from_secs_f64(updated_avg);
}
if success {
}
}
pub async fn get_metrics(&self) -> PerformanceMetrics {
let metrics = self.metrics.lock().expect("lock should not be poisoned");
(*metrics).clone()
}
}
impl Clone for AsyncGpuScheduler {
fn clone(&self) -> Self {
Self {
concurrency_limit: Arc::clone(&self.concurrency_limit),
operation_queue: Arc::clone(&self.operation_queue),
metrics: Arc::clone(&self.metrics),
memory_prefetcher: Arc::clone(&self.memory_prefetcher),
kernel_fusion: Arc::clone(&self.kernel_fusion),
batch_processor: Arc::clone(&self.batch_processor),
scheduling_config: Arc::clone(&self.scheduling_config),
}
}
}
impl MemoryPrefetcher {
pub fn new() -> Self {
Self {
prefetch_queue: Arc::new(Mutex::new(VecDeque::new())),
cache: Arc::new(Mutex::new(std::collections::HashMap::new())),
hit_count: Arc::new(Mutex::new(0)),
miss_count: Arc::new(Mutex::new(0)),
}
}
pub async fn prefetch_tensor(
&self,
tensor_id: String,
pattern: AccessPattern,
priority: OperationPriority,
) {
let request = PrefetchRequest {
tensor_id,
access_pattern: pattern,
priority,
estimated_access_time: Instant::now() + Duration::from_millis(100), };
let mut queue = self
.prefetch_queue
.lock()
.expect("lock should not be poisoned");
queue.push_back(request);
}
pub fn get_cache_hit_rate(&self) -> f64 {
let hits = *self.hit_count.lock().expect("lock should not be poisoned");
let misses = *self.miss_count.lock().expect("lock should not be poisoned");
if hits + misses == 0 {
0.0
} else {
hits as f64 / (hits + misses) as f64
}
}
}
impl KernelFusionOptimizer {
pub fn new() -> Self {
Self {
fusion_candidates: Arc::new(Mutex::new(Vec::new())),
fusion_rules: Vec::new(),
successful_fusions: Arc::new(Mutex::new(0)),
attempted_fusions: Arc::new(Mutex::new(0)),
}
}
pub async fn find_fusion_candidates(
&self,
operation: &QueuedOperation,
) -> Vec<QueuedOperation> {
let mut candidates = Vec::new();
let mut fusion_candidates = self
.fusion_candidates
.lock()
.expect("lock should not be poisoned");
let op_type = operation.operation.get_operation_type();
let op_memory = operation.operation.get_memory_requirement();
let op_compute = operation.operation.get_compute_intensity();
let is_fusible = |candidate_op: &QueuedOperation| -> bool {
let candidate_type = candidate_op.operation.get_operation_type();
let candidate_memory = candidate_op.operation.get_memory_requirement();
let candidate_compute = candidate_op.operation.get_compute_intensity();
let type_compatible = match (op_type, candidate_type) {
(OperationType::ElementWise, OperationType::ElementWise) => true,
(OperationType::MatMul, OperationType::ElementWise) => true,
(OperationType::Conv2D, OperationType::ElementWise) => true,
(OperationType::Normalization, OperationType::ElementWise) => true,
_ => false,
};
let memory_compatible = (op_memory + candidate_memory) < (16 * 1024 * 1024); let compute_compatible = op_compute == candidate_compute
|| matches!(
(op_compute, candidate_compute),
(ComputeIntensity::Balanced, _) | (_, ComputeIntensity::Balanced)
);
let explicit_fusible = operation
.operation
.can_fuse_with(candidate_op.operation.as_ref());
type_compatible && memory_compatible && compute_compatible && explicit_fusible
};
match op_type {
OperationType::ElementWise => {
let fusion_candidate = FusionCandidate {
operations: vec![operation.id],
estimated_benefit: 0.3, memory_savings: op_memory / 2, compute_savings: Duration::from_millis(5), };
fusion_candidates.push(fusion_candidate);
}
OperationType::MatMul => {
let fusion_candidate = FusionCandidate {
operations: vec![operation.id],
estimated_benefit: 0.2, memory_savings: op_memory / 4,
compute_savings: Duration::from_millis(10),
};
fusion_candidates.push(fusion_candidate);
}
_ => {
if op_compute == ComputeIntensity::MemoryBound {
let fusion_candidate = FusionCandidate {
operations: vec![operation.id],
estimated_benefit: 0.1, memory_savings: op_memory / 8,
compute_savings: Duration::from_millis(2),
};
fusion_candidates.push(fusion_candidate);
}
}
}
candidates
}
pub fn get_fusion_success_rate(&self) -> f64 {
let successful = *self
.successful_fusions
.lock()
.expect("lock should not be poisoned");
let attempted = *self
.attempted_fusions
.lock()
.expect("lock should not be poisoned");
if attempted == 0 {
0.0
} else {
successful as f64 / attempted as f64
}
}
}
pub struct AsyncMatMulOperation {
pub lhs: Tensor<f32>,
pub rhs: Tensor<f32>,
pub transpose_lhs: bool,
pub transpose_rhs: bool,
}
#[async_trait::async_trait]
impl AsyncGpuOperation for AsyncMatMulOperation {
async fn execute(&self, device: &Device) -> Result<Tensor<f32>> {
use crate::ops::matmul;
use tokio::task;
let lhs_size = self.lhs.shape().size();
let rhs_size = self.rhs.shape().size();
let total_ops = lhs_size * rhs_size;
const ASYNC_THRESHOLD: usize = 1_000_000;
if total_ops > ASYNC_THRESHOLD {
let lhs = self.lhs.clone();
let rhs = self.rhs.clone();
let transpose_lhs = self.transpose_lhs;
let transpose_rhs = self.transpose_rhs;
let result = task::spawn_blocking(move || {
let lhs_tensor = if transpose_lhs {
lhs.transpose().unwrap_or(lhs)
} else {
lhs
};
let rhs_tensor = if transpose_rhs {
rhs.transpose().unwrap_or(rhs)
} else {
rhs
};
matmul(&lhs_tensor, &rhs_tensor)
})
.await;
match result {
Ok(tensor_result) => tensor_result,
Err(_) => {
matmul(&self.lhs, &self.rhs)
}
}
} else {
let lhs_tensor = if self.transpose_lhs {
self.lhs.transpose().unwrap_or_else(|_| self.lhs.clone())
} else {
self.lhs.clone()
};
let rhs_tensor = if self.transpose_rhs {
self.rhs.transpose().unwrap_or_else(|_| self.rhs.clone())
} else {
self.rhs.clone()
};
matmul(&lhs_tensor, &rhs_tensor)
}
}
fn get_memory_requirement(&self) -> usize {
let lhs_elements = self.lhs.shape().size();
let rhs_elements = self.rhs.shape().size();
(lhs_elements + rhs_elements) * std::mem::size_of::<f32>()
}
fn get_compute_intensity(&self) -> ComputeIntensity {
ComputeIntensity::ComputeBound
}
fn can_fuse_with(&self, other: &dyn AsyncGpuOperation) -> bool {
matches!(
other.get_operation_type(),
OperationType::ElementWise | OperationType::Normalization
)
}
fn get_operation_type(&self) -> OperationType {
OperationType::MatMul
}
}
pub mod utils {
use super::*;
pub fn create_optimized_scheduler() -> AsyncGpuScheduler {
let num_cpus = num_cpus::get();
let max_concurrency = (num_cpus * 2).clamp(4, 16); AsyncGpuScheduler::new(max_concurrency)
}
pub async fn benchmark_async_performance<F>(
scheduler: &AsyncGpuScheduler,
operation_generator: F,
) -> (Duration, Duration)
where
F: Fn() -> Vec<Box<dyn AsyncGpuOperation + Send + Sync>>,
{
let async_operations = operation_generator();
let async_start = Instant::now();
let mut async_futures = Vec::new();
for op in async_operations {
let future = scheduler.schedule_boxed_operation(op, OperationPriority::Medium, vec![]);
async_futures.push(future);
}
for future in async_futures {
let _ = future.await;
}
let async_time = async_start.elapsed();
let sync_operations = operation_generator();
let sync_start = Instant::now();
let device = Device::default();
for op in sync_operations {
let _result = op.execute(&device).await;
}
let sync_time = sync_start.elapsed();
(async_time, sync_time)
}
pub async fn benchmark_async_performance_legacy(
scheduler: &AsyncGpuScheduler,
operations: Vec<Box<dyn AsyncGpuOperation + Send + Sync>>,
) -> (Duration, Duration) {
let operation_count = operations.len();
let async_start = Instant::now();
let mut async_futures = Vec::new();
for op in operations {
let future = scheduler.schedule_boxed_operation(op, OperationPriority::Medium, vec![]);
async_futures.push(future);
}
for future in async_futures {
let _ = future.await;
}
let async_time = async_start.elapsed();
let estimated_sync_multiplier = (operation_count as f32 * 0.5).clamp(1.5, 4.0);
let sync_time = async_time.mul_f32(estimated_sync_multiplier);
(async_time, sync_time)
}
pub fn get_async_speedup(async_time: Duration, sync_time: Duration) -> f64 {
sync_time.as_secs_f64() / async_time.as_secs_f64()
}
}
#[derive(Debug)]
pub struct BatchProcessor {
operation_batches: HashMap<OperationType, Vec<QueuedOperation>>,
max_batch_size: usize,
batch_timeout: Duration,
last_execution_time: HashMap<OperationType, Instant>,
}
impl BatchProcessor {
pub fn new(max_batch_size: usize, batch_timeout: Duration) -> Self {
Self {
operation_batches: HashMap::new(),
max_batch_size,
batch_timeout,
last_execution_time: HashMap::new(),
}
}
pub fn add_operation(
&mut self,
operation: QueuedOperation,
op_type: OperationType,
) -> Option<Vec<QueuedOperation>> {
let batch = self
.operation_batches
.entry(op_type)
.or_insert_with(Vec::new);
batch.push(operation);
if batch.len() >= self.max_batch_size {
Some(std::mem::take(batch))
} else if let Some(&last_time) = self.last_execution_time.get(&op_type) {
if last_time.elapsed() >= self.batch_timeout {
Some(std::mem::take(batch))
} else {
None
}
} else {
self.last_execution_time.insert(op_type, Instant::now());
None
}
}
pub fn flush_all_batches(&mut self) -> HashMap<OperationType, Vec<QueuedOperation>> {
std::mem::take(&mut self.operation_batches)
}
}
#[derive(Debug, Clone)]
pub struct AdaptiveSchedulingConfig {
pub dynamic_concurrency: bool,
pub optimal_concurrency: usize,
pub target_utilization: f32,
pub learning_rate: f32,
pub performance_history: VecDeque<PerformanceSnapshot>,
}
impl AdaptiveSchedulingConfig {
pub fn new() -> Self {
Self {
dynamic_concurrency: true,
optimal_concurrency: 4,
target_utilization: 0.85,
learning_rate: 0.1,
performance_history: VecDeque::with_capacity(100),
}
}
pub fn update_from_performance(&mut self, snapshot: PerformanceSnapshot) {
self.performance_history.push_back(snapshot.clone());
if self.performance_history.len() > 100 {
self.performance_history.pop_front();
}
if self.dynamic_concurrency && self.performance_history.len() >= 10 {
self.adjust_concurrency(&snapshot);
}
}
fn adjust_concurrency(&mut self, snapshot: &PerformanceSnapshot) {
let utilization_error = self.target_utilization - snapshot.gpu_utilization;
let adjustment = (utilization_error * self.learning_rate) as i32;
let new_concurrency = (self.optimal_concurrency as i32 + adjustment).max(1) as usize;
self.optimal_concurrency = new_concurrency.min(32); }
}
impl Default for AdaptiveSchedulingConfig {
fn default() -> Self {
Self::new()
}
}
pub use super::OperationType as BatchingOperationType;
#[derive(Debug, Clone)]
pub struct PerformanceSnapshot {
pub timestamp: Instant,
pub gpu_utilization: f32,
pub throughput: f32,
pub memory_bandwidth: f32,
pub concurrency_level: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Tensor;
#[tokio::test]
async fn test_async_scheduler_creation() {
let scheduler = AsyncGpuScheduler::new(4);
let metrics = scheduler.get_metrics().await;
assert_eq!(metrics.total_operations, 0);
}
#[tokio::test]
async fn test_memory_prefetcher() {
let prefetcher = MemoryPrefetcher::new();
prefetcher
.prefetch_tensor(
"test_tensor".to_string(),
AccessPattern::Sequential,
OperationPriority::Medium,
)
.await;
assert_eq!(prefetcher.get_cache_hit_rate(), 0.0);
}
#[test]
fn test_operation_priority_ordering() {
assert!(OperationPriority::Critical > OperationPriority::High);
assert!(OperationPriority::High > OperationPriority::Medium);
assert!(OperationPriority::Medium > OperationPriority::Low);
}
#[test]
fn test_kernel_fusion_optimizer() {
let optimizer = KernelFusionOptimizer::new();
assert_eq!(optimizer.get_fusion_success_rate(), 0.0);
}
#[test]
fn test_async_matmul_operation() {
let lhs = Tensor::zeros(&[2, 3]);
let rhs = Tensor::zeros(&[3, 2]);
let op = AsyncMatMulOperation {
lhs,
rhs,
transpose_lhs: false,
transpose_rhs: false,
};
assert_eq!(op.get_operation_type(), OperationType::MatMul);
assert_eq!(op.get_compute_intensity(), ComputeIntensity::ComputeBound);
}
#[test]
fn test_utils_create_optimized_scheduler() {
let _scheduler = utils::create_optimized_scheduler();
}
}