use super::GpuDevice;
use crate::error::{GpuAdvancedError, Result};
use crossbeam_channel::{Receiver, Sender, bounded};
use parking_lot::Mutex;
use std::sync::Arc;
use std::thread;
use std::time::Instant;
type WorkItem = Box<dyn FnOnce(&GpuDevice) -> Result<()> + Send>;
type ResultSender = Sender<Result<()>>;
pub struct WorkQueue {
device: Arc<GpuDevice>,
work_sender: Option<Sender<(WorkItem, ResultSender)>>,
worker_handle: Option<Arc<Mutex<Option<thread::JoinHandle<()>>>>>,
pending_tasks: Arc<Mutex<usize>>,
}
impl WorkQueue {
pub fn new(device: Arc<GpuDevice>) -> Self {
let (work_sender, work_receiver) = bounded::<(WorkItem, ResultSender)>(256);
let device_clone = device.clone();
let pending_tasks = Arc::new(Mutex::new(0));
let pending_clone = pending_tasks.clone();
let handle = thread::spawn(move || {
Self::worker_loop(device_clone, work_receiver, pending_clone);
});
Self {
device,
work_sender: Some(work_sender),
worker_handle: Some(Arc::new(Mutex::new(Some(handle)))),
pending_tasks,
}
}
fn worker_loop(
device: Arc<GpuDevice>,
work_receiver: Receiver<(WorkItem, ResultSender)>,
pending_tasks: Arc<Mutex<usize>>,
) {
while let Ok((work, result_sender)) = work_receiver.recv() {
let start = Instant::now();
device.set_workload(1.0);
let result = work(&device);
device.set_workload(0.0);
let _ = result_sender.send(result);
{
let mut pending = pending_tasks.lock();
*pending = pending.saturating_sub(1);
}
let duration = start.elapsed();
tracing::debug!(
"Task completed on GPU {} in {:?}",
device.info.index,
duration
);
}
}
pub async fn submit_work<F, T>(&self, work: F) -> Result<T>
where
F: FnOnce(&GpuDevice) -> Result<T> + Send + 'static,
T: Send + 'static,
{
let (result_sender, result_receiver) = bounded(1);
let result_arc = Arc::new(Mutex::new(None));
let result_clone = result_arc.clone();
let work_wrapper: WorkItem = Box::new(move |device| {
let result = work(device);
match result {
Ok(value) => {
*result_clone.lock() = Some(Ok(value));
Ok(())
}
Err(e) => {
*result_clone.lock() = Some(Err(e));
Ok(())
}
}
});
{
let mut pending = self.pending_tasks.lock();
*pending = pending.saturating_add(1);
}
self.work_sender
.as_ref()
.ok_or_else(|| GpuAdvancedError::WorkStealingError("Work queue is closed".to_string()))?
.send((work_wrapper, result_sender))
.map_err(|e| {
GpuAdvancedError::WorkStealingError(format!("Failed to send work: {}", e))
})?;
let _ = result_receiver
.recv()
.map_err(|e| GpuAdvancedError::SyncError(format!("Failed to receive result: {}", e)))?;
result_arc
.lock()
.take()
.ok_or_else(|| GpuAdvancedError::SyncError("Result not available".to_string()))?
}
pub fn pending_count(&self) -> usize {
*self.pending_tasks.lock()
}
pub fn is_empty(&self) -> bool {
self.pending_count() == 0
}
pub fn device(&self) -> Arc<GpuDevice> {
self.device.clone()
}
}
impl Drop for WorkQueue {
fn drop(&mut self) {
drop(self.work_sender.take());
if let Some(handle_arc) = self.worker_handle.take() {
if let Some(handle) = handle_arc.lock().take() {
let _ = handle.join();
}
}
}
}
pub struct WorkStealingQueue {
local_queue: Arc<Mutex<Vec<WorkItem>>>,
steal_threshold: usize,
}
impl WorkStealingQueue {
pub fn new(steal_threshold: usize) -> Self {
Self {
local_queue: Arc::new(Mutex::new(Vec::new())),
steal_threshold,
}
}
pub fn push(&self, work: WorkItem) {
let mut queue = self.local_queue.lock();
queue.push(work);
}
pub fn pop(&self) -> Option<WorkItem> {
let mut queue = self.local_queue.lock();
queue.pop()
}
pub fn steal(&self) -> Vec<WorkItem> {
let mut queue = self.local_queue.lock();
let len = queue.len();
if len <= self.steal_threshold {
return Vec::new();
}
let steal_count = len / 2;
let split_point = len - steal_count;
queue.split_off(split_point)
}
pub fn len(&self) -> usize {
self.local_queue.lock().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn should_allow_stealing(&self) -> bool {
self.len() > self.steal_threshold
}
}
pub struct BatchSubmitter {
queues: Vec<Arc<WorkQueue>>,
current_index: Mutex<usize>,
}
impl BatchSubmitter {
pub fn new(queues: Vec<Arc<WorkQueue>>) -> Self {
Self {
queues,
current_index: Mutex::new(0),
}
}
pub async fn submit_batch<F, T>(&self, work_items: Vec<F>) -> Result<Vec<T>>
where
F: FnOnce(&GpuDevice) -> Result<T> + Send + 'static,
T: Send + 'static,
{
if self.queues.is_empty() {
return Err(GpuAdvancedError::WorkStealingError(
"No work queues available".to_string(),
));
}
let mut futures = Vec::new();
for work in work_items {
let queue_index = {
let mut index = self.current_index.lock();
let current = *index;
*index = (*index + 1) % self.queues.len();
current
};
let queue = &self.queues[queue_index];
let future = queue.submit_work(work);
futures.push(future);
}
let mut results = Vec::new();
for future in futures {
results.push(future.await?);
}
Ok(results)
}
pub async fn submit_batch_to_devices<F, T>(&self, work_items: Vec<(usize, F)>) -> Result<Vec<T>>
where
F: FnOnce(&GpuDevice) -> Result<T> + Send + 'static,
T: Send + 'static,
{
let mut futures = Vec::new();
for (device_index, work) in work_items {
let queue = self
.queues
.get(device_index)
.ok_or(GpuAdvancedError::InvalidGpuIndex {
index: device_index,
total: self.queues.len(),
})?;
let future = queue.submit_work(work);
futures.push(future);
}
let mut results = Vec::new();
for future in futures {
results.push(future.await?);
}
Ok(results)
}
pub fn total_pending(&self) -> usize {
self.queues.iter().map(|q| q.pending_count()).sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_work_stealing_queue() {
let queue = WorkStealingQueue::new(10);
assert!(queue.is_empty());
let work: WorkItem = Box::new(|_device| Ok(()));
queue.push(work);
assert_eq!(queue.len(), 1);
let popped = queue.pop();
assert!(popped.is_some());
assert!(queue.is_empty());
}
#[test]
fn test_work_stealing_threshold() {
let queue = WorkStealingQueue::new(5);
for _ in 0..4 {
queue.push(Box::new(|_device| Ok(())));
}
assert!(!queue.should_allow_stealing());
for _ in 0..3 {
queue.push(Box::new(|_device| Ok(())));
}
assert!(queue.should_allow_stealing());
}
}