use std::{
collections::BinaryHeap,
sync::Arc,
thread,
time::{Duration, Instant},
};
use parking_lot::{Condvar, Mutex};
use crate::{
PlatformDispatcher, Priority, RunnableVariant, profiler,
queue::{PriorityQueueReceiver, PriorityQueueSender},
};
const MIN_THREADS: usize = 2;
pub struct BenchDispatcher {
background_sender: PriorityQueueSender<RunnableVariant>,
main_sender: PriorityQueueSender<RunnableVariant>,
main_receiver: Mutex<PriorityQueueReceiver<RunnableVariant>>,
timers: Arc<TimerQueue>,
idle: Arc<IdleTracker>,
main_thread_id: thread::ThreadId,
}
#[derive(Default)]
struct IdleTracker {
inflight: Mutex<usize>,
condvar: Condvar,
}
impl IdleTracker {
fn increment(&self) {
*self.inflight.lock() += 1;
}
fn decrement(&self) {
let mut inflight = self.inflight.lock();
*inflight -= 1;
if *inflight == 0 {
self.condvar.notify_all();
}
}
fn decrement_on_drop(&self) -> impl Drop + '_ {
rgpui_util::defer(|| self.decrement())
}
fn notify_under_lock(&self) {
let _inflight = self.inflight.lock();
self.condvar.notify_all();
}
}
struct TimerQueue {
state: Mutex<TimerQueueState>,
condvar: Condvar,
}
struct TimerQueueState {
heap: BinaryHeap<TimerEntry>,
next_seq: u64,
}
struct TimerEntry {
due: Instant,
seq: u64,
runnable: RunnableVariant,
}
impl PartialEq for TimerEntry {
fn eq(&self, other: &Self) -> bool {
self.due == other.due && self.seq == other.seq
}
}
impl Eq for TimerEntry {}
impl PartialOrd for TimerEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for TimerEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other
.due
.cmp(&self.due)
.then_with(|| other.seq.cmp(&self.seq))
}
}
impl Default for BenchDispatcher {
fn default() -> Self {
Self::new()
}
}
impl BenchDispatcher {
pub fn new() -> Self {
let (background_sender, background_receiver) = PriorityQueueReceiver::new();
let (main_sender, main_receiver) = PriorityQueueReceiver::new();
let idle = Arc::new(IdleTracker::default());
let thread_count =
thread::available_parallelism().map_or(MIN_THREADS, |i| i.get().max(MIN_THREADS));
for i in 0..thread_count {
let mut receiver: PriorityQueueReceiver<RunnableVariant> = background_receiver.clone();
let idle = idle.clone();
thread::Builder::new()
.name(format!("BenchWorker-{i}"))
.spawn(move || {
while let Ok(runnable) = receiver.pop() {
let _decrement = idle.decrement_on_drop();
let location = runnable.metadata().location;
let spawned = runnable.metadata().spawned;
profiler::update_running_task(spawned, location);
runnable.run();
profiler::save_task_timing();
}
})
.expect("failed to spawn benchmark worker thread");
}
drop(background_receiver);
let timers = Arc::new(TimerQueue {
state: Mutex::new(TimerQueueState {
heap: BinaryHeap::new(),
next_seq: 0,
}),
condvar: Condvar::new(),
});
{
let timers = timers.clone();
let idle = idle.clone();
thread::Builder::new()
.name("BenchTimer".to_owned())
.spawn(move || {
let mut state = timers.state.lock();
loop {
let Some(entry) = state.heap.peek() else {
timers.condvar.wait(&mut state);
continue;
};
let due = entry.due;
if due > Instant::now() {
timers.condvar.wait_until(&mut state, due);
continue;
}
let Some(entry) = state.heap.pop() else {
continue;
};
idle.increment();
drop(state);
{
let _decrement = idle.decrement_on_drop();
let location = entry.runnable.metadata().location;
let spawned = entry.runnable.metadata().spawned;
profiler::update_running_task(spawned, location);
entry.runnable.run();
profiler::save_task_timing();
}
state = timers.state.lock();
}
})
.expect("failed to spawn benchmark timer thread");
}
Self {
background_sender,
main_sender,
main_receiver: Mutex::new(main_receiver),
timers,
idle,
main_thread_id: thread::current().id(),
}
}
pub fn run_until_idle(&self) {
assert!(
self.is_main_thread(),
"run_until_idle must be called on the benchmark main thread"
);
loop {
if self.drain_main_queue() {
continue;
}
if self.has_due_timer() {
let mut inflight = self.idle.inflight.lock();
self.idle
.condvar
.wait_for(&mut inflight, Duration::from_millis(1));
continue;
}
let mut inflight = self.idle.inflight.lock();
if self.main_queue_has_work() {
continue;
}
if *inflight == 0 {
return;
}
self.idle.condvar.wait(&mut inflight);
}
}
pub fn forget_pending_timers(&self) {
let mut state = self.timers.state.lock();
for entry in state.heap.drain() {
std::mem::forget(entry.runnable);
}
}
fn has_due_timer(&self) -> bool {
let state = self.timers.state.lock();
state
.heap
.peek()
.is_some_and(|entry| entry.due <= Instant::now())
}
fn main_queue_has_work(&self) -> bool {
!self.main_receiver.lock().is_empty()
}
fn drain_main_queue(&self) -> bool {
let mut ran_any = false;
loop {
let runnable = self.main_receiver.lock().try_pop();
match runnable {
Ok(Some(runnable)) => {
let location = runnable.metadata().location;
let spawned = runnable.metadata().spawned;
profiler::update_running_task(spawned, location);
runnable.run();
profiler::save_task_timing();
ran_any = true;
}
Ok(None) | Err(_) => return ran_any,
}
}
}
}
impl PlatformDispatcher for BenchDispatcher {
fn is_main_thread(&self) -> bool {
thread::current().id() == self.main_thread_id
}
fn dispatch(&self, runnable: RunnableVariant, priority: Priority) {
self.idle.increment();
self.background_sender
.send(priority, runnable)
.unwrap_or_else(|_| panic!("benchmark worker threads are no longer running"));
}
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority) {
if let Err(error) = self.main_sender.send(priority, runnable) {
std::mem::forget(error);
return;
}
self.idle.notify_under_lock();
}
fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant) {
let mut state = self.timers.state.lock();
let seq = state.next_seq;
state.next_seq += 1;
state.heap.push(TimerEntry {
due: Instant::now() + duration,
seq,
runnable,
});
self.timers.condvar.notify_one();
}
fn spawn_realtime(&self, f: Box<dyn FnOnce() + Send>) {
thread::Builder::new()
.name("BenchRealtime".to_owned())
.spawn(f)
.expect("failed to spawn benchmark realtime thread");
}
fn as_bench(&self) -> Option<&BenchDispatcher> {
Some(self)
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicBool, Ordering};
use super::*;
use crate::{BackgroundExecutor, ForegroundExecutor};
#[test]
fn run_until_idle_completes_background_to_main_handoffs() {
let dispatcher = Arc::new(BenchDispatcher::new());
let background = BackgroundExecutor::new(dispatcher.clone());
let foreground = ForegroundExecutor::new(dispatcher.clone());
let (sender, receiver) = futures::channel::oneshot::channel();
background
.spawn(async move {
thread::sleep(Duration::from_millis(10));
sender.send(()).ok();
})
.detach();
let completed = Arc::new(AtomicBool::new(false));
foreground
.spawn({
let completed = completed.clone();
async move {
receiver.await.ok();
completed.store(true, Ordering::SeqCst);
}
})
.detach();
dispatcher.run_until_idle();
assert!(completed.load(Ordering::SeqCst));
}
#[test]
fn timers_fire_in_real_time() {
let dispatcher = Arc::new(BenchDispatcher::new());
let background = BackgroundExecutor::new(dispatcher);
let fired = Arc::new(AtomicBool::new(false));
let timer = background.timer(Duration::from_millis(10));
background
.spawn({
let fired = fired.clone();
async move {
timer.await;
fired.store(true, Ordering::SeqCst);
}
})
.detach();
let deadline = Instant::now() + Duration::from_secs(10);
while !fired.load(Ordering::SeqCst) && Instant::now() < deadline {
thread::sleep(Duration::from_millis(1));
}
assert!(fired.load(Ordering::SeqCst));
}
#[test]
fn forget_pending_timers_prevents_stale_timers_from_firing() {
let dispatcher = Arc::new(BenchDispatcher::new());
let background = BackgroundExecutor::new(dispatcher.clone());
let fired = Arc::new(AtomicBool::new(false));
let timer = background.timer(Duration::from_millis(250));
background
.spawn({
let fired = fired.clone();
async move {
timer.await;
fired.store(true, Ordering::SeqCst);
}
})
.detach();
dispatcher.run_until_idle();
dispatcher.forget_pending_timers();
thread::sleep(Duration::from_millis(400));
dispatcher.run_until_idle();
assert!(!fired.load(Ordering::SeqCst));
}
}