1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
// Copyright 2019 kennytm. Licensed under MIT OR Apache-2.0.

//! `async-ctrlc` is an async wrapper of the `ctrlc` crate.

use ctrlc::{set_handler, Error};
use std::{
    future::Future,
    pin::Pin,
    ptr::null_mut,
    sync::atomic::{AtomicBool, AtomicPtr, Ordering},
    task::{Context, Poll, Waker},
};

// TODO: Replace this with `AtomicOptionBox<Waker>`
// after https://github.com/jorendorff/atomicbox/pull/3 is merged.
static WAKER: AtomicPtr<Waker> = AtomicPtr::new(null_mut());
static ACTIVE: AtomicBool = AtomicBool::new(false);

/// A future which is fulfilled when the program receives the Ctrl+C signal.
#[derive(Debug)]
pub struct CtrlC {
    _private: (),
}

impl Future for CtrlC {
    type Output = ();
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        if ACTIVE.load(Ordering::SeqCst) {
            Poll::Ready(())
        } else {
            let new_waker = Box::new(cx.waker().clone());
            let old_waker_ptr = WAKER.swap(Box::into_raw(new_waker), Ordering::SeqCst);
            if !old_waker_ptr.is_null() {
                unsafe { Box::from_raw(old_waker_ptr) };
            }
            Poll::Pending
        }
    }
}

impl CtrlC {
    /// Creates a new `CtrlC` future.
    ///
    /// There should be at most one `CtrlC` instance in the whole program. The
    /// second call to `Ctrl::new()` would return an error.
    pub fn new() -> Result<Self, Error> {
        set_handler(|| {
            ACTIVE.store(true, Ordering::SeqCst);
            let waker_ptr = WAKER.swap(null_mut(), Ordering::SeqCst);
            if !waker_ptr.is_null() {
                unsafe { Box::from_raw(waker_ptr) }.wake();
            }
        })?;
        Ok(CtrlC { _private: () })
    }
}

#[cfg(unix)]
#[test]
fn test_unix() {
    use async_std::{future::timeout, task::block_on};
    use libc::{getpid, kill, SIGINT};
    use std::{
        thread::{sleep, spawn},
        time::Duration,
    };

    let thread = spawn(|| unsafe {
        sleep(Duration::from_millis(100));
        kill(getpid(), SIGINT);
    });

    let c = CtrlC::new().unwrap();
    let result = block_on(async move {
        let i = 1;
        timeout(Duration::from_millis(500), c).await.unwrap();
        i + 2
    });
    assert_eq!(result, 3);

    thread.join().unwrap();
}