rocketmq_common/common/thread/
thread_service_tokio.rs1use std::sync::atomic::AtomicBool;
18use std::sync::atomic::Ordering;
19use std::sync::Arc;
20use std::thread;
21
22use tokio::runtime::Handle;
23use tokio::sync::oneshot;
24use tokio::sync::Mutex;
25use tokio::sync::Notify;
26use tokio::task::JoinHandle;
27use tracing::info;
28use tracing::warn;
29
30use crate::common::thread::Runnable;
31
32pub struct ServiceThreadTokio {
33 name: String,
34 runnable: Arc<Mutex<dyn Runnable>>,
35 thread: Option<JoinHandle<()>>,
36 stopped: Arc<AtomicBool>,
37 started: Arc<AtomicBool>,
38 notified: Arc<Notify>,
39}
40
41impl ServiceThreadTokio {
42 pub fn new(name: String, runnable: Arc<Mutex<dyn Runnable>>) -> Self {
43 ServiceThreadTokio {
44 name,
45 runnable,
46 thread: None,
47 stopped: Arc::new(AtomicBool::new(false)),
48 started: Arc::new(AtomicBool::new(false)),
49 notified: Arc::new(Notify::new()),
50 }
51 }
52
53 pub fn start(&mut self) {
54 let started = self.started.clone();
55 let runnable = self.runnable.clone();
56 let name = self.name.clone();
57 if let Ok(value) =
58 started.compare_exchange(false, true, Ordering::SeqCst, Ordering::Relaxed)
59 {
60 if value {
61 return;
62 }
63 } else {
64 return;
65 }
66 let join_handle = tokio::spawn(async move {
67 info!("Starting service thread: {}", name);
68 let mut guard = runnable.lock().await;
69 guard.run();
70 });
71 self.thread = Some(join_handle);
72 }
73
74 pub fn make_stop(&mut self) {
75 if !self.started.load(Ordering::Acquire) {
76 return;
77 }
78 self.stopped.store(true, Ordering::Release);
79 }
80
81 pub fn is_stopped(&self) -> bool {
82 self.stopped.load(Ordering::Relaxed)
83 }
84
85 pub async fn shutdown(&mut self) {
86 self.shutdown_interrupt(false).await;
87 }
88
89 pub async fn shutdown_interrupt(&mut self, interrupt: bool) {
90 if let Ok(value) =
91 self.started
92 .compare_exchange(true, false, Ordering::SeqCst, Ordering::Relaxed)
93 {
94 if !value {
95 return;
96 }
97 } else {
98 return;
99 }
100 self.stopped.store(true, Ordering::Release);
101 if let Some(thread) = self.thread.take() {
102 info!("Shutting down service thread: {}", self.name);
103 if interrupt {
104 thread.abort();
105 } else {
106 thread.await.expect("Failed to join service thread");
107 }
108 } else {
109 warn!("Service thread not started: {}", self.name);
110 }
111 }
112
113 pub fn wakeup(&self) {
114 self.notified.notify_waiters();
115 }
116
117 pub async fn wait_for_running(&self, interval: u64) {
118 tokio::select! {
119 _ = self.notified.notified() => {}
120 _ = tokio::time::sleep(std::time::Duration::from_millis(interval)) => {}
121 }
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use mockall::automock;
128 use tokio::time;
129 use tokio::time::timeout;
130
131 use super::*;
132
133 struct MockTestRunnable;
134 impl MockTestRunnable {
135 fn new() -> MockTestRunnable {
136 MockTestRunnable
137 }
138 }
139 impl Runnable for MockTestRunnable {
140 fn run(&mut self) {
141 println!("MockTestRunnable run================")
142 }
143 }
144
145 #[tokio::test]
146 async fn test_start_and_shutdown() {
147 let mock_runnable = MockTestRunnable::new();
148
149 let mut service_thread = ServiceThreadTokio::new(
150 "TestServiceThread".to_string(),
151 Arc::new(Mutex::new(mock_runnable)),
152 );
153
154 service_thread.start();
155 assert!(service_thread.started.load(Ordering::SeqCst));
156 assert!(!service_thread.stopped.load(Ordering::SeqCst));
157
158 time::sleep(std::time::Duration::from_secs(1)).await;
159 service_thread.shutdown_interrupt(false).await;
160 assert!(!service_thread.started.load(Ordering::SeqCst));
161 assert!(service_thread.stopped.load(Ordering::SeqCst));
162 }
163
164 #[tokio::test]
165 async fn test_make_stop() {
166 let mock_runnable = MockTestRunnable::new();
167 let mut service_thread = ServiceThreadTokio::new(
168 "TestServiceThread".to_string(),
169 Arc::new(Mutex::new(mock_runnable)),
170 );
171
172 service_thread.start();
173 service_thread.make_stop();
174 assert!(service_thread.is_stopped());
175 }
176
177 #[tokio::test]
178 async fn test_wait_for_running() {
179 let mock_runnable = MockTestRunnable::new();
180 let mut service_thread = ServiceThreadTokio::new(
181 "TestServiceThread".to_string(),
182 Arc::new(Mutex::new(mock_runnable)),
183 );
184
185 service_thread.start();
186 service_thread.wait_for_running(100).await;
187 assert!(service_thread.started.load(Ordering::SeqCst));
188 }
189
190 #[tokio::test]
191 async fn test_wakeup() {
192 let mock_runnable = MockTestRunnable::new();
193 let mut service_thread = ServiceThreadTokio::new(
194 "TestServiceThread".to_string(),
195 Arc::new(Mutex::new(mock_runnable)),
196 );
197
198 service_thread.start();
199 service_thread.wakeup();
200 }
202}