Skip to main content

romm_cli/core/
interrupt.rs

1use std::fmt;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::sync::Arc;
4
5use anyhow::Error;
6use tokio::sync::Notify;
7
8#[derive(Debug)]
9pub struct CancelledByUser;
10
11impl fmt::Display for CancelledByUser {
12    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13        write!(f, "operation cancelled by user")
14    }
15}
16
17impl std::error::Error for CancelledByUser {}
18
19#[derive(Clone, Debug)]
20pub struct InterruptContext {
21    cancelled: Arc<AtomicBool>,
22    notify: Arc<Notify>,
23}
24
25impl InterruptContext {
26    pub fn new() -> Self {
27        let this = Self {
28            cancelled: Arc::new(AtomicBool::new(false)),
29            notify: Arc::new(Notify::new()),
30        };
31        let watcher = this.clone();
32        tokio::spawn(async move {
33            if tokio::signal::ctrl_c().await.is_ok() {
34                watcher.cancel();
35            }
36        });
37        this
38    }
39
40    pub fn cancel(&self) {
41        self.cancelled.store(true, Ordering::SeqCst);
42        self.notify.notify_waiters();
43    }
44
45    pub fn is_cancelled(&self) -> bool {
46        self.cancelled.load(Ordering::SeqCst)
47    }
48
49    pub async fn cancelled(&self) {
50        if self.is_cancelled() {
51            return;
52        }
53        self.notify.notified().await;
54    }
55}
56
57impl Default for InterruptContext {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63pub fn cancelled_error() -> Error {
64    Error::new(CancelledByUser)
65}
66
67pub fn is_cancelled_error(err: &Error) -> bool {
68    err.downcast_ref::<CancelledByUser>().is_some()
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74
75    #[test]
76    fn cancelled_error_is_classified() {
77        let err = cancelled_error();
78        assert!(is_cancelled_error(&err));
79    }
80
81    #[tokio::test]
82    async fn context_cancel_sets_flag() {
83        let ctx = InterruptContext::new();
84        assert!(!ctx.is_cancelled());
85        ctx.cancel();
86        assert!(ctx.is_cancelled());
87    }
88}