nu_plugin_core/util/
waitable.rs

1use std::sync::{
2    Arc, Condvar, Mutex, MutexGuard, PoisonError,
3    atomic::{AtomicBool, Ordering},
4};
5
6use nu_protocol::ShellError;
7
8/// A shared container that may be empty, and allows threads to block until it has a value.
9///
10/// This side is read-only - use [`WaitableMut`] on threads that might write a value.
11#[derive(Debug, Clone)]
12pub struct Waitable<T: Clone + Send> {
13    shared: Arc<WaitableShared<T>>,
14}
15
16#[derive(Debug)]
17pub struct WaitableMut<T: Clone + Send> {
18    shared: Arc<WaitableShared<T>>,
19}
20
21#[derive(Debug)]
22struct WaitableShared<T: Clone + Send> {
23    is_set: AtomicBool,
24    mutex: Mutex<SyncState<T>>,
25    condvar: Condvar,
26}
27
28#[derive(Debug)]
29struct SyncState<T: Clone + Send> {
30    writers: usize,
31    value: Option<T>,
32}
33
34#[track_caller]
35fn fail_if_poisoned<'a, T>(
36    result: Result<MutexGuard<'a, T>, PoisonError<MutexGuard<'a, T>>>,
37) -> Result<MutexGuard<'a, T>, ShellError> {
38    match result {
39        Ok(guard) => Ok(guard),
40        Err(_) => Err(ShellError::NushellFailedHelp {
41            msg: "Waitable mutex poisoned".into(),
42            help: std::panic::Location::caller().to_string(),
43        }),
44    }
45}
46
47impl<T: Clone + Send> WaitableMut<T> {
48    /// Create a new empty `WaitableMut`. Call [`.reader()`](Self::reader) to get [`Waitable`].
49    pub fn new() -> WaitableMut<T> {
50        WaitableMut {
51            shared: Arc::new(WaitableShared {
52                is_set: AtomicBool::new(false),
53                mutex: Mutex::new(SyncState {
54                    writers: 1,
55                    value: None,
56                }),
57                condvar: Condvar::new(),
58            }),
59        }
60    }
61
62    pub fn reader(&self) -> Waitable<T> {
63        Waitable {
64            shared: self.shared.clone(),
65        }
66    }
67
68    /// Set the value and let waiting threads know.
69    #[track_caller]
70    pub fn set(&self, value: T) -> Result<(), ShellError> {
71        let mut sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
72        self.shared.is_set.store(true, Ordering::SeqCst);
73        sync_state.value = Some(value);
74        self.shared.condvar.notify_all();
75        Ok(())
76    }
77}
78
79impl<T: Clone + Send> Default for WaitableMut<T> {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl<T: Clone + Send> Clone for WaitableMut<T> {
86    fn clone(&self) -> Self {
87        let shared = self.shared.clone();
88        shared
89            .mutex
90            .lock()
91            .expect("failed to lock mutex to increment writers")
92            .writers += 1;
93        WaitableMut { shared }
94    }
95}
96
97impl<T: Clone + Send> Drop for WaitableMut<T> {
98    fn drop(&mut self) {
99        // Decrement writers...
100        if let Ok(mut sync_state) = self.shared.mutex.lock() {
101            sync_state.writers = sync_state
102                .writers
103                .checked_sub(1)
104                .expect("would decrement writers below zero");
105        }
106        // and notify waiting threads so they have a chance to see it.
107        self.shared.condvar.notify_all();
108    }
109}
110
111impl<T: Clone + Send> Waitable<T> {
112    /// Wait for a value to be available and then clone it.
113    ///
114    /// Returns `Ok(None)` if there are no writers left that could possibly place a value.
115    #[track_caller]
116    pub fn get(&self) -> Result<Option<T>, ShellError> {
117        let sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
118        if let Some(value) = sync_state.value.clone() {
119            Ok(Some(value))
120        } else if sync_state.writers == 0 {
121            // There can't possibly be a value written, so no point in waiting.
122            Ok(None)
123        } else {
124            let sync_state = fail_if_poisoned(
125                self.shared
126                    .condvar
127                    .wait_while(sync_state, |g| g.writers > 0 && g.value.is_none()),
128            )?;
129            Ok(sync_state.value.clone())
130        }
131    }
132
133    /// Clone the value if one is available, but don't wait if not.
134    #[track_caller]
135    pub fn try_get(&self) -> Result<Option<T>, ShellError> {
136        let sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
137        Ok(sync_state.value.clone())
138    }
139
140    /// Returns true if value is available.
141    #[track_caller]
142    pub fn is_set(&self) -> bool {
143        self.shared.is_set.load(Ordering::SeqCst)
144    }
145}
146
147#[test]
148fn set_from_other_thread() -> Result<(), ShellError> {
149    let waitable_mut = WaitableMut::new();
150    let waitable = waitable_mut.reader();
151
152    assert!(!waitable.is_set());
153
154    std::thread::spawn(move || {
155        waitable_mut.set(42).expect("error on set");
156    });
157
158    assert_eq!(Some(42), waitable.get()?);
159    assert_eq!(Some(42), waitable.try_get()?);
160    assert!(waitable.is_set());
161    Ok(())
162}
163
164#[test]
165fn dont_deadlock_if_waiting_without_writer() {
166    use std::time::Duration;
167
168    let (tx, rx) = std::sync::mpsc::channel();
169    let writer = WaitableMut::<()>::new();
170    let waitable = writer.reader();
171    // Ensure there are no writers
172    drop(writer);
173    std::thread::spawn(move || {
174        let _ = tx.send(waitable.get());
175    });
176    let result = rx
177        .recv_timeout(Duration::from_secs(10))
178        .expect("timed out")
179        .expect("error");
180    assert!(result.is_none());
181}