1use std::future::Future;
2use std::time::Duration;
3
4use openraft_rt::AsyncRuntime;
5use openraft_rt::OptionalSend;
6
7mod instant;
8mod mpsc;
9mod mutex;
10mod oneshot;
11mod watch;
12
13pub use instant::TokioInstant;
14pub use mpsc::TokioMpsc;
15pub use mpsc::TokioMpscReceiver;
16pub use mpsc::TokioMpscSender;
17pub use mpsc::TokioMpscWeakSender;
18pub use mutex::TokioMutex;
19pub use oneshot::TokioOneshot;
20pub use oneshot::TokioOneshotSender;
21pub use watch::TokioWatch;
22pub use watch::TokioWatchReceiver;
23pub use watch::TokioWatchSender;
24
25pub struct TokioRuntime {
27 rt: tokio::runtime::Runtime,
28 #[cfg(feature = "single-threaded")]
29 local: tokio::task::LocalSet,
30}
31
32impl std::fmt::Debug for TokioRuntime {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("TokioRuntime").finish()
35 }
36}
37
38impl AsyncRuntime for TokioRuntime {
39 type JoinError = tokio::task::JoinError;
40 type JoinHandle<T: OptionalSend + 'static> = tokio::task::JoinHandle<T>;
41 type Sleep = tokio::time::Sleep;
42 type Instant = TokioInstant;
43 type TimeoutError = tokio::time::error::Elapsed;
44 type Timeout<R, T: Future<Output = R> + OptionalSend> = tokio::time::Timeout<T>;
45 type ThreadLocalRng = rand::rngs::ThreadRng;
46
47 #[inline]
48 fn spawn<T>(future: T) -> Self::JoinHandle<T::Output>
49 where
50 T: Future + OptionalSend + 'static,
51 T::Output: OptionalSend + 'static,
52 {
53 #[cfg(feature = "single-threaded")]
54 {
55 tokio::task::spawn_local(future)
56 }
57 #[cfg(not(feature = "single-threaded"))]
58 {
59 tokio::task::spawn(future)
60 }
61 }
62
63 #[inline]
64 fn sleep(duration: Duration) -> Self::Sleep {
65 tokio::time::sleep(duration)
66 }
67
68 #[inline]
69 fn sleep_until(deadline: Self::Instant) -> Self::Sleep {
70 tokio::time::sleep_until(deadline.0)
71 }
72
73 #[inline]
74 fn timeout<R, F: Future<Output = R> + OptionalSend>(duration: Duration, future: F) -> Self::Timeout<R, F> {
75 tokio::time::timeout(duration, future)
76 }
77
78 #[inline]
79 fn timeout_at<R, F: Future<Output = R> + OptionalSend>(deadline: Self::Instant, future: F) -> Self::Timeout<R, F> {
80 tokio::time::timeout_at(deadline.0, future)
81 }
82
83 #[inline]
84 fn is_panic(join_error: &Self::JoinError) -> bool {
85 join_error.is_panic()
86 }
87
88 #[inline]
89 fn thread_rng() -> Self::ThreadLocalRng {
90 rand::rng()
91 }
92
93 type Mpsc = TokioMpsc;
94 type Watch = TokioWatch;
95 type Oneshot = TokioOneshot;
96 type Mutex<T: OptionalSend + 'static> = TokioMutex<T>;
97
98 fn new(threads: usize) -> Self {
99 let rt = tokio::runtime::Builder::new_multi_thread()
100 .worker_threads(threads)
101 .enable_all()
102 .build()
103 .expect("Failed to create Tokio runtime");
104
105 TokioRuntime {
106 rt,
107 #[cfg(feature = "single-threaded")]
108 local: tokio::task::LocalSet::new(),
109 }
110 }
111
112 fn block_on<F, T>(&mut self, future: F) -> T
113 where
114 F: Future<Output = T>,
115 T: OptionalSend,
116 {
117 #[cfg(feature = "single-threaded")]
118 {
119 self.local.block_on(&self.rt, future)
120 }
121 #[cfg(not(feature = "single-threaded"))]
122 {
123 self.rt.block_on(future)
124 }
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use openraft_rt::testing::Suite;
131
132 use super::*;
133
134 #[test]
135 fn test_tokio_rt() {
136 TokioRuntime::run(Suite::<TokioRuntime>::test_all());
137 }
138}