use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use pin_project_lite::pin_project;
#[derive(Debug, Clone, Copy)]
pub struct Deadline {
expires_at: Instant,
}
impl Deadline {
pub fn from_duration(duration: Duration) -> Self {
Self {
expires_at: Instant::now() + duration,
}
}
pub fn at(instant: Instant) -> Self {
Self {
expires_at: instant,
}
}
pub fn is_expired(&self) -> bool {
Instant::now() >= self.expires_at
}
pub fn remaining(&self) -> Option<Duration> {
let now = Instant::now();
if now >= self.expires_at {
None
} else {
Some(self.expires_at - now)
}
}
pub fn expires_at(&self) -> Instant {
self.expires_at
}
pub fn with_timeout(&self, timeout: Duration) -> Self {
let timeout_deadline = Instant::now() + timeout;
Self {
expires_at: self.expires_at.min(timeout_deadline),
}
}
}
#[derive(Debug, Clone)]
pub struct OperationContext {
pub name: Option<String>,
pub deadline: Option<Deadline>,
cancelled: Arc<AtomicBool>,
parent: Option<Arc<OperationContext>>,
created_at: Instant,
}
impl OperationContext {
pub fn new() -> Self {
Self {
name: None,
deadline: None,
cancelled: Arc::new(AtomicBool::new(false)),
parent: None,
created_at: Instant::now(),
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn with_deadline(mut self, deadline: Deadline) -> Self {
self.deadline = Some(deadline);
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.deadline = Some(Deadline::from_duration(timeout));
self
}
pub fn with_parent(mut self, parent: Arc<OperationContext>) -> Self {
if self.deadline.is_none() {
self.deadline = parent.deadline;
} else if let (Some(parent_deadline), Some(ref my_deadline)) =
(parent.deadline, &self.deadline)
{
if parent_deadline.expires_at < my_deadline.expires_at {
self.deadline = Some(parent_deadline);
}
}
self.parent = Some(parent);
self
}
pub fn child(&self) -> OperationContext {
OperationContext::new().with_parent(Arc::new(self.clone()))
}
pub fn child_with_timeout(&self, timeout: Duration) -> OperationContext {
let deadline = match self.deadline {
Some(d) => d.with_timeout(timeout),
None => Deadline::from_duration(timeout),
};
OperationContext::new()
.with_deadline(deadline)
.with_parent(Arc::new(self.clone()))
}
pub fn is_cancelled(&self) -> bool {
if self.cancelled.load(Ordering::Relaxed) {
return true;
}
if let Some(ref parent) = self.parent {
return parent.is_cancelled();
}
false
}
pub fn cancel(&self) {
self.cancelled.store(true, Ordering::Relaxed);
}
pub fn is_expired(&self) -> bool {
self.deadline.map(|d| d.is_expired()).unwrap_or(false)
}
pub fn remaining_time(&self) -> Option<Duration> {
self.deadline.and_then(|d| d.remaining())
}
pub fn should_continue(&self) -> bool {
!self.is_cancelled() && !self.is_expired()
}
pub fn elapsed(&self) -> Duration {
self.created_at.elapsed()
}
pub fn cancellation_token(&self) -> CancellationToken {
CancellationToken {
flag: self.cancelled.clone(),
}
}
}
impl Default for OperationContext {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CancellationToken {
flag: Arc<AtomicBool>,
}
impl CancellationToken {
pub fn new() -> Self {
Self {
flag: Arc::new(AtomicBool::new(false)),
}
}
pub fn is_cancelled(&self) -> bool {
self.flag.load(Ordering::Relaxed)
}
pub fn cancel(&self) {
self.flag.store(true, Ordering::Relaxed);
}
}
impl Default for CancellationToken {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct TimeoutError {
pub operation: Option<String>,
pub timeout: Duration,
pub elapsed: Duration,
}
impl std::fmt::Display for TimeoutError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.operation {
Some(name) => write!(
f,
"Operation '{}' timed out after {:?} (limit: {:?})",
name, self.elapsed, self.timeout
),
None => write!(
f,
"Operation timed out after {:?} (limit: {:?})",
self.elapsed, self.timeout
),
}
}
}
impl std::error::Error for TimeoutError {}
pin_project! {
pub struct Timeout<F> {
#[pin]
inner: F,
deadline: Deadline,
started_at: Instant,
operation_name: Option<String>,
}
}
impl<F: Future> Future for Timeout<F> {
type Output = Result<F::Output, TimeoutError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
if this.deadline.is_expired() {
return Poll::Ready(Err(TimeoutError {
operation: this.operation_name.clone(),
timeout: this
.deadline
.expires_at()
.saturating_duration_since(*this.started_at),
elapsed: this.started_at.elapsed(),
}));
}
match this.inner.poll(cx) {
Poll::Ready(value) => Poll::Ready(Ok(value)),
Poll::Pending => {
Poll::Pending
}
}
}
}
pub fn timeout<F: Future>(duration: Duration, future: F) -> Timeout<F> {
Timeout {
inner: future,
deadline: Deadline::from_duration(duration),
started_at: Instant::now(),
operation_name: None,
}
}
pub fn timeout_named<F: Future>(
name: impl Into<String>,
duration: Duration,
future: F,
) -> Timeout<F> {
Timeout {
inner: future,
deadline: Deadline::from_duration(duration),
started_at: Instant::now(),
operation_name: Some(name.into()),
}
}
pub async fn with_timeout<F, T>(duration: Duration, future: F) -> Result<T, TimeoutError>
where
F: Future<Output = T>,
{
let started_at = Instant::now();
match tokio::time::timeout(duration, future).await {
Ok(result) => Ok(result),
Err(_) => Err(TimeoutError {
operation: None,
timeout: duration,
elapsed: started_at.elapsed(),
}),
}
}
pub async fn with_timeout_named<F, T>(
name: impl Into<String>,
duration: Duration,
future: F,
) -> Result<T, TimeoutError>
where
F: Future<Output = T>,
{
let name = name.into();
let started_at = Instant::now();
match tokio::time::timeout(duration, future).await {
Ok(result) => Ok(result),
Err(_) => Err(TimeoutError {
operation: Some(name),
timeout: duration,
elapsed: started_at.elapsed(),
}),
}
}
#[derive(Debug, Default)]
pub struct TimeoutStats {
pub total_operations: AtomicU64,
pub completed: AtomicU64,
pub timeouts: AtomicU64,
pub cancellations: AtomicU64,
}
impl TimeoutStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_completed(&self) {
self.total_operations.fetch_add(1, Ordering::Relaxed);
self.completed.fetch_add(1, Ordering::Relaxed);
}
pub fn record_timeout(&self) {
self.total_operations.fetch_add(1, Ordering::Relaxed);
self.timeouts.fetch_add(1, Ordering::Relaxed);
}
pub fn record_cancellation(&self) {
self.total_operations.fetch_add(1, Ordering::Relaxed);
self.cancellations.fetch_add(1, Ordering::Relaxed);
}
pub fn timeout_rate(&self) -> f64 {
let total = self.total_operations.load(Ordering::Relaxed);
if total == 0 {
return 0.0;
}
let timeouts = self.timeouts.load(Ordering::Relaxed);
timeouts as f64 / total as f64
}
pub fn snapshot(&self) -> TimeoutStatsSnapshot {
TimeoutStatsSnapshot {
total_operations: self.total_operations.load(Ordering::Relaxed),
completed: self.completed.load(Ordering::Relaxed),
timeouts: self.timeouts.load(Ordering::Relaxed),
cancellations: self.cancellations.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone)]
pub struct TimeoutStatsSnapshot {
pub total_operations: u64,
pub completed: u64,
pub timeouts: u64,
pub cancellations: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deadline() {
let deadline = Deadline::from_duration(Duration::from_secs(10));
assert!(!deadline.is_expired());
assert!(deadline.remaining().is_some());
let expired = Deadline::from_duration(Duration::from_nanos(1));
std::thread::sleep(Duration::from_millis(1));
assert!(expired.is_expired());
assert!(expired.remaining().is_none());
}
#[test]
fn test_deadline_with_timeout() {
let deadline = Deadline::from_duration(Duration::from_secs(60));
let shorter = deadline.with_timeout(Duration::from_secs(5));
assert!(shorter.expires_at() < deadline.expires_at());
}
#[test]
fn test_operation_context() {
let ctx = OperationContext::new()
.with_name("test_op")
.with_timeout(Duration::from_secs(30));
assert!(!ctx.is_cancelled());
assert!(!ctx.is_expired());
assert!(ctx.should_continue());
assert!(ctx.remaining_time().is_some());
}
#[test]
fn test_operation_context_cancellation() {
let ctx = OperationContext::new();
assert!(!ctx.is_cancelled());
ctx.cancel();
assert!(ctx.is_cancelled());
assert!(!ctx.should_continue());
}
#[test]
fn test_operation_context_parent() {
let parent = OperationContext::new().with_timeout(Duration::from_secs(30));
let child = parent.child();
assert!(child.deadline.is_some());
parent.cancel();
assert!(child.is_cancelled());
}
#[test]
fn test_cancellation_token() {
let token = CancellationToken::new();
assert!(!token.is_cancelled());
token.cancel();
assert!(token.is_cancelled());
let token2 = token.clone();
assert!(token2.is_cancelled());
}
#[test]
fn test_timeout_error_display() {
let error = TimeoutError {
operation: Some("send_message".to_string()),
timeout: Duration::from_secs(5),
elapsed: Duration::from_secs(5),
};
let display = format!("{}", error);
assert!(display.contains("send_message"));
assert!(display.contains("timed out"));
}
#[test]
fn test_timeout_stats() {
let stats = TimeoutStats::new();
stats.record_completed();
stats.record_completed();
stats.record_timeout();
stats.record_cancellation();
let snapshot = stats.snapshot();
assert_eq!(snapshot.total_operations, 4);
assert_eq!(snapshot.completed, 2);
assert_eq!(snapshot.timeouts, 1);
assert_eq!(snapshot.cancellations, 1);
assert!((stats.timeout_rate() - 0.25).abs() < 0.01);
}
#[tokio::test]
async fn test_with_timeout_success() {
let result = with_timeout(Duration::from_secs(5), async {
tokio::time::sleep(Duration::from_millis(10)).await;
42
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn test_with_timeout_failure() {
let result = with_timeout(Duration::from_millis(10), async {
tokio::time::sleep(Duration::from_secs(60)).await;
42
})
.await;
assert!(result.is_err());
let error = result.unwrap_err();
assert!(error.elapsed >= Duration::from_millis(10));
}
#[tokio::test]
async fn test_with_timeout_named() {
let result = with_timeout_named("test_operation", Duration::from_millis(10), async {
tokio::time::sleep(Duration::from_secs(60)).await;
42
})
.await;
assert!(result.is_err());
let error = result.unwrap_err();
assert_eq!(error.operation, Some("test_operation".to_string()));
}
}