#[cfg(feature = "gpu")]
use crate::device::async_execution::AsyncExecutor;
#[cfg(feature = "gpu")]
use crate::gpu::multi_stream_executor::{MultiStreamGpuExecutor, StreamPriority};
#[cfg(feature = "gpu")]
use crate::{Device, Result, Tensor, TensorError};
#[cfg(feature = "gpu")]
use std::collections::VecDeque;
#[cfg(feature = "gpu")]
use std::future::Future;
#[cfg(feature = "gpu")]
use std::pin::Pin;
#[cfg(feature = "gpu")]
use std::sync::{Arc, Mutex};
#[cfg(feature = "gpu")]
use std::task::{Context, Poll};
#[cfg(feature = "gpu")]
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum WorkPriority {
Low = 0,
Normal = 1,
High = 2,
Critical = 3,
}
#[derive(Debug, Clone)]
pub enum WorkType {
BinaryOp {
operation: String,
input_size: usize,
dtype: String,
},
Reduction {
operation: String,
input_size: usize,
axis: Option<Vec<usize>>,
},
MatrixMultiplication {
m: usize,
n: usize,
k: usize,
},
Convolution {
input_shape: Vec<usize>,
kernel_shape: Vec<usize>,
stride: Vec<usize>,
},
DataTransfer {
size: usize,
from_device: Device,
to_device: Device,
},
}
#[derive(Debug)]
pub struct WorkItem {
pub id: u64,
pub work_type: WorkType,
pub priority: WorkPriority,
pub estimated_duration: Duration,
pub preferred_device: Option<Device>,
pub created_at: Instant,
}
#[derive(Debug, Clone)]
pub enum ExecutionStrategy {
CpuOnly,
GpuOnly {
stream_priority: StreamPriority,
},
CpuGpuOverlap {
cpu_work: Vec<WorkType>,
gpu_work: Vec<WorkType>,
},
Adaptive, }
#[derive(Debug, Clone)]
pub struct ResourceMetrics {
pub cpu_utilization: f32,
pub gpu_utilization: f32,
pub memory_usage: f32,
pub pending_cpu_work: usize,
pub pending_gpu_work: usize,
pub last_updated: Instant,
}
pub struct HybridWorkScheduler {
cpu_executor: Arc<AsyncExecutor>,
gpu_executor: Arc<MultiStreamGpuExecutor>,
work_queue: Arc<Mutex<VecDeque<WorkItem>>>,
metrics: Arc<Mutex<ResourceMetrics>>,
work_counter: Arc<Mutex<u64>>,
scheduling_config: SchedulingConfig,
}
#[derive(Debug, Clone)]
pub struct SchedulingConfig {
pub cpu_threshold: f32, pub gpu_threshold: f32, pub small_op_threshold: usize, pub overlap_factor: f32, pub adaptive_scheduling: bool, }
impl Default for SchedulingConfig {
fn default() -> Self {
Self {
cpu_threshold: 0.8,
gpu_threshold: 0.8,
small_op_threshold: 1024,
overlap_factor: 0.7,
adaptive_scheduling: true,
}
}
}
impl HybridWorkScheduler {
pub fn new(
cpu_executor: Arc<AsyncExecutor>,
gpu_executor: Arc<MultiStreamGpuExecutor>,
) -> Self {
Self {
cpu_executor,
gpu_executor,
work_queue: Arc::new(Mutex::new(VecDeque::new())),
metrics: Arc::new(Mutex::new(ResourceMetrics {
cpu_utilization: 0.0,
gpu_utilization: 0.0,
memory_usage: 0.0,
pending_cpu_work: 0,
pending_gpu_work: 0,
last_updated: Instant::now(),
})),
work_counter: Arc::new(Mutex::new(0)),
scheduling_config: SchedulingConfig::default(),
}
}
pub fn with_config(
cpu_executor: Arc<AsyncExecutor>,
gpu_executor: Arc<MultiStreamGpuExecutor>,
config: SchedulingConfig,
) -> Self {
let mut scheduler = Self::new(cpu_executor, gpu_executor);
scheduler.scheduling_config = config;
scheduler
}
pub fn submit_work(&self, work: WorkItem) -> HybridWorkFuture<'_> {
let work_id = work.id;
{
let mut queue = self.work_queue.lock().expect("lock should not be poisoned");
queue.push_back(work);
}
self.schedule_pending_work();
HybridWorkFuture {
work_id,
scheduler: self,
completed: false,
}
}
fn schedule_pending_work(&self) {
let mut queue = self.work_queue.lock().expect("lock should not be poisoned");
let metrics = self.metrics.lock().expect("lock should not be poisoned");
let mut work_items: Vec<_> = queue.drain(..).collect();
work_items.sort_by_key(|item| std::cmp::Reverse(item.priority));
for work in work_items {
let strategy = self.determine_execution_strategy(&work, &metrics);
self.execute_work(work, strategy);
}
}
fn determine_execution_strategy(
&self,
work: &WorkItem,
metrics: &ResourceMetrics,
) -> ExecutionStrategy {
if !self.scheduling_config.adaptive_scheduling {
return ExecutionStrategy::Adaptive;
}
match &work.work_type {
WorkType::BinaryOp { input_size, .. } => {
if *input_size < self.scheduling_config.small_op_threshold {
ExecutionStrategy::CpuOnly
} else if metrics.gpu_utilization < self.scheduling_config.gpu_threshold {
ExecutionStrategy::GpuOnly {
stream_priority: self.priority_to_stream_priority(work.priority),
}
} else if metrics.cpu_utilization < self.scheduling_config.cpu_threshold {
ExecutionStrategy::CpuOnly
} else {
ExecutionStrategy::CpuGpuOverlap {
cpu_work: vec![WorkType::BinaryOp {
operation: "preprocessing".to_string(),
input_size: *input_size / 2,
dtype: "f32".to_string(),
}],
gpu_work: vec![work.work_type.clone()],
}
}
}
WorkType::MatrixMultiplication { m, n, k } => {
let total_ops = m * n * k;
if total_ops < self.scheduling_config.small_op_threshold * 100 {
ExecutionStrategy::CpuOnly
} else {
ExecutionStrategy::GpuOnly {
stream_priority: StreamPriority::High,
}
}
}
WorkType::Convolution { .. } => {
ExecutionStrategy::GpuOnly {
stream_priority: StreamPriority::Normal,
}
}
WorkType::DataTransfer { .. } => {
ExecutionStrategy::GpuOnly {
stream_priority: StreamPriority::Normal,
}
}
WorkType::Reduction { input_size, .. } => {
if *input_size < self.scheduling_config.small_op_threshold {
ExecutionStrategy::CpuOnly
} else {
ExecutionStrategy::GpuOnly {
stream_priority: StreamPriority::Normal,
}
}
}
}
}
fn priority_to_stream_priority(&self, priority: WorkPriority) -> StreamPriority {
match priority {
WorkPriority::Low => StreamPriority::Low,
WorkPriority::Normal => StreamPriority::Normal,
WorkPriority::High => StreamPriority::High,
WorkPriority::Critical => StreamPriority::Critical,
}
}
fn execute_work(&self, work: WorkItem, strategy: ExecutionStrategy) {
match strategy {
ExecutionStrategy::CpuOnly => {
self.execute_cpu_work(work);
}
ExecutionStrategy::GpuOnly { stream_priority } => {
self.execute_gpu_work(work, stream_priority);
}
ExecutionStrategy::CpuGpuOverlap { cpu_work, gpu_work } => {
self.execute_overlapped_work(work, cpu_work, gpu_work);
}
ExecutionStrategy::Adaptive => {
let metrics = self.metrics.lock().expect("lock should not be poisoned");
let adaptive_strategy = self.determine_execution_strategy(&work, &metrics);
drop(metrics);
self.execute_work(work, adaptive_strategy);
}
}
}
fn execute_cpu_work(&self, work: WorkItem) {
println!("Executing work {} on CPU: {:?}", work.id, work.work_type);
{
let mut metrics = self.metrics.lock().expect("lock should not be poisoned");
metrics.pending_cpu_work += 1;
}
}
fn execute_gpu_work(&self, work: WorkItem, stream_priority: StreamPriority) {
println!(
"Executing work {} on GPU (priority: {:?}): {:?}",
work.id, stream_priority, work.work_type
);
{
let mut metrics = self.metrics.lock().expect("lock should not be poisoned");
metrics.pending_gpu_work += 1;
}
}
fn execute_overlapped_work(
&self,
work: WorkItem,
cpu_work: Vec<WorkType>,
gpu_work: Vec<WorkType>,
) {
println!("Executing work {} with CPU-GPU overlap", work.id);
println!(" CPU work: {:?}", cpu_work);
println!(" GPU work: {:?}", gpu_work);
for cpu_item in cpu_work {
self.execute_cpu_preprocessing(cpu_item);
}
for gpu_item in gpu_work {
self.execute_gpu_computation(gpu_item);
}
{
let mut metrics = self.metrics.lock().expect("lock should not be poisoned");
metrics.pending_cpu_work += 1;
metrics.pending_gpu_work += 1;
}
}
fn execute_cpu_preprocessing(&self, work: WorkType) {
println!("CPU preprocessing: {:?}", work);
}
fn execute_gpu_computation(&self, work: WorkType) {
println!("GPU computation: {:?}", work);
}
pub fn update_metrics(&self, cpu_util: f32, gpu_util: f32, memory_usage: f32) {
let mut metrics = self.metrics.lock().expect("lock should not be poisoned");
metrics.cpu_utilization = cpu_util;
metrics.gpu_utilization = gpu_util;
metrics.memory_usage = memory_usage;
metrics.last_updated = Instant::now();
}
pub fn get_metrics(&self) -> ResourceMetrics {
self.metrics
.lock()
.expect("lock should not be poisoned")
.clone()
}
fn next_work_id(&self) -> u64 {
let mut counter = self
.work_counter
.lock()
.expect("lock should not be poisoned");
*counter += 1;
*counter
}
pub fn create_binary_op_work(
&self,
operation: &str,
input_size: usize,
dtype: &str,
priority: WorkPriority,
) -> WorkItem {
WorkItem {
id: self.next_work_id(),
work_type: WorkType::BinaryOp {
operation: operation.to_string(),
input_size,
dtype: dtype.to_string(),
},
priority,
estimated_duration: Duration::from_micros(input_size as u64 / 1000),
preferred_device: None,
created_at: Instant::now(),
}
}
pub fn synchronize_all(&self) {
self.gpu_executor.synchronize_all();
}
pub fn is_idle(&self) -> bool {
let queue = self.work_queue.lock().expect("lock should not be poisoned");
queue.is_empty() && self.gpu_executor.is_idle()
}
}
pub struct HybridWorkFuture<'a> {
work_id: u64,
scheduler: &'a HybridWorkScheduler,
completed: bool,
}
impl<'a> Future for HybridWorkFuture<'a> {
type Output = Result<()>;
fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.completed {
return Poll::Ready(Ok(()));
}
if self.scheduler.is_idle() {
self.completed = true;
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}