use super::kernel::GpuKernel;
use std::time::Duration;
pub struct GpuScheduler {
gpu_available: bool,
gpu_utilization: f64,
}
impl GpuScheduler {
pub fn new() -> Self {
Self {
gpu_available: true,
gpu_utilization: 0.0,
}
}
pub fn gpu_available(&self) -> bool {
self.gpu_available
}
pub fn set_gpu_available(&mut self, available: bool) {
self.gpu_available = available;
}
pub fn gpu_utilization(&self) -> f64 {
self.gpu_utilization
}
pub fn update_utilization(&mut self, utilization: f64) {
self.gpu_utilization = utilization.clamp(0.0, 1.0);
}
pub fn should_offload<K: GpuKernel>(&self, kernel: &K, input_size: usize) -> bool {
if !self.gpu_available {
return false;
}
if !kernel.gpu_worthwhile(input_size) {
return false;
}
if self.gpu_utilization > 0.8 {
return false;
}
let cpu_time = self.estimate_cpu_time(input_size);
let gpu_time = kernel.estimate_duration(input_size);
let transfer_overhead = self.estimate_transfer_time(input_size);
let total_gpu_time = gpu_time + transfer_overhead;
total_gpu_time < cpu_time * 2 / 3
}
fn estimate_cpu_time(&self, input_size: usize) -> Duration {
Duration::from_nanos((input_size as u64) * 10)
}
fn estimate_transfer_time(&self, input_size: usize) -> Duration {
let bytes = input_size * std::mem::size_of::<f32>();
let seconds = bytes as f64 / 10_000_000_000.0;
Duration::from_secs_f64(seconds)
}
}
impl Default for GpuScheduler {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::gpu::kernel::VectorAddKernel;
#[test]
fn test_gpu_scheduler() {
let mut scheduler = GpuScheduler::new();
assert!(scheduler.gpu_available());
scheduler.set_gpu_available(false);
assert!(!scheduler.gpu_available());
}
#[test]
fn test_should_offload() {
let scheduler = GpuScheduler::new();
let kernel = VectorAddKernel::new(100_000);
assert!(scheduler.should_offload(&kernel, 100_000));
assert!(!scheduler.should_offload(&kernel, 100));
}
#[test]
fn test_utilization() {
let mut scheduler = GpuScheduler::new();
scheduler.update_utilization(0.5);
assert_eq!(scheduler.gpu_utilization(), 0.5);
scheduler.update_utilization(1.5); assert_eq!(scheduler.gpu_utilization(), 1.0);
}
}