nu_plugin_core/util/
waitable.rs1use std::sync::{
2 Arc, Condvar, Mutex, MutexGuard, PoisonError,
3 atomic::{AtomicBool, Ordering},
4};
5
6use nu_protocol::ShellError;
7
8#[derive(Debug, Clone)]
12pub struct Waitable<T: Clone + Send> {
13 shared: Arc<WaitableShared<T>>,
14}
15
16#[derive(Debug)]
17pub struct WaitableMut<T: Clone + Send> {
18 shared: Arc<WaitableShared<T>>,
19}
20
21#[derive(Debug)]
22struct WaitableShared<T: Clone + Send> {
23 is_set: AtomicBool,
24 mutex: Mutex<SyncState<T>>,
25 condvar: Condvar,
26}
27
28#[derive(Debug)]
29struct SyncState<T: Clone + Send> {
30 writers: usize,
31 value: Option<T>,
32}
33
34#[track_caller]
35fn fail_if_poisoned<'a, T>(
36 result: Result<MutexGuard<'a, T>, PoisonError<MutexGuard<'a, T>>>,
37) -> Result<MutexGuard<'a, T>, ShellError> {
38 match result {
39 Ok(guard) => Ok(guard),
40 Err(_) => Err(ShellError::NushellFailedHelp {
41 msg: "Waitable mutex poisoned".into(),
42 help: std::panic::Location::caller().to_string(),
43 }),
44 }
45}
46
47impl<T: Clone + Send> WaitableMut<T> {
48 pub fn new() -> WaitableMut<T> {
50 WaitableMut {
51 shared: Arc::new(WaitableShared {
52 is_set: AtomicBool::new(false),
53 mutex: Mutex::new(SyncState {
54 writers: 1,
55 value: None,
56 }),
57 condvar: Condvar::new(),
58 }),
59 }
60 }
61
62 pub fn reader(&self) -> Waitable<T> {
63 Waitable {
64 shared: self.shared.clone(),
65 }
66 }
67
68 #[track_caller]
70 pub fn set(&self, value: T) -> Result<(), ShellError> {
71 let mut sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
72 self.shared.is_set.store(true, Ordering::SeqCst);
73 sync_state.value = Some(value);
74 self.shared.condvar.notify_all();
75 Ok(())
76 }
77}
78
79impl<T: Clone + Send> Default for WaitableMut<T> {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl<T: Clone + Send> Clone for WaitableMut<T> {
86 fn clone(&self) -> Self {
87 let shared = self.shared.clone();
88 shared
89 .mutex
90 .lock()
91 .expect("failed to lock mutex to increment writers")
92 .writers += 1;
93 WaitableMut { shared }
94 }
95}
96
97impl<T: Clone + Send> Drop for WaitableMut<T> {
98 fn drop(&mut self) {
99 if let Ok(mut sync_state) = self.shared.mutex.lock() {
101 sync_state.writers = sync_state
102 .writers
103 .checked_sub(1)
104 .expect("would decrement writers below zero");
105 }
106 self.shared.condvar.notify_all();
108 }
109}
110
111impl<T: Clone + Send> Waitable<T> {
112 #[track_caller]
116 pub fn get(&self) -> Result<Option<T>, ShellError> {
117 let sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
118 if let Some(value) = sync_state.value.clone() {
119 Ok(Some(value))
120 } else if sync_state.writers == 0 {
121 Ok(None)
123 } else {
124 let sync_state = fail_if_poisoned(
125 self.shared
126 .condvar
127 .wait_while(sync_state, |g| g.writers > 0 && g.value.is_none()),
128 )?;
129 Ok(sync_state.value.clone())
130 }
131 }
132
133 #[track_caller]
135 pub fn try_get(&self) -> Result<Option<T>, ShellError> {
136 let sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
137 Ok(sync_state.value.clone())
138 }
139
140 #[track_caller]
142 pub fn is_set(&self) -> bool {
143 self.shared.is_set.load(Ordering::SeqCst)
144 }
145}
146
147#[test]
148fn set_from_other_thread() -> Result<(), ShellError> {
149 let waitable_mut = WaitableMut::new();
150 let waitable = waitable_mut.reader();
151
152 assert!(!waitable.is_set());
153
154 std::thread::spawn(move || {
155 waitable_mut.set(42).expect("error on set");
156 });
157
158 assert_eq!(Some(42), waitable.get()?);
159 assert_eq!(Some(42), waitable.try_get()?);
160 assert!(waitable.is_set());
161 Ok(())
162}
163
164#[test]
165fn dont_deadlock_if_waiting_without_writer() {
166 use std::time::Duration;
167
168 let (tx, rx) = std::sync::mpsc::channel();
169 let writer = WaitableMut::<()>::new();
170 let waitable = writer.reader();
171 drop(writer);
173 std::thread::spawn(move || {
174 let _ = tx.send(waitable.get());
175 });
176 let result = rx
177 .recv_timeout(Duration::from_secs(10))
178 .expect("timed out")
179 .expect("error");
180 assert!(result.is_none());
181}