memkit 0.2.0-beta.1

Deterministic, intent-driven memory allocation for systems requiring predictable performance
Documentation
//! Memory barriers for thread synchronization.

use std::sync::atomic::{AtomicUsize, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};

/// A memory barrier for synchronization.
/// 
/// This barrier allows multiple threads to wait until a specified number of threads
/// have reached the barrier point. It supports both blocking and timeout-based waiting.
pub struct MkBarrier {
    inner: Arc<BarrierInner>,
}

struct BarrierInner {
    /// Number of threads that need to reach the barrier
    thread_count: usize,
    /// Number of threads currently waiting at the barrier
    waiting_count: AtomicUsize,
    /// Generation number to prevent re-use issues
    generation: AtomicU64,
    /// Mutex for condition variable (for blocking wait)
    mutex: std::sync::Mutex<()>,
    /// Condition variable for blocking wait
    condvar: std::sync::Condvar,
}

impl MkBarrier {
    /// Create a new barrier that waits for the specified number of threads.
    pub fn new(thread_count: usize) -> Self {
        Self {
            inner: Arc::new(BarrierInner {
                thread_count,
                waiting_count: AtomicUsize::new(0),
                generation: AtomicU64::new(0),
                mutex: std::sync::Mutex::new(()),
                condvar: std::sync::Condvar::new(),
            }),
        }
    }

    /// Wait on the barrier until all threads have reached it.
    /// 
    /// This method blocks until the required number of threads have called `wait`.
    /// Returns the generation number of this barrier cycle.
    pub fn wait(&self) -> u64 {
        let generation = self.inner.generation.load(Ordering::Acquire);
        let waiting = self.inner.waiting_count.fetch_add(1, Ordering::AcqRel) + 1;

        // If we're the last thread, wake everyone up
        if waiting == self.inner.thread_count {
            // Increment generation for next cycle
            self.inner.generation.fetch_add(1, Ordering::AcqRel);
            // Reset waiting count
            self.inner.waiting_count.store(0, Ordering::Release);
            
            // Wake all waiting threads
            let _guard = self.inner.mutex.lock().unwrap();
            self.inner.condvar.notify_all();
            
            generation
        } else {
            // Wait for the last thread
            let guard = self.inner.mutex.lock().unwrap();
            let mut guard = guard;
            
            // Check if generation changed while we were acquiring the lock
            if self.inner.generation.load(Ordering::Acquire) != generation {
                return generation;
            }
            
            // Wait until generation changes
            while self.inner.generation.load(Ordering::Acquire) == generation {
                guard = self.inner.condvar.wait(guard).unwrap();
            }
            
            generation
        }
    }

    /// Wait on the barrier with a timeout.
    /// 
    /// Returns `Some(generation)` if the barrier was reached within the timeout,
    /// or `None` if the timeout expired.
    pub fn wait_timeout(&self, mut timeout: Duration) -> Option<u64> {
        let generation = self.inner.generation.load(Ordering::Acquire);
        let waiting = self.inner.waiting_count.fetch_add(1, Ordering::AcqRel) + 1;

        // If we're the last thread, wake everyone up
        if waiting == self.inner.thread_count {
            // Increment generation for next cycle
            self.inner.generation.fetch_add(1, Ordering::AcqRel);
            // Reset waiting count
            self.inner.waiting_count.store(0, Ordering::Release);
            
            // Wake all waiting threads
            let _guard = self.inner.mutex.lock().unwrap();
            self.inner.condvar.notify_all();
            
            Some(generation)
        } else {
            // Wait with timeout
            let deadline = Instant::now() + timeout;
            let guard = self.inner.mutex.lock().unwrap();
            let mut guard = guard;
            
            // Check if generation changed while we were acquiring the lock
            if self.inner.generation.load(Ordering::Acquire) != generation {
                return Some(generation);
            }
            
            // Wait until generation changes or timeout
            while self.inner.generation.load(Ordering::Acquire) == generation {
                let result = self.inner.condvar.wait_timeout(guard, timeout).unwrap();
                guard = result.0;
                
                if result.1.timed_out() {
                    // Timeout reached - decrement waiting count
                    self.inner.waiting_count.fetch_sub(1, Ordering::AcqRel);
                    return None;
                }
                
                if self.inner.generation.load(Ordering::Acquire) != generation {
                    return Some(generation);
                }
                
                // Update remaining timeout
                let now = Instant::now();
                if now >= deadline {
                    // Timeout reached - decrement waiting count
                    self.inner.waiting_count.fetch_sub(1, Ordering::AcqRel);
                    return None;
                }
                
                timeout = deadline - now;
            }
            
            Some(generation)
        }
    }

    /// Try to wait on the barrier without blocking.
    /// 
    /// Returns `Some(generation)` if the barrier is already satisfied,
    /// or `None` if we need to wait.
    pub fn try_wait(&self) -> Option<u64> {
        let _generation = self.inner.generation.load(Ordering::Acquire);
        let waiting = self.inner.waiting_count.load(Ordering::Acquire);
        
        // If all threads are already waiting, we can complete immediately
        if waiting + 1 >= self.inner.thread_count {
            Some(self.wait())
        } else {
            None
        }
    }

    /// Get the number of threads this barrier is waiting for.
    pub fn thread_count(&self) -> usize {
        self.inner.thread_count
    }

    /// Get the current number of threads waiting at the barrier.
    pub fn waiting_count(&self) -> usize {
        self.inner.waiting_count.load(Ordering::Acquire)
    }

    /// Get the current generation number.
    pub fn generation(&self) -> u64 {
        self.inner.generation.load(Ordering::Acquire)
    }
}

impl Default for MkBarrier {
    fn default() -> Self {
        Self::new(2) // Default to 2 threads for common use case
    }
}

impl Clone for MkBarrier {
    fn clone(&self) -> Self {
        Self {
            inner: Arc::clone(&self.inner),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::Arc;
    use std::thread;

    #[test]
    fn test_barrier_basic() {
        let barrier = Arc::new(MkBarrier::new(3));
        let mut handles = vec![];

        for i in 0..3 {
            let b = barrier.clone();
            let handle = thread::spawn(move || {
                // Simulate some work
                thread::sleep(Duration::from_millis(10 * i));
                let gen = b.wait();
                gen
            });
            handles.push(handle);
        }

        // All threads should complete and return the same generation
        let generations: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
        assert_eq!(generations.len(), 3);
        assert!(generations.windows(2).all(|w| w[0] == w[1]));
    }

    #[test]
    fn test_barrier_timeout() {
        let barrier = Arc::new(MkBarrier::new(3));
        let mut handles = vec![];

        // Spawn 2 out of 3 threads
        for i in 0..2 {
            let b = barrier.clone();
            let handle = thread::spawn(move || {
                let result = b.wait_timeout(Duration::from_millis(50));
                assert!(result.is_none()); // Should timeout
            });
            handles.push(handle);
        }

        for h in handles {
            h.join().unwrap();
        }
    }

    #[test]
    fn test_barrier_try_wait() {
        let barrier = MkBarrier::new(2);
        
        // First try_wait should return None (not enough threads)
        assert!(barrier.try_wait().is_none());
        
        // After waiting, should succeed
        let barrier2 = barrier.clone();
        thread::spawn(move || {
            barrier2.wait();
        });
        
        thread::sleep(Duration::from_millis(10));
        assert!(barrier.try_wait().is_some());
    }
}