rocketmq_common/common/thread/
thread_service_tokio.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17use std::sync::atomic::AtomicBool;
18use std::sync::atomic::Ordering;
19use std::sync::Arc;
20use std::thread;
21
22use tokio::runtime::Handle;
23use tokio::sync::oneshot;
24use tokio::sync::Mutex;
25use tokio::sync::Notify;
26use tokio::task::JoinHandle;
27use tracing::info;
28use tracing::warn;
29
30use crate::common::thread::Runnable;
31
32pub struct ServiceThreadTokio {
33    name: String,
34    runnable: Arc<Mutex<dyn Runnable>>,
35    thread: Option<JoinHandle<()>>,
36    stopped: Arc<AtomicBool>,
37    started: Arc<AtomicBool>,
38    notified: Arc<Notify>,
39}
40
41impl ServiceThreadTokio {
42    pub fn new(name: String, runnable: Arc<Mutex<dyn Runnable>>) -> Self {
43        ServiceThreadTokio {
44            name,
45            runnable,
46            thread: None,
47            stopped: Arc::new(AtomicBool::new(false)),
48            started: Arc::new(AtomicBool::new(false)),
49            notified: Arc::new(Notify::new()),
50        }
51    }
52
53    pub fn start(&mut self) {
54        let started = self.started.clone();
55        let runnable = self.runnable.clone();
56        let name = self.name.clone();
57        if let Ok(value) =
58            started.compare_exchange(false, true, Ordering::SeqCst, Ordering::Relaxed)
59        {
60            if value {
61                return;
62            }
63        } else {
64            return;
65        }
66        let join_handle = tokio::spawn(async move {
67            info!("Starting service thread: {}", name);
68            let mut guard = runnable.lock().await;
69            guard.run();
70        });
71        self.thread = Some(join_handle);
72    }
73
74    pub fn make_stop(&mut self) {
75        if !self.started.load(Ordering::Acquire) {
76            return;
77        }
78        self.stopped.store(true, Ordering::Release);
79    }
80
81    pub fn is_stopped(&self) -> bool {
82        self.stopped.load(Ordering::Relaxed)
83    }
84
85    pub async fn shutdown(&mut self) {
86        self.shutdown_interrupt(false).await;
87    }
88
89    pub async fn shutdown_interrupt(&mut self, interrupt: bool) {
90        if let Ok(value) =
91            self.started
92                .compare_exchange(true, false, Ordering::SeqCst, Ordering::Relaxed)
93        {
94            if !value {
95                return;
96            }
97        } else {
98            return;
99        }
100        self.stopped.store(true, Ordering::Release);
101        if let Some(thread) = self.thread.take() {
102            info!("Shutting down service thread: {}", self.name);
103            if interrupt {
104                thread.abort();
105            } else {
106                thread.await.expect("Failed to join service thread");
107            }
108        } else {
109            warn!("Service thread not started: {}", self.name);
110        }
111    }
112
113    pub fn wakeup(&self) {
114        self.notified.notify_waiters();
115    }
116
117    pub async fn wait_for_running(&self, interval: u64) {
118        tokio::select! {
119            _ = self.notified.notified() => {}
120            _ = tokio::time::sleep(std::time::Duration::from_millis(interval)) => {}
121        }
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use mockall::automock;
128    use tokio::time;
129    use tokio::time::timeout;
130
131    use super::*;
132
133    struct MockTestRunnable;
134    impl MockTestRunnable {
135        fn new() -> MockTestRunnable {
136            MockTestRunnable
137        }
138    }
139    impl Runnable for MockTestRunnable {
140        fn run(&mut self) {
141            println!("MockTestRunnable run================")
142        }
143    }
144
145    #[tokio::test]
146    async fn test_start_and_shutdown() {
147        let mock_runnable = MockTestRunnable::new();
148
149        let mut service_thread = ServiceThreadTokio::new(
150            "TestServiceThread".to_string(),
151            Arc::new(Mutex::new(mock_runnable)),
152        );
153
154        service_thread.start();
155        assert!(service_thread.started.load(Ordering::SeqCst));
156        assert!(!service_thread.stopped.load(Ordering::SeqCst));
157
158        time::sleep(std::time::Duration::from_secs(1)).await;
159        service_thread.shutdown_interrupt(false).await;
160        assert!(!service_thread.started.load(Ordering::SeqCst));
161        assert!(service_thread.stopped.load(Ordering::SeqCst));
162    }
163
164    #[tokio::test]
165    async fn test_make_stop() {
166        let mock_runnable = MockTestRunnable::new();
167        let mut service_thread = ServiceThreadTokio::new(
168            "TestServiceThread".to_string(),
169            Arc::new(Mutex::new(mock_runnable)),
170        );
171
172        service_thread.start();
173        service_thread.make_stop();
174        assert!(service_thread.is_stopped());
175    }
176
177    #[tokio::test]
178    async fn test_wait_for_running() {
179        let mock_runnable = MockTestRunnable::new();
180        let mut service_thread = ServiceThreadTokio::new(
181            "TestServiceThread".to_string(),
182            Arc::new(Mutex::new(mock_runnable)),
183        );
184
185        service_thread.start();
186        service_thread.wait_for_running(100).await;
187        assert!(service_thread.started.load(Ordering::SeqCst));
188    }
189
190    #[tokio::test]
191    async fn test_wakeup() {
192        let mock_runnable = MockTestRunnable::new();
193        let mut service_thread = ServiceThreadTokio::new(
194            "TestServiceThread".to_string(),
195            Arc::new(Mutex::new(mock_runnable)),
196        );
197
198        service_thread.start();
199        service_thread.wakeup();
200        // We expect that the wakeup method is called successfully.
201    }
202}