use std::collections::VecDeque;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransferDirection {
HostToDevice,
DeviceToHost,
DeviceToDevice,
}
struct AsyncTransfer<T> {
#[allow(dead_code)]
data: Vec<T>,
#[allow(dead_code)]
direction: TransferDirection,
handle: TransferHandle,
}
#[derive(Debug, Error)]
pub enum AsyncTransferError {
#[error("Pipeline full: {0} pending transfers")]
PipelineFull(usize),
#[error("Transfer ID {0} not found")]
NotFound(u64),
#[error("Failed to acquire pipeline lock")]
LockError,
}
#[derive(Debug, Clone)]
pub struct TransferHandle {
pub id: u64,
completed: Arc<AtomicBool>,
}
impl TransferHandle {
pub fn is_complete(&self) -> bool {
self.completed.load(Ordering::Acquire)
}
}
pub struct AsyncTransferPipeline<T> {
pending: Mutex<VecDeque<AsyncTransfer<T>>>,
max_pending: usize,
id_counter: AtomicU64,
}
impl<T: Clone + Send + 'static> AsyncTransferPipeline<T> {
pub fn new(max_pending: usize) -> Self {
Self {
pending: Mutex::new(VecDeque::new()),
max_pending,
id_counter: AtomicU64::new(1),
}
}
pub fn submit(
&self,
data: Vec<T>,
direction: TransferDirection,
) -> Result<TransferHandle, AsyncTransferError> {
let mut queue = self
.pending
.lock()
.map_err(|_| AsyncTransferError::LockError)?;
let in_flight = queue.iter().filter(|t| !t.handle.is_complete()).count();
if in_flight >= self.max_pending {
return Err(AsyncTransferError::PipelineFull(in_flight));
}
let id = self.id_counter.fetch_add(1, Ordering::Relaxed);
let completed = Arc::new(AtomicBool::new(false));
let handle = TransferHandle {
id,
completed: Arc::clone(&completed),
};
completed.store(true, Ordering::Release);
queue.push_back(AsyncTransfer {
data,
direction,
handle: handle.clone(),
});
Ok(handle)
}
pub fn is_complete(&self, handle: &TransferHandle) -> bool {
handle.is_complete()
}
pub fn flush(&self) -> Result<(), AsyncTransferError> {
let mut queue = self
.pending
.lock()
.map_err(|_| AsyncTransferError::LockError)?;
queue.retain(|transfer| !transfer.handle.is_complete());
Ok(())
}
pub fn pending_count(&self) -> usize {
self.pending.lock().map(|q| q.len()).unwrap_or(0)
}
pub fn in_flight_count(&self) -> usize {
self.pending
.lock()
.map(|q| q.iter().filter(|t| !t.handle.is_complete()).count())
.unwrap_or(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_async_transfer_submit() {
let pipeline: AsyncTransferPipeline<f32> = AsyncTransferPipeline::new(8);
let data = vec![1.0_f32, 2.0, 3.0, 4.0];
let handle = pipeline
.submit(data.clone(), TransferDirection::HostToDevice)
.expect("submit should succeed");
assert!(
handle.is_complete(),
"handle should be complete immediately in CPU mode"
);
assert!(
pipeline.is_complete(&handle),
"pipeline.is_complete should match handle"
);
}
#[test]
fn test_async_transfer_pipeline_flush() {
let pipeline: AsyncTransferPipeline<u8> = AsyncTransferPipeline::new(16);
for i in 0..8_u8 {
let data = vec![i; 64];
pipeline
.submit(data, TransferDirection::DeviceToHost)
.expect("submit should succeed");
}
assert_eq!(pipeline.pending_count(), 8);
pipeline.flush().expect("flush should succeed");
assert_eq!(pipeline.pending_count(), 0);
}
#[test]
fn test_async_transfer_pipeline_full() {
let pipeline: AsyncTransferPipeline<f32> = AsyncTransferPipeline::new(0);
let result = pipeline.submit(vec![0.0_f32; 4], TransferDirection::HostToDevice);
match result {
Err(AsyncTransferError::PipelineFull(count)) => {
assert_eq!(count, 0, "should report 0 in-flight when cap is 0");
}
other => panic!("expected PipelineFull, got {:?}", other),
}
}
#[test]
fn test_async_transfer_device_to_device() {
let pipeline: AsyncTransferPipeline<i32> = AsyncTransferPipeline::new(4);
let handle = pipeline
.submit(vec![42_i32; 32], TransferDirection::DeviceToDevice)
.expect("submit should succeed");
assert!(handle.is_complete());
}
#[test]
fn test_async_transfer_multiple_handles() {
let pipeline: AsyncTransferPipeline<f64> = AsyncTransferPipeline::new(16);
let mut handles = Vec::new();
for _ in 0..5 {
let h = pipeline
.submit(vec![1.0_f64; 8], TransferDirection::HostToDevice)
.expect("submit should succeed");
handles.push(h);
}
for (i, h) in handles.iter().enumerate() {
assert!(h.is_complete(), "handle {} should be complete", i);
}
assert_eq!(pipeline.in_flight_count(), 0);
}
}