1use crate::resource::Resource;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll, Waker};
5
6#[derive(Debug)]
7pub enum GuardOutput<T> {
8 Ok(T),
9 InvalidState,
10}
11
12pub trait Guard {
13 type ResourceState;
14 type RunningGuard: RunningGuard;
15
16 fn start(self) -> (Self::ResourceState, Self::RunningGuard);
17}
18
19pub trait RunningGuard {
20 type StoppedResourceState;
21
22 fn is_safe(&self) -> bool;
23 fn set_waker(&mut self, waker: Waker);
24
25 unsafe fn stop(&mut self) -> Self::StoppedResourceState;
26}
27
28pub fn with_guard<T, S, F, G, RG, FUT>(
29 guard: G,
30 initial_state: S,
31 fut_const: F,
32) -> impl Future<Output = (S, RG::StoppedResourceState, GuardOutput<T>)>
33where
34 S: Resource,
35 RG: RunningGuard,
36 G: Guard<RunningGuard = RG>,
37 FUT: Future<Output = T>,
38 F: FnOnce(S, G::ResourceState) -> FUT,
39{
40 let inner_state = unsafe {
41 let mut inner_state = initial_state.clone_state();
42 inner_state.set_cleanup_enabled(false);
43 inner_state
44 };
45
46 let (guard_resource, running_guard) = guard.start();
47
48 let future = fut_const(inner_state, guard_resource);
49 GuardExecutor {
50 running_guard,
51 future,
52 initial_state: Some(initial_state),
53 }
54}
55
56struct GuardExecutor<RG, S, F> {
57 running_guard: RG,
58 future: F,
59 initial_state: Option<S>,
60}
61
62impl<T, RG, S, F> Future for GuardExecutor<RG, S, F>
63where
64 RG: RunningGuard,
65 F: Future<Output = T>,
66 S: Resource,
67{
68 type Output = (S, RG::StoppedResourceState, GuardOutput<T>);
69
70 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
71 let waker = cx.waker().clone();
72 unsafe {
74 self.as_mut()
75 .get_unchecked_mut()
76 .running_guard
77 .set_waker(waker);
78 }
79
80 if !self.running_guard.is_safe() {
81 unsafe {
87 let original_state = self.as_mut().get_unchecked_mut().running_guard.stop();
88 let initial_state = self
89 .as_mut()
90 .get_unchecked_mut()
91 .initial_state
92 .take()
93 .unwrap();
94 return Poll::Ready((initial_state, original_state, GuardOutput::InvalidState));
95 }
96 }
97
98 let fut = unsafe { self.as_mut().map_unchecked_mut(|s| &mut s.future) };
100 match fut.poll(cx) {
101 Poll::Ready(val) => {
102 unsafe {
104 let original_state = self.as_mut().get_unchecked_mut().running_guard.stop();
105 let initial_state = self
106 .as_mut()
107 .get_unchecked_mut()
108 .initial_state
109 .take()
110 .unwrap();
111 Poll::Ready((initial_state, original_state, GuardOutput::Ok(val)))
112 }
113 }
114 Poll::Pending => Poll::Pending,
115 }
116 }
117}