memkit 0.2.0-beta.1

Deterministic, intent-driven memory allocation for systems requiring predictable performance
Documentation
//! Memory transfer handles.

use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::sync::Condvar;
use std::sync::Mutex;
use std::time::{Duration, Instant};

/// A handle for tracking memory transfers.
/// 
/// This handle provides thread-safe methods to check and wait for the completion
/// of memory transfers (e.g., CPU-to-GPU transfers, cross-thread transfers).
pub struct MkTransferHandle {
    inner: Arc<TransferInner>,
}

struct TransferInner {
    /// Unique transfer ID
    id: u64,
    /// Completion state
    completed: AtomicBool,
    /// Mutex and condvar for blocking wait
    wait_mutex: Mutex<()>,
    wait_condvar: Condvar,
}

impl MkTransferHandle {
    /// Create a new transfer handle.
    pub fn new(id: u64) -> Self {
        Self {
            inner: Arc::new(TransferInner {
                id,
                completed: AtomicBool::new(false),
                wait_mutex: Mutex::new(()),
                wait_condvar: Condvar::new(),
            }),
        }
    }

    /// Get the transfer ID.
    pub fn id(&self) -> u64 {
        self.inner.id
    }

    /// Check if the transfer is complete.
    /// 
    /// This method is non-blocking and returns immediately.
    pub fn is_complete(&self) -> bool {
        self.inner.completed.load(Ordering::Acquire)
    }

    /// Wait for the transfer to complete.
    /// 
    /// This method blocks until the transfer is marked as complete.
    /// If the transfer is already complete, it returns immediately.
    pub fn wait(&self) {
        // Fast path: already complete
        if self.is_complete() {
            return;
        }

        // Slow path: wait for completion
        let _guard = self.inner.wait_mutex.lock().unwrap();
        let mut guard = _guard;
        
        // Check again in case it completed while we were acquiring the lock
        if self.is_complete() {
            return;
        }

        // Wait until completed
        while !self.inner.completed.load(Ordering::Acquire) {
            guard = self.inner.wait_condvar.wait(guard).unwrap();
        }
    }

    /// Wait for the transfer to complete with a timeout.
    /// 
    /// Returns `true` if the transfer completed within the timeout,
    /// or `false` if the timeout expired.
    pub fn wait_timeout(&self, timeout: Duration) -> bool {
        // Fast path: already complete
        if self.is_complete() {
            return true;
        }

        // Slow path: wait with timeout
        let deadline = Instant::now() + timeout;
        let guard = self.inner.wait_mutex.lock().unwrap();
        let mut guard = guard;

        // Check again in case it completed while we were acquiring the lock
        if self.is_complete() {
            return true;
        }

        // Wait until completed or timeout
        while !self.inner.completed.load(Ordering::Acquire) {
            let now = Instant::now();
            if now >= deadline {
                return false;
            }

            let remaining = deadline - now;
            let result = self.inner.wait_condvar.wait_timeout(guard, remaining).unwrap();
            guard = result.0;

            if result.1.timed_out() {
                return false;
            }

            if self.is_complete() {
                return true;
            }
        }

        true
    }

    /// Try to wait without blocking.
    /// 
    /// Returns `true` if the transfer is complete, `false` otherwise.
    pub fn try_wait(&self) -> bool {
        self.is_complete()
    }

    /// Mark the transfer as complete.
    /// 
    /// This method is typically called by the transfer system when the operation
    /// finishes (e.g., GPU command completes, data transfer finishes).
    /// 
    /// # Safety
    /// 
    /// This should only be called by the transfer system that created the handle.
    /// Multiple calls are safe but idempotent.
    pub fn mark_complete(&self) {
        // Mark as complete
        self.inner.completed.store(true, Ordering::Release);

        // Wake all waiting threads
        let _guard = self.inner.wait_mutex.lock().unwrap();
        self.inner.wait_condvar.notify_all();
    }

    /// Create a completed transfer handle (for testing or immediate completion).
    pub fn completed(id: u64) -> Self {
        let handle = Self::new(id);
        handle.mark_complete();
        handle
    }

    /// Get a future that completes when the transfer does.
    /// 
    /// This requires the `async` feature to be enabled.
    pub async fn wait_async(&self) {
        use std::future::Future;
        use std::pin::Pin;
        use std::task::{Context, Poll};

        struct TransferFuture<'a> {
            handle: &'a MkTransferHandle,
        }

        impl<'a> Future for TransferFuture<'a> {
            type Output = ();

            fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
                if self.handle.is_complete() {
                    Poll::Ready(())
                } else {
                    // Register waker and check again
                    // In a real implementation, we'd store the waker
                    // For now, just yield and check again
                    cx.waker().wake_by_ref();
                    Poll::Pending
                }
            }
        }

        TransferFuture { handle: self }.await
    }
}

impl Clone for MkTransferHandle {
    fn clone(&self) -> Self {
        Self {
            inner: Arc::clone(&self.inner),
        }
    }
}

impl Drop for TransferInner {
    fn drop(&mut self) {
        // Ensure all waiters are woken up when the handle is dropped
        // This prevents deadlocks if the transfer system crashes
        if !self.completed.load(Ordering::Acquire) {
            self.completed.store(true, Ordering::Release);
            let _guard = self.wait_mutex.lock().unwrap();
            self.wait_condvar.notify_all();
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::Arc;
    use std::thread;

    #[test]
    fn test_transfer_handle_basic() {
        let handle = MkTransferHandle::new(42);
        assert!(!handle.is_complete());
        assert_eq!(handle.id(), 42);

        // Mark complete from another thread
        let handle_clone = handle.clone();
        thread::spawn(move || {
            thread::sleep(Duration::from_millis(50));
            handle_clone.mark_complete();
        });

        // Wait should block until completion
        handle.wait();
        assert!(handle.is_complete());
    }

    #[test]
    fn test_transfer_handle_timeout() {
        let handle = MkTransferHandle::new(100);
        
        // Should timeout
        let start = Instant::now();
        let result = handle.wait_timeout(Duration::from_millis(50));
        let elapsed = start.elapsed();
        
        assert!(!result);
        assert!(elapsed >= Duration::from_millis(45)); // Allow some variance
    }

    #[test]
    fn test_transfer_handle_completed() {
        let handle = MkTransferHandle::completed(200);
        assert!(handle.is_complete());
        assert_eq!(handle.id(), 200);

        // Wait should return immediately
        handle.wait();
        
        // Timeout should return true immediately
        assert!(handle.wait_timeout(Duration::from_millis(100)));
    }

    #[test]
    fn test_multiple_waiters() {
        let handle = Arc::new(MkTransferHandle::new(300));
        let mut handles = vec![];

        // Spawn multiple waiters
        for _ in 0..5 {
            let h = handle.clone();
            let waiter = thread::spawn(move || {
                h.wait();
                assert!(h.is_complete());
            });
            handles.push(waiter);
        }

        // Complete after a delay
        thread::sleep(Duration::from_millis(50));
        handle.mark_complete();

        // All waiters should complete
        for h in handles {
            h.join().unwrap();
        }
    }

    #[test]
    fn test_try_wait() {
        let handle = MkTransferHandle::new(400);
        
        // Should return false initially
        assert!(!handle.try_wait());
        
        // Mark complete
        handle.mark_complete();
        
        // Should return true now
        assert!(handle.try_wait());
    }

    #[test]
    fn test_double_complete() {
        let handle = MkTransferHandle::new(500);
        
        handle.mark_complete();
        assert!(handle.is_complete());
        
        // Second call should be safe
        handle.mark_complete();
        assert!(handle.is_complete());
    }
}