use cpu::{Cursor, WaitStrategy};
pub struct SequenceBarrier<W: WaitStrategy> {
cursor: *const Cursor,
wait_strategy: W,
}
unsafe impl<W: WaitStrategy> Send for SequenceBarrier<W> {}
unsafe impl<W: WaitStrategy> Sync for SequenceBarrier<W> {}
impl<W: WaitStrategy> SequenceBarrier<W> {
#[inline]
pub unsafe fn new(cursor: *const Cursor, wait_strategy: W) -> Self {
Self {
cursor,
wait_strategy,
}
}
#[inline]
pub fn wait_for(&self, sequence: i64) -> i64 {
let cursor = unsafe { &*self.cursor };
self.wait_strategy.wait_for(sequence, cursor.as_atomic())
}
#[inline]
pub fn is_available(&self, sequence: i64) -> bool {
let current = unsafe { (*self.cursor).value() };
current >= sequence
}
#[inline]
pub fn get_cursor(&self) -> i64 {
unsafe { (*self.cursor).value() }
}
#[inline]
pub fn get_cursor_relaxed(&self) -> i64 {
unsafe { (*self.cursor).value_relaxed() }
}
#[inline]
pub fn signal(&self) {
self.wait_strategy.signal();
}
}
impl<W: WaitStrategy + Clone> Clone for SequenceBarrier<W> {
fn clone(&self) -> Self {
Self {
cursor: self.cursor,
wait_strategy: self.wait_strategy.clone(),
}
}
}
impl<W: WaitStrategy + core::fmt::Debug> core::fmt::Debug for SequenceBarrier<W> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("SequenceBarrier")
.field("cursor_value", &self.get_cursor_relaxed())
.field("wait_strategy", &self.wait_strategy)
.finish()
}
}
pub struct MultiSequenceBarrier<W: WaitStrategy> {
cursors: Vec<*const Cursor>,
wait_strategy: W,
}
unsafe impl<W: WaitStrategy> Send for MultiSequenceBarrier<W> {}
unsafe impl<W: WaitStrategy> Sync for MultiSequenceBarrier<W> {}
impl<W: WaitStrategy> MultiSequenceBarrier<W> {
#[inline]
pub unsafe fn new(cursors: Vec<*const Cursor>, wait_strategy: W) -> Self {
Self {
cursors,
wait_strategy,
}
}
#[inline]
pub fn get_minimum_sequence(&self) -> i64 {
let mut min = i64::MAX;
for &cursor_ptr in &self.cursors {
let value = unsafe { (*cursor_ptr).value_relaxed() };
if value < min {
min = value;
}
}
min
}
pub fn wait_for(&self, sequence: i64) -> i64 {
loop {
let min = self.get_minimum_sequence();
if min >= sequence {
return min;
}
core::hint::spin_loop();
}
}
#[inline]
pub fn is_available(&self, sequence: i64) -> bool {
self.get_minimum_sequence() >= sequence
}
#[inline]
pub fn cursor_count(&self) -> usize {
self.cursors.len()
}
#[inline]
pub fn signal(&self) {
self.wait_strategy.signal();
}
}
#[cfg(test)]
mod tests {
use super::*;
use cpu::SpinLoopHintWait;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
#[test]
fn test_sequence_barrier_immediate() {
let cursor = Cursor::with_value(10);
let barrier = unsafe { SequenceBarrier::new(&cursor, SpinLoopHintWait) };
let result = barrier.wait_for(5);
assert_eq!(result, 10);
}
#[test]
fn test_sequence_barrier_is_available() {
let cursor = Cursor::with_value(10);
let barrier = unsafe { SequenceBarrier::new(&cursor, SpinLoopHintWait) };
assert!(barrier.is_available(5));
assert!(barrier.is_available(10));
assert!(!barrier.is_available(11));
}
#[test]
fn test_sequence_barrier_get_cursor() {
let cursor = Cursor::with_value(42);
let barrier = unsafe { SequenceBarrier::new(&cursor, SpinLoopHintWait) };
assert_eq!(barrier.get_cursor(), 42);
assert_eq!(barrier.get_cursor_relaxed(), 42);
}
#[test]
fn test_sequence_barrier_with_producer() {
let cursor = Arc::new(Cursor::new()); let cursor_ptr = Arc::as_ptr(&cursor) as *const Cursor;
let cursor_clone = Arc::clone(&cursor);
let producer = thread::spawn(move || {
thread::sleep(Duration::from_millis(10));
cursor_clone.set(5);
});
let barrier = unsafe { SequenceBarrier::new(cursor_ptr, SpinLoopHintWait) };
let result = barrier.wait_for(5);
assert!(result >= 5);
producer.join().unwrap();
}
#[test]
fn test_multi_sequence_barrier() {
let cursor1 = Cursor::with_value(10);
let cursor2 = Cursor::with_value(5);
let cursor3 = Cursor::with_value(15);
let cursors = vec![
&cursor1 as *const Cursor,
&cursor2 as *const Cursor,
&cursor3 as *const Cursor,
];
let barrier = unsafe { MultiSequenceBarrier::new(cursors, SpinLoopHintWait) };
assert_eq!(barrier.get_minimum_sequence(), 5);
assert!(barrier.is_available(5));
assert!(!barrier.is_available(6));
assert_eq!(barrier.cursor_count(), 3);
}
#[test]
fn test_multi_sequence_barrier_wait() {
let cursor1 = Cursor::with_value(10);
let cursor2 = Cursor::with_value(8);
let cursors = vec![&cursor1 as *const Cursor, &cursor2 as *const Cursor];
let barrier = unsafe { MultiSequenceBarrier::new(cursors, SpinLoopHintWait) };
let result = barrier.wait_for(5);
assert_eq!(result, 8);
}
#[test]
fn test_barrier_debug() {
let cursor = Cursor::with_value(42);
let barrier = unsafe { SequenceBarrier::new(&cursor, SpinLoopHintWait) };
let debug = format!("{:?}", barrier);
assert!(debug.contains("SequenceBarrier"));
assert!(debug.contains("42"));
}
}