use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use rocketmq_error::RocketMQResult;
use tokio::sync::Notify;
use tokio::sync::RwLock;
use tokio::time::timeout;
use tracing::info;
use tracing::warn;
pub struct ServiceContext {
wait_point: Arc<Notify>,
has_notified: Arc<AtomicBool>,
stopped: Arc<AtomicBool>,
}
impl ServiceContext {
pub fn new(wait_point: Arc<Notify>, has_notified: Arc<AtomicBool>, stopped: Arc<AtomicBool>) -> Self {
Self {
wait_point,
has_notified,
stopped,
}
}
pub fn is_stopped(&self) -> bool {
self.stopped.load(Ordering::Acquire)
}
pub async fn wait_for_running(&self, interval: Duration) -> bool {
if self
.has_notified
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return true; }
match timeout(interval, self.wait_point.notified()).await {
Ok(_) => {
}
Err(_) => {
}
}
self.has_notified.store(false, Ordering::Release);
true }
pub fn wakeup(&self) {
if self
.has_notified
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
self.wait_point.notify_one();
}
}
}
pub trait ServiceTask: Sync + Send {
fn get_service_name(&self) -> String;
fn run(&self, context: &ServiceContext) -> impl ::core::future::Future<Output = ()> + Send;
fn on_wait_end(&self) -> impl ::core::future::Future<Output = ()> + Send {
async {
}
}
fn get_join_time(&self) -> Duration {
Duration::from_millis(90_000)
}
}
pub struct ServiceManager<T: ServiceTask + 'static> {
service: Arc<T>,
state: Arc<RwLock<ServiceLifecycle>>,
stopped: Arc<AtomicBool>,
started: Arc<AtomicBool>,
has_notified: Arc<AtomicBool>,
wait_point: Arc<Notify>,
task_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
is_daemon: AtomicBool,
}
impl<T: ServiceTask> AsRef<T> for ServiceManager<T> {
fn as_ref(&self) -> &T {
&self.service
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ServiceLifecycle {
NotStarted,
Starting,
Running,
Stopping,
Stopped,
}
impl<T: ServiceTask + 'static> ServiceManager<T> {
pub fn new(service: T) -> Self {
Self {
service: Arc::new(service),
state: Arc::new(RwLock::new(ServiceLifecycle::NotStarted)),
stopped: Arc::new(AtomicBool::new(false)),
started: Arc::new(AtomicBool::new(false)),
has_notified: Arc::new(AtomicBool::new(false)),
wait_point: Arc::new(Notify::new()),
task_handle: Arc::new(RwLock::new(None)),
is_daemon: AtomicBool::new(false),
}
}
pub fn new_arc(service: Arc<T>) -> Self {
Self {
service,
state: Arc::new(RwLock::new(ServiceLifecycle::NotStarted)),
stopped: Arc::new(AtomicBool::new(false)),
started: Arc::new(AtomicBool::new(false)),
has_notified: Arc::new(AtomicBool::new(false)),
wait_point: Arc::new(Notify::new()),
task_handle: Arc::new(RwLock::new(None)),
is_daemon: AtomicBool::new(false),
}
}
pub async fn start(&self) -> RocketMQResult<()> {
let service_name = self.service.get_service_name();
info!(
"Try to start service thread: {} started: {} current_state: {:?}",
service_name,
self.started.load(Ordering::Acquire),
self.get_lifecycle_state().await
);
if self
.started
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
warn!("Service thread {} is already started", service_name);
return Ok(());
}
{
let mut state = self.state.write().await;
*state = ServiceLifecycle::Starting;
}
self.stopped.store(false, Ordering::Release);
let service = self.service.clone();
let state = self.state.clone();
let stopped = self.stopped.clone();
let started = self.started.clone();
let has_notified = self.has_notified.clone();
let wait_point = self.wait_point.clone();
let task_handle = self.task_handle.clone();
let handle = tokio::spawn(async move {
Self::run_internal(service, state, stopped, started, has_notified, wait_point).await;
});
{
let mut handle_guard = task_handle.write().await;
*handle_guard = Some(handle);
}
{
let mut state = self.state.write().await;
*state = ServiceLifecycle::Running;
}
info!(
"Started service thread: {} started: {}",
service_name,
self.started.load(Ordering::Acquire)
);
Ok(())
}
async fn run_internal(
service: Arc<T>,
state: Arc<RwLock<ServiceLifecycle>>,
stopped: Arc<AtomicBool>,
started: Arc<AtomicBool>,
has_notified: Arc<AtomicBool>,
wait_point: Arc<Notify>,
) {
let service_name = service.get_service_name();
info!("Service thread {} is running", service_name);
{
let mut state_guard = state.write().await;
*state_guard = ServiceLifecycle::Running;
}
let context = ServiceContext::new(wait_point.clone(), has_notified.clone(), stopped.clone());
service.run(&context).await;
started.store(false, Ordering::Release);
stopped.store(true, Ordering::Release);
has_notified.store(false, Ordering::Release);
{
let mut state_guard = state.write().await;
*state_guard = ServiceLifecycle::Stopped;
}
info!("Service thread {} has stopped", service_name);
}
pub async fn shutdown(&self) -> RocketMQResult<()> {
self.shutdown_with_interrupt(false).await
}
pub async fn shutdown_with_interrupt(&self, interrupt: bool) -> RocketMQResult<()> {
let service_name = self.service.get_service_name();
info!(
"Try to shutdown service thread: {} started: {} current_state: {:?}",
service_name,
self.started.load(Ordering::Acquire),
self.get_lifecycle_state().await
);
if self
.started
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
warn!("Service thread {} is not running", service_name);
return Ok(());
}
{
let mut state = self.state.write().await;
*state = ServiceLifecycle::Stopping;
}
self.stopped.store(true, Ordering::Release);
info!("Shutdown thread[{}] interrupt={}", service_name, interrupt);
self.wakeup();
let begin_time = Instant::now();
let join_time = self.service.get_join_time();
let result = if !self.is_daemon() {
let mut handle_guard = self.task_handle.write().await;
if let Some(handle) = handle_guard.take() {
if interrupt {
handle.abort();
Ok(())
} else {
match timeout(join_time, handle).await {
Ok(_) => Ok(()),
Err(_) => {
warn!("Service thread {} shutdown timeout", service_name);
Ok(())
}
}
}
} else {
Ok(())
}
} else {
Ok(())
};
let elapsed_time = begin_time.elapsed();
info!(
"Join thread[{}], elapsed time: {}ms, join time: {}ms",
service_name,
elapsed_time.as_millis(),
join_time.as_millis()
);
{
let mut state = self.state.write().await;
*state = ServiceLifecycle::Stopped;
}
result
}
pub fn make_stop(&self) {
if !self.started.load(Ordering::Acquire) {
return;
}
self.stopped.store(true, Ordering::Release);
info!("Make stop thread[{}]", self.service.get_service_name());
}
pub fn wakeup(&self) {
if self
.has_notified
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
self.wait_point.notify_one();
}
}
pub async fn wait_for_running(&self, interval: Duration) {
if self
.has_notified
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
self.service.on_wait_end().await;
return;
}
let wait_result = timeout(interval, self.wait_point.notified()).await;
self.has_notified.store(false, Ordering::Release);
self.service.on_wait_end().await;
if wait_result.is_err() {
}
}
pub fn is_stopped(&self) -> bool {
self.stopped.load(Ordering::Acquire)
}
pub fn is_daemon(&self) -> bool {
self.is_daemon.load(Ordering::Acquire)
}
pub fn set_daemon(&self, daemon: bool) {
self.is_daemon.store(daemon, Ordering::Release);
}
pub async fn get_lifecycle_state(&self) -> ServiceLifecycle {
*self.state.read().await
}
pub fn is_started(&self) -> bool {
self.started.load(Ordering::Acquire)
}
}
#[cfg(test)]
mod tests {
use tokio::time::sleep;
use tokio::time::Duration;
use super::*;
use crate::service_manager;
pub struct ExampleTransactionCheckService {
check_interval: Duration,
transaction_timeout: Duration,
}
impl ExampleTransactionCheckService {
pub fn new(check_interval: Duration, transaction_timeout: Duration) -> Self {
Self {
check_interval,
transaction_timeout,
}
}
}
impl ServiceTask for ExampleTransactionCheckService {
fn get_service_name(&self) -> String {
"ExampleTransactionCheckService".to_string()
}
async fn run(&self, context: &ServiceContext) {
info!("Start transaction check service thread!");
while !context.is_stopped() {
context.wait_for_running(self.check_interval).await;
}
info!("End transaction check service thread!");
}
async fn on_wait_end(&self) {
let begin = Instant::now();
info!("Begin to check prepare message, begin time: {:?}", begin);
self.perform_transaction_check().await;
let elapsed = begin.elapsed();
info!("End to check prepare message, consumed time: {}ms", elapsed.as_millis());
}
}
impl ExampleTransactionCheckService {
async fn perform_transaction_check(&self) {
sleep(Duration::from_millis(100)).await;
info!(
"Transaction check completed with timeout: {:?}",
self.transaction_timeout
);
}
}
impl Clone for ExampleTransactionCheckService {
fn clone(&self) -> Self {
Self {
check_interval: self.check_interval,
transaction_timeout: self.transaction_timeout,
}
}
}
service_manager!(ExampleTransactionCheckService);
#[derive(Clone)]
struct TestService {
name: String,
work_duration: Duration,
}
impl TestService {
fn new(name: String, work_duration: Duration) -> Self {
Self { name, work_duration }
}
}
impl ServiceTask for TestService {
fn get_service_name(&self) -> String {
self.name.clone()
}
async fn run(&self, context: &ServiceContext) {
println!("TestService {} starting {}", self.name, context.is_stopped());
let mut counter = 0;
while !context.is_stopped() && counter < 5 {
context.wait_for_running(Duration::from_millis(100)).await;
println!("TestService {} running iteration {}", self.name, counter);
counter += 1;
}
println!("TestService {} finished after {} iterations", self.name, counter);
}
async fn on_wait_end(&self) {
println!("TestService {} performing work", self.name);
sleep(self.work_duration).await;
println!("TestService {} work completed", self.name);
}
}
service_manager!(TestService);
#[tokio::test]
async fn test_service_lifecycle() {
let service = TestService::new("test-service".to_string(), Duration::from_millis(50));
let service_thread = service.create_service_task();
assert_eq!(service_thread.get_lifecycle_state().await, ServiceLifecycle::NotStarted);
assert!(!service_thread.is_started());
assert!(!service_thread.is_stopped());
service_thread.start().await.unwrap();
assert_eq!(service_thread.get_lifecycle_state().await, ServiceLifecycle::Running);
assert!(service_thread.is_started());
assert!(!service_thread.is_stopped());
sleep(Duration::from_millis(300)).await;
service_thread.wakeup();
sleep(Duration::from_millis(100)).await;
service_thread.shutdown().await.unwrap();
assert_eq!(service_thread.get_lifecycle_state().await, ServiceLifecycle::Stopped);
assert!(!service_thread.is_started());
assert!(service_thread.is_stopped());
}
#[tokio::test]
async fn test_daemon_service() {
let service = TestService::new("daemon-service".to_string(), Duration::from_millis(10));
let service_thread = service.create_service_task();
service_thread.set_daemon(true);
assert!(service_thread.is_daemon());
service_thread.start().await.unwrap();
sleep(Duration::from_millis(100)).await;
service_thread.shutdown().await.unwrap();
}
#[tokio::test]
async fn test_multiple_start_attempts() {
let service = TestService::new("multi-start-service".to_string(), Duration::from_millis(10));
let service_thread = service.create_service_task();
service_thread.start().await.unwrap();
assert!(service_thread.is_started());
service_thread.start().await.unwrap();
assert!(service_thread.is_started());
service_thread.shutdown().await.unwrap();
assert!(!service_thread.is_started());
}
#[tokio::test]
async fn test_make_stop() {
let service = TestService::new("stop-service".to_string(), Duration::from_millis(10));
let service_thread = service.create_service_task();
service_thread.start().await.unwrap();
sleep(Duration::from_millis(50)).await;
service_thread.make_stop();
assert!(service_thread.is_stopped());
sleep(Duration::from_millis(100)).await;
}
#[tokio::test]
async fn test_example_transaction_service() {
let service = ExampleTransactionCheckService::new(Duration::from_millis(100), Duration::from_millis(1000));
let service_thread = service.create_service_task();
service_thread.start().await.unwrap();
sleep(Duration::from_millis(350)).await;
service_thread.wakeup();
sleep(Duration::from_millis(150)).await;
service_thread.shutdown().await.unwrap();
}
}