romm_cli/core/
interrupt.rs1use std::fmt;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::sync::Arc;
4
5use anyhow::Error;
6use tokio::sync::Notify;
7
8#[derive(Debug)]
9pub struct CancelledByUser;
10
11impl fmt::Display for CancelledByUser {
12 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13 write!(f, "operation cancelled by user")
14 }
15}
16
17impl std::error::Error for CancelledByUser {}
18
19#[derive(Clone, Debug)]
20pub struct InterruptContext {
21 cancelled: Arc<AtomicBool>,
22 notify: Arc<Notify>,
23}
24
25impl InterruptContext {
26 pub fn new() -> Self {
27 let this = Self {
28 cancelled: Arc::new(AtomicBool::new(false)),
29 notify: Arc::new(Notify::new()),
30 };
31 let watcher = this.clone();
32 tokio::spawn(async move {
33 if tokio::signal::ctrl_c().await.is_ok() {
34 watcher.cancel();
35 }
36 });
37 this
38 }
39
40 pub fn cancel(&self) {
41 self.cancelled.store(true, Ordering::SeqCst);
42 self.notify.notify_waiters();
43 }
44
45 pub fn is_cancelled(&self) -> bool {
46 self.cancelled.load(Ordering::SeqCst)
47 }
48
49 pub async fn cancelled(&self) {
50 if self.is_cancelled() {
51 return;
52 }
53 self.notify.notified().await;
54 }
55}
56
57impl Default for InterruptContext {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63pub fn cancelled_error() -> Error {
64 Error::new(CancelledByUser)
65}
66
67pub fn is_cancelled_error(err: &Error) -> bool {
68 err.downcast_ref::<CancelledByUser>().is_some()
69}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74
75 #[test]
76 fn cancelled_error_is_classified() {
77 let err = cancelled_error();
78 assert!(is_cancelled_error(&err));
79 }
80
81 #[tokio::test]
82 async fn context_cancel_sets_flag() {
83 let ctx = InterruptContext::new();
84 assert!(!ctx.is_cancelled());
85 ctx.cancel();
86 assert!(ctx.is_cancelled());
87 }
88}