1use std::future::Future;
31use std::pin::Pin;
32use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
33use std::sync::Arc;
34use std::task::{Context, Poll};
35use std::time::{Duration, Instant};
36
37use pin_project_lite::pin_project;
38
39#[derive(Debug, Clone, Copy)]
45pub struct Deadline {
46 expires_at: Instant,
48}
49
50impl Deadline {
51 pub fn from_duration(duration: Duration) -> Self {
53 Self {
54 expires_at: Instant::now() + duration,
55 }
56 }
57
58 pub fn at(instant: Instant) -> Self {
60 Self {
61 expires_at: instant,
62 }
63 }
64
65 pub fn is_expired(&self) -> bool {
67 Instant::now() >= self.expires_at
68 }
69
70 pub fn remaining(&self) -> Option<Duration> {
72 let now = Instant::now();
73 if now >= self.expires_at {
74 None
75 } else {
76 Some(self.expires_at - now)
77 }
78 }
79
80 pub fn expires_at(&self) -> Instant {
82 self.expires_at
83 }
84
85 pub fn with_timeout(&self, timeout: Duration) -> Self {
87 let timeout_deadline = Instant::now() + timeout;
88 Self {
89 expires_at: self.expires_at.min(timeout_deadline),
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
100pub struct OperationContext {
101 pub name: Option<String>,
103 pub deadline: Option<Deadline>,
105 cancelled: Arc<AtomicBool>,
107 parent: Option<Arc<OperationContext>>,
109 created_at: Instant,
111}
112
113impl OperationContext {
114 pub fn new() -> Self {
116 Self {
117 name: None,
118 deadline: None,
119 cancelled: Arc::new(AtomicBool::new(false)),
120 parent: None,
121 created_at: Instant::now(),
122 }
123 }
124
125 pub fn with_name(mut self, name: impl Into<String>) -> Self {
127 self.name = Some(name.into());
128 self
129 }
130
131 pub fn with_deadline(mut self, deadline: Deadline) -> Self {
133 self.deadline = Some(deadline);
134 self
135 }
136
137 pub fn with_timeout(mut self, timeout: Duration) -> Self {
139 self.deadline = Some(Deadline::from_duration(timeout));
140 self
141 }
142
143 pub fn with_parent(mut self, parent: Arc<OperationContext>) -> Self {
145 if self.deadline.is_none() {
147 self.deadline = parent.deadline;
148 } else if let (Some(parent_deadline), Some(ref my_deadline)) =
149 (parent.deadline, &self.deadline)
150 {
151 if parent_deadline.expires_at < my_deadline.expires_at {
153 self.deadline = Some(parent_deadline);
154 }
155 }
156 self.parent = Some(parent);
157 self
158 }
159
160 pub fn child(&self) -> OperationContext {
162 OperationContext::new().with_parent(Arc::new(self.clone()))
163 }
164
165 pub fn child_with_timeout(&self, timeout: Duration) -> OperationContext {
167 let deadline = match self.deadline {
168 Some(d) => d.with_timeout(timeout),
169 None => Deadline::from_duration(timeout),
170 };
171 OperationContext::new()
172 .with_deadline(deadline)
173 .with_parent(Arc::new(self.clone()))
174 }
175
176 pub fn is_cancelled(&self) -> bool {
178 if self.cancelled.load(Ordering::Relaxed) {
179 return true;
180 }
181 if let Some(ref parent) = self.parent {
182 return parent.is_cancelled();
183 }
184 false
185 }
186
187 pub fn cancel(&self) {
189 self.cancelled.store(true, Ordering::Relaxed);
190 }
191
192 pub fn is_expired(&self) -> bool {
194 self.deadline.map(|d| d.is_expired()).unwrap_or(false)
195 }
196
197 pub fn remaining_time(&self) -> Option<Duration> {
199 self.deadline.and_then(|d| d.remaining())
200 }
201
202 pub fn should_continue(&self) -> bool {
204 !self.is_cancelled() && !self.is_expired()
205 }
206
207 pub fn elapsed(&self) -> Duration {
209 self.created_at.elapsed()
210 }
211
212 pub fn cancellation_token(&self) -> CancellationToken {
214 CancellationToken {
215 flag: self.cancelled.clone(),
216 }
217 }
218}
219
220impl Default for OperationContext {
221 fn default() -> Self {
222 Self::new()
223 }
224}
225
226#[derive(Debug, Clone)]
232pub struct CancellationToken {
233 flag: Arc<AtomicBool>,
234}
235
236impl CancellationToken {
237 pub fn new() -> Self {
239 Self {
240 flag: Arc::new(AtomicBool::new(false)),
241 }
242 }
243
244 pub fn is_cancelled(&self) -> bool {
246 self.flag.load(Ordering::Relaxed)
247 }
248
249 pub fn cancel(&self) {
251 self.flag.store(true, Ordering::Relaxed);
252 }
253}
254
255impl Default for CancellationToken {
256 fn default() -> Self {
257 Self::new()
258 }
259}
260
261#[derive(Debug, Clone)]
267pub struct TimeoutError {
268 pub operation: Option<String>,
270 pub timeout: Duration,
272 pub elapsed: Duration,
274}
275
276impl std::fmt::Display for TimeoutError {
277 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278 match &self.operation {
279 Some(name) => write!(
280 f,
281 "Operation '{}' timed out after {:?} (limit: {:?})",
282 name, self.elapsed, self.timeout
283 ),
284 None => write!(
285 f,
286 "Operation timed out after {:?} (limit: {:?})",
287 self.elapsed, self.timeout
288 ),
289 }
290 }
291}
292
293impl std::error::Error for TimeoutError {}
294
295pin_project! {
296 pub struct Timeout<F> {
298 #[pin]
299 inner: F,
300 deadline: Deadline,
301 started_at: Instant,
302 operation_name: Option<String>,
303 }
304}
305
306impl<F: Future> Future for Timeout<F> {
307 type Output = Result<F::Output, TimeoutError>;
308
309 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
310 let this = self.project();
311
312 if this.deadline.is_expired() {
314 return Poll::Ready(Err(TimeoutError {
315 operation: this.operation_name.clone(),
316 timeout: this
317 .deadline
318 .expires_at()
319 .saturating_duration_since(*this.started_at),
320 elapsed: this.started_at.elapsed(),
321 }));
322 }
323
324 match this.inner.poll(cx) {
326 Poll::Ready(value) => Poll::Ready(Ok(value)),
327 Poll::Pending => {
328 Poll::Pending
332 }
333 }
334 }
335}
336
337pub fn timeout<F: Future>(duration: Duration, future: F) -> Timeout<F> {
339 Timeout {
340 inner: future,
341 deadline: Deadline::from_duration(duration),
342 started_at: Instant::now(),
343 operation_name: None,
344 }
345}
346
347pub fn timeout_named<F: Future>(
349 name: impl Into<String>,
350 duration: Duration,
351 future: F,
352) -> Timeout<F> {
353 Timeout {
354 inner: future,
355 deadline: Deadline::from_duration(duration),
356 started_at: Instant::now(),
357 operation_name: Some(name.into()),
358 }
359}
360
361pub async fn with_timeout<F, T>(duration: Duration, future: F) -> Result<T, TimeoutError>
363where
364 F: Future<Output = T>,
365{
366 let started_at = Instant::now();
367 match tokio::time::timeout(duration, future).await {
368 Ok(result) => Ok(result),
369 Err(_) => Err(TimeoutError {
370 operation: None,
371 timeout: duration,
372 elapsed: started_at.elapsed(),
373 }),
374 }
375}
376
377pub async fn with_timeout_named<F, T>(
379 name: impl Into<String>,
380 duration: Duration,
381 future: F,
382) -> Result<T, TimeoutError>
383where
384 F: Future<Output = T>,
385{
386 let name = name.into();
387 let started_at = Instant::now();
388 match tokio::time::timeout(duration, future).await {
389 Ok(result) => Ok(result),
390 Err(_) => Err(TimeoutError {
391 operation: Some(name),
392 timeout: duration,
393 elapsed: started_at.elapsed(),
394 }),
395 }
396}
397
398#[derive(Debug, Default)]
404pub struct TimeoutStats {
405 pub total_operations: AtomicU64,
407 pub completed: AtomicU64,
409 pub timeouts: AtomicU64,
411 pub cancellations: AtomicU64,
413}
414
415impl TimeoutStats {
416 pub fn new() -> Self {
418 Self::default()
419 }
420
421 pub fn record_completed(&self) {
423 self.total_operations.fetch_add(1, Ordering::Relaxed);
424 self.completed.fetch_add(1, Ordering::Relaxed);
425 }
426
427 pub fn record_timeout(&self) {
429 self.total_operations.fetch_add(1, Ordering::Relaxed);
430 self.timeouts.fetch_add(1, Ordering::Relaxed);
431 }
432
433 pub fn record_cancellation(&self) {
435 self.total_operations.fetch_add(1, Ordering::Relaxed);
436 self.cancellations.fetch_add(1, Ordering::Relaxed);
437 }
438
439 pub fn timeout_rate(&self) -> f64 {
441 let total = self.total_operations.load(Ordering::Relaxed);
442 if total == 0 {
443 return 0.0;
444 }
445 let timeouts = self.timeouts.load(Ordering::Relaxed);
446 timeouts as f64 / total as f64
447 }
448
449 pub fn snapshot(&self) -> TimeoutStatsSnapshot {
451 TimeoutStatsSnapshot {
452 total_operations: self.total_operations.load(Ordering::Relaxed),
453 completed: self.completed.load(Ordering::Relaxed),
454 timeouts: self.timeouts.load(Ordering::Relaxed),
455 cancellations: self.cancellations.load(Ordering::Relaxed),
456 }
457 }
458}
459
460#[derive(Debug, Clone)]
462pub struct TimeoutStatsSnapshot {
463 pub total_operations: u64,
465 pub completed: u64,
467 pub timeouts: u64,
469 pub cancellations: u64,
471}
472
473#[cfg(test)]
478mod tests {
479 use super::*;
480
481 #[test]
482 fn test_deadline() {
483 let deadline = Deadline::from_duration(Duration::from_secs(10));
484 assert!(!deadline.is_expired());
485 assert!(deadline.remaining().is_some());
486
487 let expired = Deadline::from_duration(Duration::from_nanos(1));
488 std::thread::sleep(Duration::from_millis(1));
489 assert!(expired.is_expired());
490 assert!(expired.remaining().is_none());
491 }
492
493 #[test]
494 fn test_deadline_with_timeout() {
495 let deadline = Deadline::from_duration(Duration::from_secs(60));
496 let shorter = deadline.with_timeout(Duration::from_secs(5));
497
498 assert!(shorter.expires_at() < deadline.expires_at());
500 }
501
502 #[test]
503 fn test_operation_context() {
504 let ctx = OperationContext::new()
505 .with_name("test_op")
506 .with_timeout(Duration::from_secs(30));
507
508 assert!(!ctx.is_cancelled());
509 assert!(!ctx.is_expired());
510 assert!(ctx.should_continue());
511 assert!(ctx.remaining_time().is_some());
512 }
513
514 #[test]
515 fn test_operation_context_cancellation() {
516 let ctx = OperationContext::new();
517 assert!(!ctx.is_cancelled());
518
519 ctx.cancel();
520 assert!(ctx.is_cancelled());
521 assert!(!ctx.should_continue());
522 }
523
524 #[test]
525 fn test_operation_context_parent() {
526 let parent = OperationContext::new().with_timeout(Duration::from_secs(30));
527
528 let child = parent.child();
529
530 assert!(child.deadline.is_some());
532
533 parent.cancel();
535 assert!(child.is_cancelled());
536 }
537
538 #[test]
539 fn test_cancellation_token() {
540 let token = CancellationToken::new();
541 assert!(!token.is_cancelled());
542
543 token.cancel();
544 assert!(token.is_cancelled());
545
546 let token2 = token.clone();
548 assert!(token2.is_cancelled());
549 }
550
551 #[test]
552 fn test_timeout_error_display() {
553 let error = TimeoutError {
554 operation: Some("send_message".to_string()),
555 timeout: Duration::from_secs(5),
556 elapsed: Duration::from_secs(5),
557 };
558
559 let display = format!("{}", error);
560 assert!(display.contains("send_message"));
561 assert!(display.contains("timed out"));
562 }
563
564 #[test]
565 fn test_timeout_stats() {
566 let stats = TimeoutStats::new();
567
568 stats.record_completed();
569 stats.record_completed();
570 stats.record_timeout();
571 stats.record_cancellation();
572
573 let snapshot = stats.snapshot();
574 assert_eq!(snapshot.total_operations, 4);
575 assert_eq!(snapshot.completed, 2);
576 assert_eq!(snapshot.timeouts, 1);
577 assert_eq!(snapshot.cancellations, 1);
578 assert!((stats.timeout_rate() - 0.25).abs() < 0.01);
579 }
580
581 #[tokio::test]
582 async fn test_with_timeout_success() {
583 let result = with_timeout(Duration::from_secs(5), async {
584 tokio::time::sleep(Duration::from_millis(10)).await;
585 42
586 })
587 .await;
588
589 assert!(result.is_ok());
590 assert_eq!(result.unwrap(), 42);
591 }
592
593 #[tokio::test]
594 async fn test_with_timeout_failure() {
595 let result = with_timeout(Duration::from_millis(10), async {
596 tokio::time::sleep(Duration::from_secs(60)).await;
597 42
598 })
599 .await;
600
601 assert!(result.is_err());
602 let error = result.unwrap_err();
603 assert!(error.elapsed >= Duration::from_millis(10));
604 }
605
606 #[tokio::test]
607 async fn test_with_timeout_named() {
608 let result = with_timeout_named("test_operation", Duration::from_millis(10), async {
609 tokio::time::sleep(Duration::from_secs(60)).await;
610 42
611 })
612 .await;
613
614 assert!(result.is_err());
615 let error = result.unwrap_err();
616 assert_eq!(error.operation, Some("test_operation".to_string()));
617 }
618}