use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[derive(Clone)]
pub struct CancellationToken {
cancelled: Arc<Mutex<bool>>,
cancellation_reason: Arc<Mutex<Option<String>>>,
}
impl CancellationToken {
pub fn new() -> Self {
Self {
cancelled: Arc::new(Mutex::new(false)),
cancellation_reason: Arc::new(Mutex::new(None)),
}
}
pub fn cancel(&self, reason: Option<String>) {
*self.cancelled.lock().expect("lock should not be poisoned") = true;
*self
.cancellation_reason
.lock()
.expect("lock should not be poisoned") = reason;
}
pub fn is_cancelled(&self) -> bool {
*self.cancelled.lock().expect("lock should not be poisoned")
}
pub fn cancellation_reason(&self) -> Option<String> {
self.cancellation_reason
.lock()
.expect("lock should not be poisoned")
.clone()
}
pub fn check_cancelled(&self) -> Result<(), String> {
if self.is_cancelled() {
Err(self
.cancellation_reason()
.unwrap_or_else(|| "Task was cancelled".to_string()))
} else {
Ok(())
}
}
}
impl Default for CancellationToken {
fn default() -> Self {
Self::new()
}
}
pub struct TimeoutManager {
timeout: Duration,
start_time: Instant,
}
impl TimeoutManager {
pub fn new(timeout: Duration) -> Self {
Self {
timeout,
start_time: Instant::now(),
}
}
pub fn is_timed_out(&self) -> bool {
self.start_time.elapsed() > self.timeout
}
pub fn remaining(&self) -> Duration {
let elapsed = self.start_time.elapsed();
if elapsed >= self.timeout {
Duration::from_secs(0)
} else {
self.timeout - elapsed
}
}
pub fn elapsed(&self) -> Duration {
self.start_time.elapsed()
}
pub fn check_timeout(&self) -> Result<(), String> {
if self.is_timed_out() {
Err(format!(
"Task timeout exceeded: {}s",
self.timeout.as_secs()
))
} else {
Ok(())
}
}
}
pub struct ExecutionGuard {
cancellation_token: CancellationToken,
timeout_manager: Option<TimeoutManager>,
}
impl ExecutionGuard {
pub fn new(cancellation_token: CancellationToken, timeout: Option<Duration>) -> Self {
Self {
cancellation_token,
timeout_manager: timeout.map(TimeoutManager::new),
}
}
pub fn should_continue(&self) -> Result<(), String> {
self.cancellation_token.check_cancelled()?;
if let Some(timeout_mgr) = &self.timeout_manager {
timeout_mgr.check_timeout()?;
}
Ok(())
}
pub fn cancellation_token(&self) -> &CancellationToken {
&self.cancellation_token
}
pub fn remaining_timeout(&self) -> Option<Duration> {
self.timeout_manager.as_ref().map(|t| t.remaining())
}
}