openraft_rt_compio/
lib.rs1use std::any::Any;
2use std::fmt::Debug;
3use std::fmt::Display;
4use std::fmt::Error;
5use std::fmt::Formatter;
6use std::future::Future;
7use std::pin::Pin;
8use std::task::Context;
9use std::task::Poll;
10
11pub use compio;
12pub use futures;
13use futures::FutureExt;
14use openraft_rt::AsyncRuntime;
15use openraft_rt::OptionalSend;
16pub use rand;
17use rand::rngs::ThreadRng;
18
19use crate::mpsc::FlumeMpsc;
20use crate::oneshot::FuturesOneshot;
21use crate::watch::See;
22
23mod mpsc;
24mod mutex;
25mod oneshot;
26mod watch;
27
28pub struct CompioRuntime {
30 rt: compio::runtime::Runtime,
31}
32
33impl Debug for CompioRuntime {
34 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
35 f.debug_struct("CompioRuntime").finish()
36 }
37}
38
39#[derive(Debug)]
40pub struct CompioJoinError(#[allow(dead_code)] Box<dyn Any + Send>);
41
42impl Display for CompioJoinError {
43 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
44 write!(f, "Spawned task panicked")
45 }
46}
47
48pub struct CompioJoinHandle<T>(Option<compio::runtime::JoinHandle<T>>);
49
50impl<T> Drop for CompioJoinHandle<T> {
51 fn drop(&mut self) {
52 let Some(j) = self.0.take() else {
53 return;
54 };
55 j.detach();
56 }
57}
58
59impl<T> Future for CompioJoinHandle<T> {
60 type Output = Result<T, CompioJoinError>;
61
62 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
63 let this = self.get_mut();
64 let task = this.0.as_mut().expect("Task has been cancelled");
65 match task.poll_unpin(cx) {
66 Poll::Ready(Ok(v)) => Poll::Ready(Ok(v)),
67 Poll::Ready(Err(e)) => Poll::Ready(Err(CompioJoinError(e))),
68 Poll::Pending => Poll::Pending,
69 }
70 }
71}
72
73pub type BoxedFuture<T> = Pin<Box<dyn Future<Output = T>>>;
74
75pin_project_lite::pin_project! {
76 pub struct CompioTimeout<F> {
77 #[pin]
78 future: F,
79 delay: BoxedFuture<()>
80 }
81}
82
83#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
85pub struct Elapsed(());
86
87impl std::fmt::Display for Elapsed {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 write!(f, "Time has elapsed")
90 }
91}
92
93impl<F: Future> Future for CompioTimeout<F> {
94 type Output = Result<F::Output, Elapsed>;
95
96 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
97 let this = self.project();
98 match this.delay.poll_unpin(cx) {
99 Poll::Ready(()) => {
100 Poll::Ready(Err(Elapsed(())))
102 }
103 Poll::Pending => {
104 match this.future.poll(cx) {
106 Poll::Ready(v) => Poll::Ready(Ok(v)),
107 Poll::Pending => Poll::Pending,
108 }
109 }
110 }
111 }
112}
113
114impl AsyncRuntime for CompioRuntime {
115 type JoinError = CompioJoinError;
116 type JoinHandle<T: OptionalSend + 'static> = CompioJoinHandle<T>;
117 type Sleep = BoxedFuture<()>;
118 type Instant = std::time::Instant;
119 type TimeoutError = Elapsed;
120 type Timeout<R, T: Future<Output = R> + OptionalSend> = CompioTimeout<T>;
121 type ThreadLocalRng = ThreadRng;
122 type Mpsc = FlumeMpsc;
123 type Watch = See;
124 type Oneshot = FuturesOneshot;
125 type Mutex<T: OptionalSend + 'static> = mutex::FlumeMutex<T>;
126
127 fn spawn<T>(fut: T) -> Self::JoinHandle<T::Output>
128 where
129 T: Future + OptionalSend + 'static,
130 T::Output: OptionalSend + 'static,
131 {
132 CompioJoinHandle(Some(compio::runtime::spawn(fut)))
133 }
134
135 fn sleep(duration: std::time::Duration) -> Self::Sleep {
136 Box::pin(compio::time::sleep(duration))
137 }
138
139 fn sleep_until(deadline: Self::Instant) -> Self::Sleep {
140 Box::pin(compio::time::sleep_until(deadline))
141 }
142
143 fn timeout<R, F: Future<Output = R> + OptionalSend>(
144 duration: std::time::Duration,
145 future: F,
146 ) -> Self::Timeout<R, F> {
147 let delay = Box::pin(compio::time::sleep(duration));
148 CompioTimeout { future, delay }
149 }
150
151 fn timeout_at<R, F: Future<Output = R> + OptionalSend>(deadline: Self::Instant, future: F) -> Self::Timeout<R, F> {
152 let delay = Box::pin(compio::time::sleep_until(deadline));
153 CompioTimeout { future, delay }
154 }
155
156 fn is_panic(_: &Self::JoinError) -> bool {
157 true
159 }
160
161 fn thread_rng() -> Self::ThreadLocalRng {
162 rand::rng()
163 }
164
165 fn new(_threads: usize) -> Self {
166 let rt = compio::runtime::Runtime::new().expect("Failed to create Compio runtime");
168 CompioRuntime { rt }
169 }
170
171 fn block_on<F, T>(&mut self, future: F) -> T
172 where
173 F: Future<Output = T>,
174 T: OptionalSend,
175 {
176 self.rt.block_on(future)
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use openraft_rt::testing::Suite;
183
184 use super::*;
185
186 #[test]
187 fn test_compio_rt() {
188 CompioRuntime::run(Suite::<CompioRuntime>::test_all());
189 }
190}