use std::sync::{Arc, Condvar, Mutex};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::thread;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Preempted {
pub reason: Option<String>,
}
impl std::fmt::Display for Preempted {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.reason {
Some(r) => write!(f, "Task preempted: {}", r),
None => write!(f, "Task preempted"),
}
}
}
impl std::error::Error for Preempted {}
struct PreemptionState {
cancelled: AtomicBool,
reason: Mutex<Option<String>>,
deadline_unix_ms: AtomicU64,
condvar: Condvar,
wakeup: Mutex<bool>,
}
impl PreemptionState {
fn new() -> Arc<Self> {
Arc::new(Self {
cancelled: AtomicBool::new(false),
reason: Mutex::new(None),
deadline_unix_ms: AtomicU64::new(0),
condvar: Condvar::new(),
wakeup: Mutex::new(false),
})
}
fn cancel_with_reason(&self, reason: Option<String>) {
if let Ok(mut r) = self.reason.lock() {
*r = reason;
}
self.cancelled.store(true, Ordering::Release);
if let Ok(mut w) = self.wakeup.lock() {
*w = true;
self.condvar.notify_all();
}
}
}
#[derive(Clone)]
pub struct PreemptiveCancellationToken {
state: Arc<PreemptionState>,
}
impl std::fmt::Debug for PreemptiveCancellationToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PreemptiveCancellationToken")
.field("cancelled", &self.state.cancelled.load(Ordering::Relaxed))
.finish()
}
}
impl Default for PreemptiveCancellationToken {
fn default() -> Self { Self::new() }
}
impl PreemptiveCancellationToken {
pub fn new() -> Self {
Self { state: PreemptionState::new() }
}
pub fn cancel(&self) {
self.state.cancel_with_reason(None);
}
pub fn cancel_with(&self, reason: impl Into<String>) {
self.state.cancel_with_reason(Some(reason.into()));
}
pub fn cancel_after(&self, after: Duration) {
let deadline = Instant::now() + after;
self.cancel_at_instant(deadline, None);
}
pub fn cancel_after_with(&self, after: Duration, reason: impl Into<String>) {
let deadline = Instant::now() + after;
self.cancel_at_instant(deadline, Some(reason.into()));
}
pub fn cancel_at(&self, deadline: Instant) {
self.cancel_at_instant(deadline, None);
}
pub fn deadline_guard(&self, budget: Duration) -> DeadlineGuard {
self.cancel_after(budget);
DeadlineGuard {
token: self.clone(),
start: Instant::now(),
budget,
}
}
#[inline(always)]
pub fn is_cancelled(&self) -> bool {
self.state.cancelled.load(Ordering::Acquire)
}
#[inline(always)]
pub fn check(&self) -> Result<(), Preempted> {
if self.is_cancelled() {
let reason = self.state.reason.lock()
.ok()
.and_then(|r| r.clone());
Err(Preempted { reason })
} else {
Ok(())
}
}
#[inline]
pub fn check_and_yield(&self) -> Result<(), Preempted> {
let result = self.check();
if result.is_ok() {
thread::yield_now();
}
result
}
pub fn reset(&self) {
self.state.deadline_unix_ms.store(0, Ordering::Release);
if let Ok(mut r) = self.state.reason.lock() {
*r = None;
}
self.state.cancelled.store(false, Ordering::Release);
}
pub fn reason(&self) -> Option<String> {
self.state.reason.lock().ok().and_then(|r| r.clone())
}
#[cfg(target_os = "linux")]
pub unsafe fn install_signal_handler() {
signal_preemption::install();
}
#[cfg(target_os = "linux")]
pub fn signal_preempt_thread(pthread_id: libc::pthread_t) {
signal_preemption::preempt(pthread_id);
}
#[cfg(target_os = "linux")]
#[inline]
pub fn check_signal() -> Result<(), Preempted> {
signal_preemption::check_flag()
}
fn cancel_at_instant(&self, deadline: Instant, reason: Option<String>) {
let unix_ms = unix_ms_from_instant(deadline);
self.state.deadline_unix_ms.store(unix_ms, Ordering::Release);
let state = self.state.clone();
thread::Builder::new()
.name("taskflow-watchdog".into())
.spawn(move || {
let now = Instant::now();
if deadline <= now {
state.cancel_with_reason(reason);
return;
}
let sleep_for = deadline - now;
let lock = state.wakeup.lock().unwrap();
let (guard, timed_out) = state.condvar
.wait_timeout(lock, sleep_for)
.unwrap();
let should_cancel =
timed_out.timed_out() && !state.cancelled.load(Ordering::Acquire);
drop(guard); if should_cancel {
state.cancel_with_reason(reason);
}
})
.expect("Failed to spawn watchdog thread");
}
}
pub struct DeadlineGuard {
token: PreemptiveCancellationToken,
start: Instant,
budget: Duration,
}
impl DeadlineGuard {
pub fn remaining(&self) -> Duration {
self.budget.checked_sub(self.start.elapsed()).unwrap_or(Duration::ZERO)
}
pub fn is_expired(&self) -> bool {
self.start.elapsed() >= self.budget
}
pub fn token(&self) -> &PreemptiveCancellationToken {
&self.token
}
}
impl Drop for DeadlineGuard {
fn drop(&mut self) {
if self.is_expired() {
self.token.cancel_with("deadline elapsed");
}
}
}
pub fn with_deadline<T, F>(budget: Duration, f: F) -> Result<T, Preempted>
where
F: FnOnce(&PreemptiveCancellationToken) -> Result<T, Preempted>,
{
let token = PreemptiveCancellationToken::new();
token.cancel_after(budget);
f(&token)
}
fn unix_ms_from_instant(deadline: Instant) -> u64 {
let now_instant = Instant::now();
let now_unix_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
if deadline > now_instant {
let offset_ms = (deadline - now_instant).as_millis() as u64;
now_unix_ms.saturating_add(offset_ms)
} else {
now_unix_ms
}
}
#[cfg(target_os = "linux")]
mod signal_preemption {
use std::sync::atomic::{AtomicBool, Ordering};
use std::cell::Cell;
thread_local! {
static SIGNAL_PREEMPTED: Cell<bool> = Cell::new(false);
}
static HANDLER_INSTALLED: AtomicBool = AtomicBool::new(false);
pub(super) unsafe fn install() {
if HANDLER_INSTALLED.swap(true, Ordering::AcqRel) {
return; }
extern "C" fn handler(_sig: libc::c_int) {
SIGNAL_PREEMPTED.with(|f| f.set(true));
}
let mut sa: libc::sigaction = std::mem::zeroed();
sa.sa_sigaction = handler as libc::sighandler_t;
sa.sa_flags = libc::SA_RESTART;
libc::sigaction(libc::SIGUSR2, &sa, std::ptr::null_mut());
}
pub(super) fn preempt(thread: libc::pthread_t) {
unsafe { libc::pthread_kill(thread, libc::SIGUSR2); }
}
#[inline]
pub(super) fn check_flag() -> Result<(), super::Preempted> {
let preempted = SIGNAL_PREEMPTED.with(|f| {
if f.get() { f.set(false); true } else { false }
});
if preempted {
Err(super::Preempted { reason: Some("signal preemption".into()) })
} else {
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering as AO};
#[test]
fn manual_cancel() {
let token = PreemptiveCancellationToken::new();
assert!(!token.is_cancelled());
assert!(token.check().is_ok());
token.cancel();
assert!(token.is_cancelled());
assert!(token.check().is_err());
}
#[test]
fn cancel_with_reason() {
let token = PreemptiveCancellationToken::new();
token.cancel_with("budget exceeded");
let err = token.check().unwrap_err();
assert_eq!(err.reason.as_deref(), Some("budget exceeded"));
}
#[test]
fn watchdog_fires_after_timeout() {
let token = PreemptiveCancellationToken::new();
token.cancel_after(Duration::from_millis(50));
assert!(!token.is_cancelled());
thread::sleep(Duration::from_millis(120));
assert!(token.is_cancelled());
}
#[test]
fn manual_cancel_beats_watchdog() {
let token = PreemptiveCancellationToken::new();
token.cancel_after(Duration::from_millis(500));
token.cancel(); assert!(token.is_cancelled());
}
#[test]
fn reset_clears_state() {
let token = PreemptiveCancellationToken::new();
token.cancel_with("test");
assert!(token.is_cancelled());
token.reset();
assert!(!token.is_cancelled());
assert!(token.reason().is_none());
assert!(token.check().is_ok());
}
#[test]
fn deadline_guard_cancels_on_expiry() {
let token = PreemptiveCancellationToken::new();
{
let _guard = token.deadline_guard(Duration::from_millis(10));
thread::sleep(Duration::from_millis(30));
}
assert!(token.is_cancelled());
}
#[test]
fn with_deadline_respects_budget() {
let counter = Arc::new(AtomicUsize::new(0));
let c = counter.clone();
let result = with_deadline(Duration::from_millis(80), |tok| {
loop {
tok.check()?;
c.fetch_add(1, AO::Relaxed);
thread::sleep(Duration::from_millis(20));
}
});
assert!(result.is_err());
let iterations = counter.load(AO::Relaxed);
assert!((1..=6).contains(&iterations),
"expected 1-6 iterations, got {}", iterations);
}
#[test]
fn clone_shares_state() {
let token = PreemptiveCancellationToken::new();
let clone = token.clone();
clone.cancel();
assert!(token.is_cancelled());
}
#[test]
fn with_deadline_ok_path() {
let result = with_deadline(Duration::from_secs(5), |tok| {
tok.check()?;
Ok(42u32)
});
assert_eq!(result.unwrap(), 42);
}
}