use crossbeam_channel as cb;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ThreadAffinity {
Main,
Named(String),
#[default]
Any,
}
impl ThreadAffinity {
pub fn is_pinned(&self) -> bool {
!matches!(self, Self::Any)
}
pub fn thread_name(&self) -> Option<&str> {
match self {
Self::Named(n) => Some(n.as_str()),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
pub struct PumpStats {
pub processed: usize,
pub elapsed: Duration,
pub remaining: usize,
}
type WorkFn = Box<dyn FnOnce() + Send + 'static>;
struct PumpInner {
tx: cb::Sender<WorkFn>,
pending: AtomicUsize,
total_dispatched: AtomicU64,
total_processed: AtomicU64,
}
#[derive(Clone)]
pub struct MainThreadPump {
inner: Arc<PumpInner>,
rx: Arc<cb::Receiver<WorkFn>>,
}
impl MainThreadPump {
pub fn new() -> Self {
let (tx, rx) = cb::unbounded();
Self {
inner: Arc::new(PumpInner {
tx,
pending: AtomicUsize::new(0),
total_dispatched: AtomicU64::new(0),
total_processed: AtomicU64::new(0),
}),
rx: Arc::new(rx),
}
}
pub fn dispatch<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
self.inner.pending.fetch_add(1, Ordering::Relaxed);
self.inner.total_dispatched.fetch_add(1, Ordering::Relaxed);
self.inner
.tx
.send(Box::new(f))
.expect("pump receiver dropped");
}
pub fn pump(&self, budget: Duration) -> PumpStats {
let start = Instant::now();
let mut processed = 0;
loop {
if start.elapsed() >= budget {
break;
}
match self.rx.try_recv() {
Ok(f) => {
f();
self.inner.pending.fetch_sub(1, Ordering::Relaxed);
self.inner.total_processed.fetch_add(1, Ordering::Relaxed);
processed += 1;
}
Err(cb::TryRecvError::Empty) => break,
Err(cb::TryRecvError::Disconnected) => break,
}
}
PumpStats {
processed,
elapsed: start.elapsed(),
remaining: self.inner.pending.load(Ordering::Relaxed),
}
}
pub fn pending(&self) -> usize {
self.inner.pending.load(Ordering::Relaxed)
}
pub fn total_dispatched(&self) -> u64 {
self.inner.total_dispatched.load(Ordering::Relaxed)
}
pub fn total_processed(&self) -> u64 {
self.inner.total_processed.load(Ordering::Relaxed)
}
}
impl Default for MainThreadPump {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicU32;
use std::thread;
#[test]
fn test_thread_affinity_default_is_any() {
assert_eq!(ThreadAffinity::default(), ThreadAffinity::Any);
}
#[test]
fn test_thread_affinity_is_pinned() {
assert!(!ThreadAffinity::Any.is_pinned());
assert!(ThreadAffinity::Main.is_pinned());
assert!(ThreadAffinity::Named("RenderThread".into()).is_pinned());
}
#[test]
fn test_thread_affinity_thread_name() {
assert!(ThreadAffinity::Any.thread_name().is_none());
assert!(ThreadAffinity::Main.thread_name().is_none());
assert_eq!(ThreadAffinity::Named("GT".into()).thread_name(), Some("GT"));
}
#[test]
fn test_thread_affinity_serialization() {
let variants = [
ThreadAffinity::Any,
ThreadAffinity::Main,
ThreadAffinity::Named("Render".into()),
];
for v in &variants {
let json = serde_json::to_string(v).unwrap();
let back: ThreadAffinity = serde_json::from_str(&json).unwrap();
assert_eq!(&back, v);
}
}
#[test]
fn test_pump_basic() {
let pump = MainThreadPump::new();
let counter = Arc::new(AtomicU32::new(0));
let c = Arc::clone(&counter);
pump.dispatch(move || {
c.fetch_add(1, Ordering::SeqCst);
});
assert_eq!(pump.pending(), 1);
let stats = pump.pump(Duration::from_millis(100));
assert_eq!(stats.processed, 1);
assert_eq!(stats.remaining, 0);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[test]
fn test_pump_multiple_items() {
let pump = MainThreadPump::new();
let counter = Arc::new(AtomicU32::new(0));
for _ in 0..5 {
let c = Arc::clone(&counter);
pump.dispatch(move || {
c.fetch_add(1, Ordering::SeqCst);
});
}
let stats = pump.pump(Duration::from_millis(500));
assert_eq!(stats.processed, 5);
assert_eq!(counter.load(Ordering::SeqCst), 5);
assert_eq!(pump.total_dispatched(), 5);
assert_eq!(pump.total_processed(), 5);
}
#[test]
fn test_pump_budget_zero_processes_nothing() {
let pump = MainThreadPump::new();
pump.dispatch(|| {});
let _ = pump.pump(Duration::ZERO);
pump.pump(Duration::from_millis(100));
}
#[test]
fn test_pump_cross_thread_dispatch() {
let pump = MainThreadPump::new();
let pump_worker = pump.clone();
let counter = Arc::new(AtomicU32::new(0));
let c = Arc::clone(&counter);
let handle = thread::spawn(move || {
pump_worker.dispatch(move || {
c.fetch_add(42, Ordering::SeqCst);
});
});
handle.join().unwrap();
pump.pump(Duration::from_millis(100));
assert_eq!(counter.load(Ordering::SeqCst), 42);
}
#[test]
fn test_pump_stats_elapsed_is_reasonable() {
let pump = MainThreadPump::new();
pump.dispatch(|| thread::sleep(Duration::from_millis(20)));
let stats = pump.pump(Duration::from_millis(500));
assert_eq!(stats.processed, 1);
assert!(stats.elapsed >= Duration::from_millis(15)); }
#[test]
fn test_pump_clone_shares_queue() {
let pump1 = MainThreadPump::new();
let pump2 = pump1.clone();
let counter = Arc::new(AtomicU32::new(0));
let c = Arc::clone(&counter);
pump2.dispatch(move || {
c.fetch_add(1, Ordering::SeqCst);
});
let stats = pump1.pump(Duration::from_millis(100));
assert_eq!(stats.processed, 1);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
}