use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::sync::Condvar;
use std::sync::Mutex;
use std::time::{Duration, Instant};
pub struct MkTransferHandle {
inner: Arc<TransferInner>,
}
struct TransferInner {
id: u64,
completed: AtomicBool,
wait_mutex: Mutex<()>,
wait_condvar: Condvar,
}
impl MkTransferHandle {
pub fn new(id: u64) -> Self {
Self {
inner: Arc::new(TransferInner {
id,
completed: AtomicBool::new(false),
wait_mutex: Mutex::new(()),
wait_condvar: Condvar::new(),
}),
}
}
pub fn id(&self) -> u64 {
self.inner.id
}
pub fn is_complete(&self) -> bool {
self.inner.completed.load(Ordering::Acquire)
}
pub fn wait(&self) {
if self.is_complete() {
return;
}
let _guard = self.inner.wait_mutex.lock().unwrap();
let mut guard = _guard;
if self.is_complete() {
return;
}
while !self.inner.completed.load(Ordering::Acquire) {
guard = self.inner.wait_condvar.wait(guard).unwrap();
}
}
pub fn wait_timeout(&self, timeout: Duration) -> bool {
if self.is_complete() {
return true;
}
let deadline = Instant::now() + timeout;
let guard = self.inner.wait_mutex.lock().unwrap();
let mut guard = guard;
if self.is_complete() {
return true;
}
while !self.inner.completed.load(Ordering::Acquire) {
let now = Instant::now();
if now >= deadline {
return false;
}
let remaining = deadline - now;
let result = self.inner.wait_condvar.wait_timeout(guard, remaining).unwrap();
guard = result.0;
if result.1.timed_out() {
return false;
}
if self.is_complete() {
return true;
}
}
true
}
pub fn try_wait(&self) -> bool {
self.is_complete()
}
pub fn mark_complete(&self) {
self.inner.completed.store(true, Ordering::Release);
let _guard = self.inner.wait_mutex.lock().unwrap();
self.inner.wait_condvar.notify_all();
}
pub fn completed(id: u64) -> Self {
let handle = Self::new(id);
handle.mark_complete();
handle
}
pub async fn wait_async(&self) {
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
struct TransferFuture<'a> {
handle: &'a MkTransferHandle,
}
impl<'a> Future for TransferFuture<'a> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.handle.is_complete() {
Poll::Ready(())
} else {
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
TransferFuture { handle: self }.await
}
}
impl Clone for MkTransferHandle {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl Drop for TransferInner {
fn drop(&mut self) {
if !self.completed.load(Ordering::Acquire) {
self.completed.store(true, Ordering::Release);
let _guard = self.wait_mutex.lock().unwrap();
self.wait_condvar.notify_all();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn test_transfer_handle_basic() {
let handle = MkTransferHandle::new(42);
assert!(!handle.is_complete());
assert_eq!(handle.id(), 42);
let handle_clone = handle.clone();
thread::spawn(move || {
thread::sleep(Duration::from_millis(50));
handle_clone.mark_complete();
});
handle.wait();
assert!(handle.is_complete());
}
#[test]
fn test_transfer_handle_timeout() {
let handle = MkTransferHandle::new(100);
let start = Instant::now();
let result = handle.wait_timeout(Duration::from_millis(50));
let elapsed = start.elapsed();
assert!(!result);
assert!(elapsed >= Duration::from_millis(45)); }
#[test]
fn test_transfer_handle_completed() {
let handle = MkTransferHandle::completed(200);
assert!(handle.is_complete());
assert_eq!(handle.id(), 200);
handle.wait();
assert!(handle.wait_timeout(Duration::from_millis(100)));
}
#[test]
fn test_multiple_waiters() {
let handle = Arc::new(MkTransferHandle::new(300));
let mut handles = vec![];
for _ in 0..5 {
let h = handle.clone();
let waiter = thread::spawn(move || {
h.wait();
assert!(h.is_complete());
});
handles.push(waiter);
}
thread::sleep(Duration::from_millis(50));
handle.mark_complete();
for h in handles {
h.join().unwrap();
}
}
#[test]
fn test_try_wait() {
let handle = MkTransferHandle::new(400);
assert!(!handle.try_wait());
handle.mark_complete();
assert!(handle.try_wait());
}
#[test]
fn test_double_complete() {
let handle = MkTransferHandle::new(500);
handle.mark_complete();
assert!(handle.is_complete());
handle.mark_complete();
assert!(handle.is_complete());
}
}