use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PollOrder {
DirectFirst,
RelayFirst,
}
#[derive(Debug)]
pub struct FairPoller {
counter: AtomicU64,
}
impl FairPoller {
pub fn new() -> Self {
Self {
counter: AtomicU64::new(0),
}
}
pub fn poll_order(&self) -> PollOrder {
let count = self.counter.fetch_add(1, Ordering::Relaxed);
if count % 2 == 0 {
PollOrder::DirectFirst
} else {
PollOrder::RelayFirst
}
}
pub fn peek_order(&self) -> PollOrder {
let count = self.counter.load(Ordering::Relaxed);
if count % 2 == 0 {
PollOrder::DirectFirst
} else {
PollOrder::RelayFirst
}
}
pub fn reset(&self) {
self.counter.store(0, Ordering::Relaxed);
}
pub fn counter(&self) -> u64 {
self.counter.load(Ordering::Relaxed)
}
#[cfg(test)]
pub fn set_counter(&self, value: u64) {
self.counter.store(value, Ordering::Relaxed);
}
}
impl Default for FairPoller {
fn default() -> Self {
Self::new()
}
}
#[macro_export]
macro_rules! poll_transports_fair {
($poller:expr, $direct:expr, $relay:expr) => {{
use $crate::fair_polling::PollOrder;
match $poller.poll_order() {
PollOrder::DirectFirst => {
if let Some(result) = $direct {
Some(result)
} else {
$relay
}
}
PollOrder::RelayFirst => {
if let Some(result) = $relay {
Some(result)
} else {
$direct
}
}
}
}};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_alternating_poll_order() {
let poller = FairPoller::new();
let order1 = poller.poll_order();
let order2 = poller.poll_order();
let order3 = poller.poll_order();
let order4 = poller.poll_order();
assert_eq!(order1, PollOrder::DirectFirst);
assert_eq!(order2, PollOrder::RelayFirst);
assert_eq!(order3, PollOrder::DirectFirst);
assert_eq!(order4, PollOrder::RelayFirst);
}
#[test]
fn test_counter_wraps() {
let poller = FairPoller::new();
poller.set_counter(u64::MAX);
let _ = poller.poll_order();
let _ = poller.poll_order();
assert!(poller.counter() < u64::MAX);
}
#[test]
fn test_poll_order_is_deterministic() {
let poller = FairPoller::new();
poller.set_counter(0);
assert_eq!(poller.poll_order(), PollOrder::DirectFirst);
poller.set_counter(1);
assert_eq!(poller.poll_order(), PollOrder::RelayFirst);
}
#[test]
fn test_peek_does_not_increment() {
let poller = FairPoller::new();
let peek1 = poller.peek_order();
let peek2 = poller.peek_order();
let peek3 = poller.peek_order();
assert_eq!(peek1, peek2);
assert_eq!(peek2, peek3);
assert_eq!(poller.counter(), 0);
}
#[test]
fn test_reset() {
let poller = FairPoller::new();
poller.poll_order();
poller.poll_order();
poller.poll_order();
assert_eq!(poller.counter(), 3);
poller.reset();
assert_eq!(poller.counter(), 0);
assert_eq!(poller.peek_order(), PollOrder::DirectFirst);
}
#[test]
fn test_default() {
let poller = FairPoller::default();
assert_eq!(poller.counter(), 0);
}
#[test]
fn test_poll_transports_fair_macro_direct_first() {
let poller = FairPoller::new();
poller.set_counter(0);
let direct = Some(1);
let relay: Option<i32> = Some(2);
let result = poll_transports_fair!(poller, direct, relay);
assert_eq!(result, Some(1)); }
#[test]
fn test_poll_transports_fair_macro_relay_first() {
let poller = FairPoller::new();
poller.set_counter(1);
let direct: Option<i32> = Some(1);
let relay = Some(2);
let result = poll_transports_fair!(poller, direct, relay);
assert_eq!(result, Some(2)); }
#[test]
fn test_poll_transports_fair_macro_fallback() {
let poller = FairPoller::new();
poller.set_counter(0);
let direct: Option<i32> = None;
let relay = Some(2);
let result = poll_transports_fair!(poller, direct, relay);
assert_eq!(result, Some(2)); }
#[test]
fn test_poll_transports_fair_macro_both_none() {
let poller = FairPoller::new();
let direct: Option<i32> = None;
let relay: Option<i32> = None;
let result = poll_transports_fair!(poller, direct, relay);
assert_eq!(result, None);
}
#[test]
fn test_concurrent_access() {
use std::sync::Arc;
use std::thread;
let poller = Arc::new(FairPoller::new());
let mut handles = vec![];
for _ in 0..10 {
let p = Arc::clone(&poller);
handles.push(thread::spawn(move || {
for _ in 0..100 {
let _ = p.poll_order();
}
}));
}
for handle in handles {
handle.join().expect("Thread panicked");
}
assert_eq!(poller.counter(), 1000);
}
}