#![allow(dead_code)]
use crate::error::{AccelError, AccelResult};
use std::sync::{Arc, Condvar, Mutex};
use std::thread;
use std::time::{Duration, Instant};
pub struct AsyncHandle<T> {
inner: Arc<AsyncInner<T>>,
}
struct AsyncInner<T> {
result: Mutex<Option<AccelResult<T>>>,
cvar: Condvar,
submitted_at: Instant,
}
impl<T: Send + 'static> AsyncHandle<T> {
pub fn wait(self) -> AccelResult<T> {
let inner = self.inner;
let guard = inner
.result
.lock()
.map_err(|e| AccelError::Synchronization(format!("mutex poisoned: {e}")))?;
let guard = inner
.cvar
.wait_while(guard, |r| r.is_none())
.map_err(|e| AccelError::Synchronization(format!("condvar wait failed: {e}")))?;
guard
.clone()
.ok_or_else(|| AccelError::Synchronization("async result missing".to_string()))?
}
#[must_use]
pub fn is_ready(&self) -> bool {
self.inner
.result
.lock()
.map(|g| g.is_some())
.unwrap_or(false)
}
#[must_use]
pub fn age(&self) -> Duration {
self.inner.submitted_at.elapsed()
}
}
impl<T> Clone for AsyncHandle<T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct AsyncQueueStats {
pub submitted: u64,
pub completed: u64,
pub failed: u64,
}
pub struct AsyncComputeQueue {
stats: Arc<Mutex<AsyncQueueStats>>,
}
impl AsyncComputeQueue {
#[must_use]
pub fn new() -> Self {
Self {
stats: Arc::new(Mutex::new(AsyncQueueStats::default())),
}
}
pub fn submit<T, F>(&self, work: F) -> AsyncHandle<T>
where
T: Send + 'static,
F: FnOnce() -> AccelResult<T> + Send + 'static,
{
let inner = Arc::new(AsyncInner {
result: Mutex::new(None),
cvar: Condvar::new(),
submitted_at: Instant::now(),
});
let handle = AsyncHandle {
inner: Arc::clone(&inner),
};
let stats = Arc::clone(&self.stats);
if let Ok(mut s) = stats.lock() {
s.submitted += 1;
}
thread::spawn(move || {
let result = work();
let is_err = result.is_err();
if let Ok(mut guard) = inner.result.lock() {
*guard = Some(result);
}
inner.cvar.notify_all();
if let Ok(mut s) = stats.lock() {
if is_err {
s.failed += 1;
} else {
s.completed += 1;
}
}
});
handle
}
pub fn stats(&self) -> AccelResult<AsyncQueueStats> {
self.stats
.lock()
.map(|g| g.clone())
.map_err(|e| AccelError::Synchronization(format!("stats mutex poisoned: {e}")))
}
}
impl Default for AsyncComputeQueue {
fn default() -> Self {
Self::new()
}
}
pub fn submit_batch<T, F>(queue: &AsyncComputeQueue, items: Vec<F>) -> AccelResult<Vec<T>>
where
T: Send + 'static,
F: FnOnce() -> AccelResult<T> + Send + 'static,
{
let handles: Vec<AsyncHandle<T>> = items.into_iter().map(|f| queue.submit(f)).collect();
handles
.into_iter()
.map(|h| h.wait())
.collect::<AccelResult<Vec<T>>>()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_async_queue_submit_and_wait() {
let q = AsyncComputeQueue::new();
let h = q.submit(|| Ok::<u32, AccelError>(42));
let result = h.wait().expect("async wait should succeed");
assert_eq!(result, 42);
}
#[test]
fn test_async_queue_submit_error() {
let q = AsyncComputeQueue::new();
let h = q.submit(|| Err::<u32, AccelError>(AccelError::OutOfMemory));
let result = h.wait();
assert!(result.is_err());
}
#[test]
fn test_async_queue_is_ready_after_wait() {
let q = AsyncComputeQueue::new();
let h = q.submit(|| Ok::<i32, AccelError>(7));
let h2 = h.clone();
h.wait().expect("wait should succeed");
assert!(h2.is_ready());
}
#[test]
fn test_async_queue_stats_submitted() {
let q = AsyncComputeQueue::new();
let h1 = q.submit(|| Ok::<u8, AccelError>(1));
let h2 = q.submit(|| Ok::<u8, AccelError>(2));
h1.wait().expect("h1 wait should succeed");
h2.wait().expect("h2 wait should succeed");
let stats = q.stats().expect("stats should succeed");
assert_eq!(stats.submitted, 2);
assert_eq!(stats.completed, 2);
assert_eq!(stats.failed, 0);
}
#[test]
fn test_async_queue_stats_failed() {
let q = AsyncComputeQueue::new();
let h = q.submit(|| Err::<u8, AccelError>(AccelError::OutOfMemory));
let _ = h.wait();
std::thread::sleep(Duration::from_millis(10));
let stats = q.stats().expect("stats should succeed");
assert_eq!(stats.failed, 1);
}
#[test]
fn test_async_queue_handle_age() {
let q = AsyncComputeQueue::new();
let h = q.submit(|| Ok::<(), AccelError>(()));
h.clone().wait().expect("wait should succeed");
assert!(h.age() < Duration::from_secs(5));
}
#[test]
fn test_submit_batch_all_succeed() {
let q = AsyncComputeQueue::new();
let items: Vec<Box<dyn FnOnce() -> AccelResult<u32> + Send>> = vec![
Box::new(|| Ok(1)),
Box::new(|| Ok(2)),
Box::new(|| Ok(3)),
];
let results = submit_batch(&q, items).expect("batch should succeed");
assert_eq!(results, vec![1, 2, 3]);
}
#[test]
fn test_submit_batch_propagates_error() {
let q = AsyncComputeQueue::new();
let items: Vec<Box<dyn FnOnce() -> AccelResult<u32> + Send>> = vec![
Box::new(|| Ok(1)),
Box::new(|| Err(AccelError::OutOfMemory)),
Box::new(|| Ok(3)),
];
let result = submit_batch(&q, items);
assert!(result.is_err());
}
#[test]
fn test_async_queue_default() {
let q = AsyncComputeQueue::default();
let stats = q.stats().expect("stats should succeed");
assert_eq!(stats.submitted, 0);
}
#[test]
fn test_async_handle_not_ready_immediately() {
let q = AsyncComputeQueue::new();
let h = q.submit(|| {
std::thread::sleep(Duration::from_millis(50));
Ok::<u32, AccelError>(99)
});
let _ = h.is_ready();
h.wait().expect("wait should succeed");
}
}