use std::marker::PhantomData;
use std::sync::atomic::{AtomicU64, AtomicBool, Ordering};
use std::thread::ThreadId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TransferId(u64);
impl TransferId {
fn new() -> Self {
static COUNTER: AtomicU64 = AtomicU64::new(0);
Self(COUNTER.fetch_add(1, Ordering::Relaxed))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransferState {
Owned,
InFlight,
Received,
Consumed,
}
pub struct TransferHandle<T> {
ptr: *mut T,
size: usize,
id: TransferId,
origin_thread: ThreadId,
state: TransferState,
received: AtomicBool,
_marker: PhantomData<T>,
}
unsafe impl<T: Send> Send for TransferHandle<T> {}
impl<T> TransferHandle<T> {
pub(crate) fn new(ptr: *mut T, size: usize) -> Self {
Self {
ptr,
size,
id: TransferId::new(),
origin_thread: std::thread::current().id(),
state: TransferState::Owned,
received: AtomicBool::new(false),
_marker: PhantomData,
}
}
pub fn id(&self) -> TransferId {
self.id
}
pub fn origin_thread(&self) -> ThreadId {
self.origin_thread
}
pub fn state(&self) -> TransferState {
self.state
}
pub fn size(&self) -> usize {
self.size
}
pub fn mark_sent(&mut self) {
self.state = TransferState::InFlight;
}
pub fn receive(&mut self) -> &mut T {
if self.received.swap(true, Ordering::SeqCst) {
panic!("TransferHandle::receive called more than once");
}
self.state = TransferState::Received;
unsafe { &mut *self.ptr }
}
pub fn receive_owned(mut self) -> T
where
T: Clone,
{
let data = self.receive().clone();
self.state = TransferState::Consumed;
data
}
pub fn is_received(&self) -> bool {
self.received.load(Ordering::SeqCst)
}
pub unsafe fn as_ptr(&self) -> *mut T {
self.ptr
}
}
impl<T> Drop for TransferHandle<T> {
fn drop(&mut self) {
if !self.is_received() && self.state != TransferState::Consumed {
#[cfg(feature = "debug")]
eprintln!(
"TransferHandle dropped without being received (id: {:?}, origin: {:?})",
self.id, self.origin_thread
);
}
}
}
#[derive(Debug, Default, Clone)]
pub struct TransferStats {
pub transfers_initiated: u64,
pub transfers_completed: u64,
pub transfers_dropped: u64,
pub bytes_transferred: u64,
}
#[derive(Default)]
pub struct TransferRegistry {
stats: std::sync::Mutex<TransferStats>,
}
impl TransferRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn record_initiated(&self, size: usize) {
let mut stats = self.stats.lock().unwrap();
stats.transfers_initiated += 1;
stats.bytes_transferred += size as u64;
}
pub fn record_completed(&self) {
let mut stats = self.stats.lock().unwrap();
stats.transfers_completed += 1;
}
pub fn record_dropped(&self) {
let mut stats = self.stats.lock().unwrap();
stats.transfers_dropped += 1;
}
pub fn stats(&self) -> TransferStats {
self.stats.lock().unwrap().clone()
}
pub fn reset_stats(&self) {
let mut stats = self.stats.lock().unwrap();
*stats = TransferStats::default();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transfer_id_unique() {
let id1 = TransferId::new();
let id2 = TransferId::new();
assert_ne!(id1, id2);
}
#[test]
fn test_transfer_handle_state() {
let data = Box::into_raw(Box::new(42u32));
let mut handle = TransferHandle::new(data, 4);
assert_eq!(handle.state(), TransferState::Owned);
assert!(!handle.is_received());
handle.mark_sent();
assert_eq!(handle.state(), TransferState::InFlight);
let value = handle.receive();
assert_eq!(*value, 42);
assert_eq!(handle.state(), TransferState::Received);
assert!(handle.is_received());
unsafe { let _ = Box::from_raw(data); }
}
}