1use std::{sync::Arc, task::Waker};
2
3use parking_lot::Mutex;
4
5struct FutureSharedInner<Output> {
6 result: Option<Output>,
7 waker: Option<Waker>,
8 dropped: bool,
9}
10
11pub struct FutureSharedState<Output> {
13 inner: Mutex<FutureSharedInner<Output>>,
14}
15
16impl<Output> FutureSharedInner<Output> {
17 pub fn new() -> Self {
18 Self {
19 result: None,
20 waker: None,
21 dropped: false,
22 }
23 }
24}
25
26impl<Output> FutureSharedState<Output> {
27 pub fn new() -> Arc<Self> {
28 Arc::new(Self {
29 inner: Mutex::new(FutureSharedInner::new()),
30 })
31 }
32
33 pub fn resolve(self: &Arc<Self>, result: Output) {
35 let mut inner = self.inner.lock();
36 if inner.dropped {
37 return;
38 }
39
40 inner.result = Some(result);
41 if let Some(waker) = inner.waker.take() {
42 waker.wake();
43 }
44 }
45
46 pub fn register_waker(self: &Arc<Self>, waker: Waker) {
48 let mut inner = self.inner.lock();
49 if inner.dropped {
50 return;
51 }
52
53 inner.waker = Some(waker);
54 if inner.result.is_some()
55 && let Some(waker) = inner.waker.take()
56 {
57 waker.wake();
58 }
59 }
60
61 pub fn take_result(self: &Arc<Self>) -> Option<Output> {
63 let mut inner = self.inner.lock();
64 inner.result.take()
65 }
66
67 pub fn abandon(self: &Arc<Self>) {
69 let mut inner = self.inner.lock();
70 inner.dropped = true;
71 inner.result = None;
72 inner.waker = None;
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use std::sync::{
79 Arc,
80 atomic::{AtomicBool, Ordering},
81 };
82
83 use futures_util::task::{ArcWake, waker_ref};
84
85 use super::*;
86 use crate::guest_data::GuestResult;
87
88 struct FlagWaker {
89 flag: Arc<AtomicBool>,
90 }
91
92 impl ArcWake for FlagWaker {
93 fn wake_by_ref(arc_self: &Arc<Self>) {
94 arc_self.flag.store(true, Ordering::SeqCst);
95 }
96 }
97
98 #[test]
99 fn resolve_notifies_registered_waker() {
100 let state = FutureSharedState::<GuestResult<Vec<u8>>>::new();
101 let flag = Arc::new(AtomicBool::new(false));
102 let waker = waker_ref(&Arc::new(FlagWaker { flag: flag.clone() })).clone();
103
104 state.register_waker(waker);
105 state.resolve(Ok(Vec::new()));
106
107 assert!(flag.load(Ordering::SeqCst));
108 assert!(state.take_result().is_some());
109 }
110}