use std::collections::BinaryHeap;
use std::sync::Mutex;
use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender, channel};
use std::thread::JoinHandle;
use std::time::{Duration, Instant};
use super::config::SignalDeliveryConfig;
use crate::error::EngineError;
struct WakeOrder {
due: Instant,
backoff: Duration,
wake: Box<dyn Fn() + Send>,
done: Box<dyn Fn() -> bool + Send>,
}
impl PartialEq for WakeOrder {
fn eq(&self, other: &Self) -> bool {
self.due == other.due
}
}
impl Eq for WakeOrder {}
impl PartialOrd for WakeOrder {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for WakeOrder {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.due.cmp(&self.due)
}
}
pub(super) struct WakeConfirmer {
sender: Mutex<Option<Sender<WakeOrder>>>,
worker: Mutex<Option<JoinHandle<()>>>,
policy: SignalDeliveryConfig,
}
impl WakeConfirmer {
pub(super) fn new(policy: SignalDeliveryConfig) -> Result<Self, EngineError> {
let (sender, receiver) = channel();
let cap = policy.ready_timeout.max(policy.initial_backoff);
let worker = std::thread::Builder::new()
.name("aion-wake-confirm".to_owned())
.spawn(move || run_worker(&receiver, cap))
.map_err(|error| EngineError::Runtime {
reason: format!("failed to start the wake-confirmation worker: {error}"),
})?;
Ok(Self {
sender: Mutex::new(Some(sender)),
worker: Mutex::new(Some(worker)),
policy,
})
}
pub(super) fn confirm(
&self,
wake: impl Fn() + Send + 'static,
done: impl Fn() -> bool + Send + 'static,
) {
let initial = self.policy.initial_backoff.max(Duration::from_micros(50));
let order = WakeOrder {
due: Instant::now() + initial,
backoff: initial,
wake: Box::new(wake),
done: Box::new(done),
};
let guard = match self.sender.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
if let Some(sender) = guard.as_ref() {
drop(sender.send(order));
}
}
pub(super) fn shutdown(&self) {
let sender = {
let mut guard = match self.sender.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
guard.take()
};
drop(sender);
let worker = {
let mut guard = match self.worker.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
guard.take()
};
if let Some(worker) = worker
&& worker.join().is_err()
{
tracing::error!("wake-confirmation worker panicked");
}
}
}
impl Drop for WakeConfirmer {
fn drop(&mut self) {
self.shutdown();
}
}
fn run_worker(receiver: &Receiver<WakeOrder>, cap: Duration) {
let mut pending: BinaryHeap<WakeOrder> = BinaryHeap::new();
loop {
let arrival = match pending.peek() {
Some(order) => {
let wait = order.due.saturating_duration_since(Instant::now());
match receiver.recv_timeout(wait) {
Ok(order) => Some(order),
Err(RecvTimeoutError::Timeout) => None,
Err(RecvTimeoutError::Disconnected) => return,
}
}
None => match receiver.recv() {
Ok(order) => Some(order),
Err(_disconnected) => return,
},
};
if let Some(order) = arrival {
pending.push(order);
}
let now = Instant::now();
while pending.peek().is_some_and(|order| order.due <= now) {
let Some(mut order) = pending.pop() else {
break;
};
if (order.done)() {
continue;
}
(order.wake)();
let doubled = order.backoff.saturating_mul(2);
order.backoff = if doubled > cap { cap } else { doubled };
order.due = now + order.backoff;
pending.push(order);
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::time::{Duration, Instant};
use super::WakeConfirmer;
use crate::runtime::SignalDeliveryConfig;
type TestResult = Result<(), Box<dyn std::error::Error>>;
fn policy() -> SignalDeliveryConfig {
SignalDeliveryConfig::new(
Duration::from_millis(20),
8,
Duration::from_millis(1),
Duration::from_millis(4),
)
}
#[test]
fn ladder_persists_until_the_target_is_observed() -> TestResult {
let confirmer = WakeConfirmer::new(policy())?;
let wakes = Arc::new(AtomicU32::new(0));
let observed = Arc::new(AtomicBool::new(false));
let counter = Arc::clone(&wakes);
let gate = Arc::clone(&observed);
confirmer.confirm(
move || {
counter.fetch_add(1, Ordering::AcqRel);
},
move || gate.load(Ordering::Acquire),
);
let deadline = Instant::now() + Duration::from_secs(10);
while wakes.load(Ordering::Acquire) < 12 {
if Instant::now() > deadline {
return Err(format!(
"ladder stopped early at {} wakes without observation",
wakes.load(Ordering::Acquire)
)
.into());
}
std::thread::sleep(Duration::from_millis(2));
}
observed.store(true, Ordering::Release);
std::thread::sleep(Duration::from_millis(80));
let settled = wakes.load(Ordering::Acquire);
std::thread::sleep(Duration::from_millis(80));
assert_eq!(
wakes.load(Ordering::Acquire),
settled,
"no wakes may follow a positive observation"
);
confirmer.shutdown();
Ok(())
}
#[test]
fn already_observed_orders_never_wake() -> TestResult {
let confirmer = WakeConfirmer::new(policy())?;
let wakes = Arc::new(AtomicU32::new(0));
let counter = Arc::clone(&wakes);
confirmer.confirm(
move || {
counter.fetch_add(1, Ordering::AcqRel);
},
|| true,
);
std::thread::sleep(Duration::from_millis(40));
assert_eq!(wakes.load(Ordering::Acquire), 0);
confirmer.shutdown();
Ok(())
}
#[test]
fn concurrent_orders_stop_independently() -> TestResult {
let confirmer = WakeConfirmer::new(policy())?;
let mut gates = Vec::new();
let wakes = Arc::new(AtomicU32::new(0));
for _ in 0..8 {
let gate = Arc::new(AtomicBool::new(false));
let counter = Arc::clone(&wakes);
let observed = Arc::clone(&gate);
confirmer.confirm(
move || {
counter.fetch_add(1, Ordering::AcqRel);
},
move || observed.load(Ordering::Acquire),
);
gates.push(gate);
}
let deadline = Instant::now() + Duration::from_secs(10);
while wakes.load(Ordering::Acquire) < 16 {
if Instant::now() > deadline {
return Err("orders did not interleave".into());
}
std::thread::sleep(Duration::from_millis(2));
}
for gate in &gates {
gate.store(true, Ordering::Release);
}
std::thread::sleep(Duration::from_millis(80));
let settled = wakes.load(Ordering::Acquire);
std::thread::sleep(Duration::from_millis(80));
assert_eq!(wakes.load(Ordering::Acquire), settled);
confirmer.shutdown();
Ok(())
}
#[test]
fn shutdown_is_idempotent_and_gates_new_orders() -> TestResult {
let confirmer = WakeConfirmer::new(policy())?;
confirmer.shutdown();
confirmer.shutdown();
let wakes = Arc::new(AtomicU32::new(0));
let counter = Arc::clone(&wakes);
confirmer.confirm(
move || {
counter.fetch_add(1, Ordering::AcqRel);
},
|| false,
);
std::thread::sleep(Duration::from_millis(20));
assert_eq!(
wakes.load(Ordering::Acquire),
0,
"orders after shutdown must be dropped"
);
Ok(())
}
}