Skip to main content

openraft_rt_compio/
lib.rs

1use std::any::Any;
2use std::fmt::Debug;
3use std::fmt::Display;
4use std::fmt::Error;
5use std::fmt::Formatter;
6use std::future::Future;
7use std::pin::Pin;
8use std::task::Context;
9use std::task::Poll;
10
11pub use compio;
12pub use futures;
13use futures::FutureExt;
14use openraft_rt::AsyncRuntime;
15use openraft_rt::OptionalSend;
16pub use rand;
17use rand::rngs::ThreadRng;
18
19use crate::mpsc::FlumeMpsc;
20use crate::oneshot::FuturesOneshot;
21use crate::watch::See;
22
23mod mpsc;
24mod mutex;
25mod oneshot;
26mod watch;
27
28/// Compio async runtime.
29pub struct CompioRuntime {
30    rt: compio::runtime::Runtime,
31}
32
33impl Debug for CompioRuntime {
34    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
35        f.debug_struct("CompioRuntime").finish()
36    }
37}
38
39#[derive(Debug)]
40pub struct CompioJoinError(#[allow(dead_code)] Box<dyn Any + Send>);
41
42impl Display for CompioJoinError {
43    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
44        write!(f, "Spawned task panicked")
45    }
46}
47
48pub struct CompioJoinHandle<T>(Option<compio::runtime::JoinHandle<T>>);
49
50impl<T> Drop for CompioJoinHandle<T> {
51    fn drop(&mut self) {
52        let Some(j) = self.0.take() else {
53            return;
54        };
55        j.detach();
56    }
57}
58
59impl<T> Future for CompioJoinHandle<T> {
60    type Output = Result<T, CompioJoinError>;
61
62    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
63        let this = self.get_mut();
64        let task = this.0.as_mut().expect("Task has been cancelled");
65        match task.poll_unpin(cx) {
66            Poll::Ready(Ok(v)) => Poll::Ready(Ok(v)),
67            Poll::Ready(Err(e)) => Poll::Ready(Err(CompioJoinError(e))),
68            Poll::Pending => Poll::Pending,
69        }
70    }
71}
72
73pub type BoxedFuture<T> = Pin<Box<dyn Future<Output = T>>>;
74
75pin_project_lite::pin_project! {
76    pub struct CompioTimeout<F> {
77        #[pin]
78        future: F,
79        delay: BoxedFuture<()>
80    }
81}
82
83/// Time has elapsed
84#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
85pub struct Elapsed(());
86
87impl std::fmt::Display for Elapsed {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        write!(f, "Time has elapsed")
90    }
91}
92
93impl<F: Future> Future for CompioTimeout<F> {
94    type Output = Result<F::Output, Elapsed>;
95
96    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
97        let this = self.project();
98        match this.delay.poll_unpin(cx) {
99            Poll::Ready(()) => {
100                // The delay has elapsed, so we return an error.
101                Poll::Ready(Err(Elapsed(())))
102            }
103            Poll::Pending => {
104                // The delay has not yet elapsed, so we poll the future.
105                match this.future.poll(cx) {
106                    Poll::Ready(v) => Poll::Ready(Ok(v)),
107                    Poll::Pending => Poll::Pending,
108                }
109            }
110        }
111    }
112}
113
114impl AsyncRuntime for CompioRuntime {
115    type JoinError = CompioJoinError;
116    type JoinHandle<T: OptionalSend + 'static> = CompioJoinHandle<T>;
117    type Sleep = BoxedFuture<()>;
118    type Instant = std::time::Instant;
119    type TimeoutError = Elapsed;
120    type Timeout<R, T: Future<Output = R> + OptionalSend> = CompioTimeout<T>;
121    type ThreadLocalRng = ThreadRng;
122    type Mpsc = FlumeMpsc;
123    type Watch = See;
124    type Oneshot = FuturesOneshot;
125    type Mutex<T: OptionalSend + 'static> = mutex::FlumeMutex<T>;
126
127    fn spawn<T>(fut: T) -> Self::JoinHandle<T::Output>
128    where
129        T: Future + OptionalSend + 'static,
130        T::Output: OptionalSend + 'static,
131    {
132        CompioJoinHandle(Some(compio::runtime::spawn(fut)))
133    }
134
135    fn sleep(duration: std::time::Duration) -> Self::Sleep {
136        Box::pin(compio::time::sleep(duration))
137    }
138
139    fn sleep_until(deadline: Self::Instant) -> Self::Sleep {
140        Box::pin(compio::time::sleep_until(deadline))
141    }
142
143    fn timeout<R, F: Future<Output = R> + OptionalSend>(
144        duration: std::time::Duration,
145        future: F,
146    ) -> Self::Timeout<R, F> {
147        let delay = Box::pin(compio::time::sleep(duration));
148        CompioTimeout { future, delay }
149    }
150
151    fn timeout_at<R, F: Future<Output = R> + OptionalSend>(deadline: Self::Instant, future: F) -> Self::Timeout<R, F> {
152        let delay = Box::pin(compio::time::sleep_until(deadline));
153        CompioTimeout { future, delay }
154    }
155
156    fn is_panic(_: &Self::JoinError) -> bool {
157        // Task only returns `JoinError` if the spawned future panics.
158        true
159    }
160
161    fn thread_rng() -> Self::ThreadLocalRng {
162        rand::rng()
163    }
164
165    fn new(_threads: usize) -> Self {
166        // Compio is single-threaded, ignores threads parameter
167        let rt = compio::runtime::Runtime::new().expect("Failed to create Compio runtime");
168        CompioRuntime { rt }
169    }
170
171    fn block_on<F, T>(&mut self, future: F) -> T
172    where
173        F: Future<Output = T>,
174        T: OptionalSend,
175    {
176        self.rt.block_on(future)
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use openraft_rt::testing::Suite;
183
184    use super::*;
185
186    #[test]
187    fn test_compio_rt() {
188        CompioRuntime::run(Suite::<CompioRuntime>::test_all());
189    }
190}