1use std::{ops::Not, str::FromStr};
2
3use thiserror::Error;
4use tokio::sync::watch;
5
6#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
7pub enum Gateway {
8 Raised,
9 Lowered,
10}
11pub use Gateway::{Lowered, Raised};
12
13impl Not for Gateway {
14 type Output = Gateway;
15
16 fn not(self) -> Self::Output {
17 match self {
18 Raised => Lowered,
19 Lowered => Raised,
20 }
21 }
22}
23
24impl std::fmt::Display for Gateway {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 write!(
27 f,
28 "{}",
29 match self {
30 Raised => "Raised",
31 Lowered => "Lowered",
32 }
33 )
34 }
35}
36
37#[derive(Clone, Copy, Debug, Error)]
38#[error("failed to parse Gateway: provided string was not `Raised` or `Lowered`")]
39pub struct ParseGatewayError;
40
41impl FromStr for Gateway {
42 type Err = ParseGatewayError;
43
44 #[inline]
45 fn from_str(s: &str) -> Result<Self, Self::Err> {
46 match s {
47 "Raised" => Ok(Raised),
48 "Lowered" => Ok(Lowered),
49 _ => Err(ParseGatewayError),
50 }
51 }
52}
53
54#[derive(Clone, Copy, Debug, Error)]
56#[error("gate was {0} before dropping")]
57pub struct BeforeGateDropped(pub Gateway);
58
59#[derive(Clone, Copy, Debug, Error)]
61#[error("gate was dropped")]
62pub struct GateDropped;
63
64#[derive(Clone, Copy, Debug, Error)]
66#[error("lever was dropped while raised")]
67pub struct LeverDroppedWhileRaised;
68
69#[derive(Clone, Copy, Debug, Error)]
71#[error("lever was dropped while lowered")]
72pub struct LeverDroppedWhileLowered;
73
74#[derive(Clone, Debug)]
79pub struct Lever {
80 sender: watch::Sender<Gateway>,
81}
82
83impl Lever {
84 pub fn raise(&self) -> Result<(), GateDropped> {
92 if self.gate_was_dropped() {
93 Err(GateDropped)
94 } else {
95 self.sender.send_if_modified(|gateway| match gateway {
96 Raised => false,
97 Lowered => {
98 *gateway = Raised;
99 true
100 }
101 });
102
103 Ok(())
104 }
105 }
106
107 pub fn lower(&self) -> Result<(), GateDropped> {
115 if self.gate_was_dropped() {
116 Err(GateDropped)
117 } else {
118 self.sender.send_if_modified(|gateway| match gateway {
119 Lowered => false,
120 Raised => {
121 *gateway = Lowered;
122 true
123 }
124 });
125
126 Ok(())
127 }
128 }
129
130 pub fn is_raised(&self) -> Result<bool, BeforeGateDropped> {
135 let gateway = self.sender.borrow();
136
137 if self.gate_was_dropped() {
138 Err(BeforeGateDropped(*gateway))
139 } else {
140 let is_raised = matches!(*gateway, Raised);
141 Ok(is_raised)
142 }
143 }
144
145 pub fn is_lowered(&self) -> Result<bool, BeforeGateDropped> {
150 let gateway = self.sender.borrow();
151
152 if self.gate_was_dropped() {
153 Err(BeforeGateDropped(*gateway))
154 } else {
155 let is_lowered = matches!(*gateway, Lowered);
156 Ok(is_lowered)
157 }
158 }
159
160 #[must_use]
163 pub fn gate_was_dropped(&self) -> bool {
164 self.sender.is_closed()
165 }
166}
167
168#[derive(Clone, Debug)]
176pub struct Gate {
177 receiver: watch::Receiver<Gateway>,
178}
179
180impl Gate {
181 #[must_use]
183 pub fn is_raised(&self) -> bool {
184 matches!(*self.receiver.borrow(), Raised)
185 }
186
187 #[must_use]
189 pub fn is_lowered(&self) -> bool {
190 matches!(*self.receiver.borrow(), Lowered)
191 }
192
193 pub async fn raised(&mut self) -> Result<(), LeverDroppedWhileLowered> {
198 match self
199 .receiver
200 .wait_for(|gateway| matches!(*gateway, Raised))
201 .await
202 {
203 Ok(_) => Ok(()),
204 Err(_) => Err(LeverDroppedWhileLowered),
205 }
206 }
207
208 pub async fn lowered(&mut self) -> Result<(), LeverDroppedWhileRaised> {
213 match self
214 .receiver
215 .wait_for(|gateway| matches!(*gateway, Lowered))
216 .await
217 {
218 Ok(_) => Ok(()),
219 Err(_) => Err(LeverDroppedWhileRaised),
220 }
221 }
222
223 #[must_use]
226 pub fn lever_was_dropped(&self) -> bool {
227 self.receiver.has_changed().is_err()
228 }
229}
230
231#[must_use]
234#[inline]
235pub fn new(initial: Gateway) -> (Lever, Gate) {
236 let (sender, receiver) = watch::channel(initial);
237
238 let lever = Lever { sender };
239 let gate = Gate { receiver };
240
241 (lever, gate)
242}
243
244#[must_use]
247#[inline]
248pub fn new_raised() -> (Lever, Gate) {
249 new(Raised)
250}
251
252#[must_use]
255#[inline]
256pub fn new_lowered() -> (Lever, Gate) {
257 new(Lowered)
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
268 fn starts_raised_like_it_claims() {
269 let (lever, gate) = new_raised();
270
271 assert!(gate.is_raised());
272 assert!(!gate.is_lowered());
273
274 assert!(lever.is_raised().unwrap());
275 assert!(!lever.is_lowered().unwrap());
276 }
277
278 #[test]
282 fn starts_lowered_like_it_claims() {
283 let (lever, gate) = new_lowered();
284
285 assert!(gate.is_lowered());
286 assert!(!gate.is_raised());
287
288 assert!(lever.is_lowered().unwrap());
289 assert!(!lever.is_raised().unwrap());
290 }
291
292 #[test]
295 fn resolves_instantly() {
296 let (_raised_lever, mut raised_gate) = new_raised();
297 let (_lowered_lever, mut lowered_gate) = new_lowered();
298
299 tokio_test::assert_ready!(tokio_test::task::spawn(raised_gate.raised()).poll()).unwrap();
300 tokio_test::assert_ready!(tokio_test::task::spawn(lowered_gate.lowered()).poll()).unwrap();
301 }
302
303 #[test]
308 fn does_not_resolve_until_satisfied() {
309 let (initially_raised_lever, mut initially_raised_gate) = new_raised();
310 let (initially_lowered_lever, mut initially_lowered_gate) = new_lowered();
311
312 let mut became_lowered = tokio_test::task::spawn(initially_raised_gate.lowered());
313 let mut became_raised = tokio_test::task::spawn(initially_lowered_gate.raised());
314
315 tokio_test::assert_pending!(became_lowered.poll());
316 tokio_test::assert_pending!(became_raised.poll());
317
318 initially_raised_lever.lower().unwrap();
319 initially_lowered_lever.raise().unwrap();
320
321 tokio_test::assert_ready_ok!(became_lowered.poll());
322 tokio_test::assert_ready_ok!(became_raised.poll());
323 }
324
325 #[test]
327 fn lowered_gate_gives_err_on_raised_when_lever_dropped() {
328 let (lever, mut gate) = new_lowered();
329
330 drop(lever);
331
332 assert!(gate.lever_was_dropped());
333
334 tokio_test::assert_ready_err!(tokio_test::task::spawn(gate.raised()).poll());
335 }
336
337 #[test]
339 fn raised_gate_gives_err_on_lowered_when_lever_dropped() {
340 let (lever, mut gate) = new_raised();
341
342 drop(lever);
343
344 assert!(gate.lever_was_dropped());
345
346 tokio_test::assert_ready_err!(tokio_test::task::spawn(gate.lowered()).poll());
347 }
348
349 #[test]
354 fn ok_even_if_lever_dropped_for_matching_state() {
355 let (raised_lever, mut raised_gate) = new_raised();
356 let (lowered_lever, mut lowered_gate) = new_lowered();
357
358 drop(raised_lever);
359 drop(lowered_lever);
360
361 tokio_test::assert_ready_ok!(tokio_test::task::spawn(lowered_gate.lowered()).poll());
362 tokio_test::assert_ready_ok!(tokio_test::task::spawn(raised_gate.raised()).poll());
363 }
364
365 #[test]
367 fn lever_can_check_gate_was_dropped() {
368 let (lever, gate) = new_raised();
369
370 assert!(!lever.gate_was_dropped());
371
372 drop(gate);
373
374 assert!(lever.gate_was_dropped());
375 }
376
377 #[test]
379 fn gate_can_check_lever_was_dropped() {
380 let (lever, gate) = new_raised();
381
382 assert!(!gate.lever_was_dropped());
383
384 drop(lever);
385
386 assert!(gate.lever_was_dropped());
387 }
388
389 #[test]
392 fn lever_can_retrieve_dropped_gate_state() {
393 let (lever, gate) = new_lowered();
394
395 assert!(lever.is_lowered().unwrap());
396 assert!(!lever.is_raised().unwrap());
397
398 drop(gate);
399
400 assert!(matches!(
401 lever.is_lowered().unwrap_err(),
402 BeforeGateDropped(Lowered)
403 ));
404 assert!(matches!(
405 lever.is_raised().unwrap_err(),
406 BeforeGateDropped(Lowered)
407 ));
408 }
409}