use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
thread,
};
use tokio::{
runtime::Handle,
sync::{oneshot, Mutex, Notify},
task::JoinHandle,
};
use crate::common::thread::Runnable;
pub struct ServiceThreadTokio {
name: String,
runnable: Arc<Mutex<dyn Runnable>>,
thread: Option<JoinHandle<()>>,
stopped: Arc<AtomicBool>,
started: Arc<AtomicBool>,
notified: Arc<Notify>,
}
impl ServiceThreadTokio {
pub fn new(name: String, runnable: Arc<Mutex<dyn Runnable>>) -> Self {
ServiceThreadTokio {
name,
runnable,
thread: None,
stopped: Arc::new(AtomicBool::new(false)),
started: Arc::new(AtomicBool::new(false)),
notified: Arc::new(Notify::new()),
}
}
pub fn start(&mut self) {
let started = self.started.clone();
let runnable = self.runnable.clone();
let name = self.name.clone();
if let Ok(value) =
started.compare_exchange(false, true, Ordering::SeqCst, Ordering::Relaxed)
{
if value {
return;
}
} else {
return;
}
let join_handle = tokio::spawn(async move {
log::info!("Starting service thread: {}", name);
let mut guard = runnable.lock().await;
guard.run();
});
self.thread = Some(join_handle);
}
pub fn make_stop(&mut self) {
if !self.started.load(Ordering::Acquire) {
return;
}
self.stopped.store(true, Ordering::Release);
}
pub fn is_stopped(&self) -> bool {
self.stopped.load(Ordering::Relaxed)
}
pub async fn shutdown(&mut self) {
self.shutdown_interrupt(false).await;
}
pub async fn shutdown_interrupt(&mut self, interrupt: bool) {
if let Ok(value) =
self.started
.compare_exchange(true, false, Ordering::SeqCst, Ordering::Relaxed)
{
if !value {
return;
}
} else {
return;
}
self.stopped.store(true, Ordering::Release);
if let Some(thread) = self.thread.take() {
log::info!("Shutting down service thread: {}", self.name);
if interrupt {
thread.abort();
} else {
thread.await.expect("Failed to join service thread");
}
} else {
log::warn!("Service thread not started: {}", self.name);
}
}
pub fn wakeup(&self) {
self.notified.notify_waiters();
}
pub async fn wait_for_running(&self, interval: u64) {
tokio::select! {
_ = self.notified.notified() => {}
_ = tokio::time::sleep(std::time::Duration::from_millis(interval)) => {}
}
}
}
#[cfg(test)]
mod tests {
use mockall::automock;
use tokio::{time, time::timeout};
use super::*;
struct MockTestRunnable;
impl MockTestRunnable {
fn new() -> MockTestRunnable {
MockTestRunnable
}
}
impl Runnable for MockTestRunnable {
fn run(&mut self) {
println!("MockTestRunnable run================")
}
}
#[tokio::test]
async fn test_start_and_shutdown() {
let mock_runnable = MockTestRunnable::new();
let mut service_thread = ServiceThreadTokio::new(
"TestServiceThread".to_string(),
Arc::new(Mutex::new(mock_runnable)),
);
service_thread.start();
assert!(service_thread.started.load(Ordering::SeqCst));
assert!(!service_thread.stopped.load(Ordering::SeqCst));
time::sleep(std::time::Duration::from_secs(100)).await;
service_thread.shutdown_interrupt(false).await;
assert!(!service_thread.started.load(Ordering::SeqCst));
assert!(service_thread.stopped.load(Ordering::SeqCst));
}
#[tokio::test]
async fn test_make_stop() {
let mock_runnable = MockTestRunnable::new();
let mut service_thread = ServiceThreadTokio::new(
"TestServiceThread".to_string(),
Arc::new(Mutex::new(mock_runnable)),
);
service_thread.start();
service_thread.make_stop();
assert!(service_thread.is_stopped());
}
#[tokio::test]
async fn test_wait_for_running() {
let mock_runnable = MockTestRunnable::new();
let mut service_thread = ServiceThreadTokio::new(
"TestServiceThread".to_string(),
Arc::new(Mutex::new(mock_runnable)),
);
service_thread.start();
service_thread.wait_for_running(100).await;
assert!(service_thread.started.load(Ordering::SeqCst));
}
#[tokio::test]
async fn test_wakeup() {
let mock_runnable = MockTestRunnable::new();
let mut service_thread = ServiceThreadTokio::new(
"TestServiceThread".to_string(),
Arc::new(Mutex::new(mock_runnable)),
);
service_thread.start();
service_thread.wakeup();
}
}