compio_runtime/future/combinator/
cancel.rs1use std::{
2 pin::Pin,
3 task::{Context, Poll},
4};
5
6use futures_util::FutureExt;
7use pin_project_lite::pin_project;
8use synchrony::unsync::event::EventListener;
9
10use crate::{
11 CancelToken,
12 future::Ext,
13 waker::{ExtWaker, with_ext},
14};
15
16pin_project! {
17 pub struct WithCancel<F: ?Sized> {
29 cancel: CancelToken,
30 #[pin]
31 future: F,
32 }
33}
34
35pin_project! {
36 pub struct WithCancelFailFast<F: ?Sized> {
44 listen: EventListener,
45 #[pin]
46 future: WithCancel<F>,
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
52pub struct Cancelled;
53
54impl<F: ?Sized> WithCancel<F> {
55 pub fn new(future: F, cancel: CancelToken) -> Self
57 where
58 F: Sized,
59 {
60 Self { cancel, future }
61 }
62}
63
64impl<F> WithCancel<F> {
65 pub fn fail_fast(self) -> WithCancelFailFast<F> {
70 let listen = self.cancel.listen();
71
72 WithCancelFailFast {
73 listen,
74 future: self,
75 }
76 }
77}
78
79impl<F> WithCancelFailFast<F> {
80 pub fn fail_slow(self) -> WithCancel<F> {
84 self.future
85 }
86}
87
88impl std::fmt::Display for Cancelled {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 write!(f, "Cancelled")
91 }
92}
93
94impl std::error::Error for Cancelled {}
95
96impl<F: ?Sized> Future for WithCancel<F>
97where
98 F: Future,
99{
100 type Output = F::Output;
101
102 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
103 let this = self.project();
104
105 with_ext(cx.waker(), |waker, ext: &Ext| {
106 let ext = ext.with_cancel(this.cancel);
107 ExtWaker::new(waker, &ext).poll(this.future)
108 })
109 }
110}
111
112impl<F: ?Sized> Future for WithCancelFailFast<F>
113where
114 F: Future,
115{
116 type Output = Result<F::Output, Cancelled>;
117
118 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
119 let mut this = self.project();
120
121 if this.listen.poll_unpin(cx).is_ready() {
122 return Poll::Ready(Err(Cancelled));
123 }
124
125 this.future.poll_unpin(cx).map(Ok)
126 }
127}