1use std::sync::atomic::AtomicBool;
16use std::sync::atomic::Ordering;
17use std::sync::Arc;
18use std::time::Duration;
19use std::time::Instant;
20
21use rocketmq_error::RocketMQResult;
22use tokio::sync::Notify;
23use tokio::sync::RwLock;
24use tokio::time::timeout;
25use tracing::info;
26use tracing::warn;
27
28pub struct ServiceContext {
31 wait_point: Arc<Notify>,
33 has_notified: Arc<AtomicBool>,
35 stopped: Arc<AtomicBool>,
37}
38
39impl ServiceContext {
40 pub fn new(wait_point: Arc<Notify>, has_notified: Arc<AtomicBool>, stopped: Arc<AtomicBool>) -> Self {
41 Self {
42 wait_point,
43 has_notified,
44 stopped,
45 }
46 }
47
48 pub fn is_stopped(&self) -> bool {
50 self.stopped.load(Ordering::Acquire)
51 }
52
53 pub async fn wait_for_running(&self, interval: Duration) -> bool {
55 if self
57 .has_notified
58 .compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire)
59 .is_ok()
60 {
61 return true; }
63
64 match timeout(interval, self.wait_point.notified()).await {
66 Ok(_) => {
67 }
69 Err(_) => {
70 }
72 }
73 self.has_notified.store(false, Ordering::Release);
75 true }
77
78 pub fn wakeup(&self) {
79 if self
80 .has_notified
81 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
82 .is_ok()
83 {
84 self.wait_point.notify_one();
85 }
86 }
87}
88
89pub trait ServiceTask: Sync + Send {
90 fn get_service_name(&self) -> String;
92
93 fn run(&self, context: &ServiceContext) -> impl ::core::future::Future<Output = ()> + Send;
95
96 fn on_wait_end(&self) -> impl ::core::future::Future<Output = ()> + Send {
98 async {
99 }
101 }
102
103 fn get_join_time(&self) -> Duration {
105 Duration::from_millis(90_000)
106 }
107}
108
109pub struct ServiceManager<T: ServiceTask + 'static> {
111 service: Arc<T>,
113
114 state: Arc<RwLock<ServiceLifecycle>>,
116
117 stopped: Arc<AtomicBool>,
119
120 started: Arc<AtomicBool>,
122
123 has_notified: Arc<AtomicBool>,
125
126 wait_point: Arc<Notify>,
128
129 task_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
131
132 is_daemon: AtomicBool,
134}
135
136impl<T: ServiceTask> AsRef<T> for ServiceManager<T> {
137 fn as_ref(&self) -> &T {
138 &self.service
139 }
140}
141
142#[derive(Debug, Clone, Copy, PartialEq)]
144pub enum ServiceLifecycle {
145 NotStarted,
146 Starting,
147 Running,
148 Stopping,
149 Stopped,
150}
151
152impl<T: ServiceTask + 'static> ServiceManager<T> {
153 pub fn new(service: T) -> Self {
155 Self {
156 service: Arc::new(service),
157 state: Arc::new(RwLock::new(ServiceLifecycle::NotStarted)),
158 stopped: Arc::new(AtomicBool::new(false)),
159 started: Arc::new(AtomicBool::new(false)),
160 has_notified: Arc::new(AtomicBool::new(false)),
161 wait_point: Arc::new(Notify::new()),
162 task_handle: Arc::new(RwLock::new(None)),
163 is_daemon: AtomicBool::new(false),
164 }
165 }
166
167 pub fn new_arc(service: Arc<T>) -> Self {
168 Self {
169 service,
170 state: Arc::new(RwLock::new(ServiceLifecycle::NotStarted)),
171 stopped: Arc::new(AtomicBool::new(false)),
172 started: Arc::new(AtomicBool::new(false)),
173 has_notified: Arc::new(AtomicBool::new(false)),
174 wait_point: Arc::new(Notify::new()),
175 task_handle: Arc::new(RwLock::new(None)),
176 is_daemon: AtomicBool::new(false),
177 }
178 }
179
180 pub async fn start(&self) -> RocketMQResult<()> {
182 let service_name = self.service.get_service_name();
183
184 info!(
185 "Try to start service thread: {} started: {} current_state: {:?}",
186 service_name,
187 self.started.load(Ordering::Acquire),
188 self.get_lifecycle_state().await
189 );
190
191 if self
193 .started
194 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
195 .is_err()
196 {
197 warn!("Service thread {} is already started", service_name);
198 return Ok(());
199 }
200
201 {
203 let mut state = self.state.write().await;
204 *state = ServiceLifecycle::Starting;
205 }
206
207 self.stopped.store(false, Ordering::Release);
209
210 let service = self.service.clone();
212 let state = self.state.clone();
213 let stopped = self.stopped.clone();
214 let started = self.started.clone();
215 let has_notified = self.has_notified.clone();
216 let wait_point = self.wait_point.clone();
217 let task_handle = self.task_handle.clone();
218
219 let handle = tokio::spawn(async move {
221 Self::run_internal(service, state, stopped, started, has_notified, wait_point).await;
222 });
223
224 {
226 let mut handle_guard = task_handle.write().await;
227 *handle_guard = Some(handle);
228 }
229
230 {
232 let mut state = self.state.write().await;
233 *state = ServiceLifecycle::Running;
234 }
235
236 info!(
237 "Started service thread: {} started: {}",
238 service_name,
239 self.started.load(Ordering::Acquire)
240 );
241
242 Ok(())
243 }
244
245 async fn run_internal(
247 service: Arc<T>,
248 state: Arc<RwLock<ServiceLifecycle>>,
249 stopped: Arc<AtomicBool>,
250 started: Arc<AtomicBool>,
251 has_notified: Arc<AtomicBool>,
252 wait_point: Arc<Notify>,
253 ) {
254 let service_name = service.get_service_name();
255 info!("Service thread {} is running", service_name);
256
257 {
259 let mut state_guard = state.write().await;
260 *state_guard = ServiceLifecycle::Running;
261 }
262 let context = ServiceContext::new(wait_point.clone(), has_notified.clone(), stopped.clone());
264 service.run(&context).await;
266
267 started.store(false, Ordering::Release);
269 stopped.store(true, Ordering::Release);
270 has_notified.store(false, Ordering::Release);
271
272 {
273 let mut state_guard = state.write().await;
274 *state_guard = ServiceLifecycle::Stopped;
275 }
276
277 info!("Service thread {} has stopped", service_name);
278 }
279
280 pub async fn shutdown(&self) -> RocketMQResult<()> {
282 self.shutdown_with_interrupt(false).await
283 }
284
285 pub async fn shutdown_with_interrupt(&self, interrupt: bool) -> RocketMQResult<()> {
287 let service_name = self.service.get_service_name();
288
289 info!(
290 "Try to shutdown service thread: {} started: {} current_state: {:?}",
291 service_name,
292 self.started.load(Ordering::Acquire),
293 self.get_lifecycle_state().await
294 );
295
296 if self
298 .started
299 .compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire)
300 .is_err()
301 {
302 warn!("Service thread {} is not running", service_name);
303 return Ok(());
304 }
305
306 {
308 let mut state = self.state.write().await;
309 *state = ServiceLifecycle::Stopping;
310 }
311
312 self.stopped.store(true, Ordering::Release);
314
315 info!("Shutdown thread[{}] interrupt={}", service_name, interrupt);
316
317 self.wakeup();
319
320 let begin_time = Instant::now();
321
322 let join_time = self.service.get_join_time();
324 let result = if !self.is_daemon() {
325 let mut handle_guard = self.task_handle.write().await;
326 if let Some(handle) = handle_guard.take() {
327 if interrupt {
328 handle.abort();
329 Ok(())
330 } else {
331 match timeout(join_time, handle).await {
332 Ok(_) => Ok(()),
333 Err(_) => {
334 warn!("Service thread {} shutdown timeout", service_name);
335 Ok(())
336 }
337 }
338 }
339 } else {
340 Ok(())
341 }
342 } else {
343 Ok(())
344 };
345
346 let elapsed_time = begin_time.elapsed();
347 info!(
348 "Join thread[{}], elapsed time: {}ms, join time: {}ms",
349 service_name,
350 elapsed_time.as_millis(),
351 join_time.as_millis()
352 );
353
354 {
356 let mut state = self.state.write().await;
357 *state = ServiceLifecycle::Stopped;
358 }
359
360 result
361 }
362
363 pub fn make_stop(&self) {
365 if !self.started.load(Ordering::Acquire) {
366 return;
367 }
368
369 self.stopped.store(true, Ordering::Release);
370 info!("Make stop thread[{}]", self.service.get_service_name());
371 }
372
373 pub fn wakeup(&self) {
375 if self
376 .has_notified
377 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
378 .is_ok()
379 {
380 self.wait_point.notify_one();
381 }
382 }
383
384 pub async fn wait_for_running(&self, interval: Duration) {
386 if self
388 .has_notified
389 .compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire)
390 .is_ok()
391 {
392 self.service.on_wait_end().await;
393 return;
394 }
395
396 let wait_result = timeout(interval, self.wait_point.notified()).await;
398
399 self.has_notified.store(false, Ordering::Release);
401
402 self.service.on_wait_end().await;
404
405 if wait_result.is_err() {
406 }
408 }
409
410 pub fn is_stopped(&self) -> bool {
412 self.stopped.load(Ordering::Acquire)
413 }
414
415 pub fn is_daemon(&self) -> bool {
417 self.is_daemon.load(Ordering::Acquire)
418 }
419
420 pub fn set_daemon(&self, daemon: bool) {
422 self.is_daemon.store(daemon, Ordering::Release);
423 }
424
425 pub async fn get_lifecycle_state(&self) -> ServiceLifecycle {
427 *self.state.read().await
428 }
429
430 pub fn is_started(&self) -> bool {
432 self.started.load(Ordering::Acquire)
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use tokio::time::sleep;
439 use tokio::time::Duration;
440
441 use super::*;
442 use crate::service_manager;
443
444 pub struct ExampleTransactionCheckService {
446 check_interval: Duration,
447 transaction_timeout: Duration,
448 }
449
450 impl ExampleTransactionCheckService {
451 pub fn new(check_interval: Duration, transaction_timeout: Duration) -> Self {
452 Self {
453 check_interval,
454 transaction_timeout,
455 }
456 }
457 }
458
459 impl ServiceTask for ExampleTransactionCheckService {
460 fn get_service_name(&self) -> String {
461 "ExampleTransactionCheckService".to_string()
462 }
463
464 async fn run(&self, context: &ServiceContext) {
465 info!("Start transaction check service thread!");
466
467 while !context.is_stopped() {
468 context.wait_for_running(self.check_interval).await;
469 }
470
471 info!("End transaction check service thread!");
472 }
473
474 async fn on_wait_end(&self) {
475 let begin = Instant::now();
476 info!("Begin to check prepare message, begin time: {:?}", begin);
477
478 self.perform_transaction_check().await;
480
481 let elapsed = begin.elapsed();
482 info!("End to check prepare message, consumed time: {}ms", elapsed.as_millis());
483 }
484 }
485
486 impl ExampleTransactionCheckService {
487 async fn perform_transaction_check(&self) {
488 sleep(Duration::from_millis(100)).await;
490 info!(
491 "Transaction check completed with timeout: {:?}",
492 self.transaction_timeout
493 );
494 }
495 }
496
497 impl Clone for ExampleTransactionCheckService {
498 fn clone(&self) -> Self {
499 Self {
500 check_interval: self.check_interval,
501 transaction_timeout: self.transaction_timeout,
502 }
503 }
504 }
505
506 service_manager!(ExampleTransactionCheckService);
508
509 #[derive(Clone)]
510 struct TestService {
511 name: String,
512 work_duration: Duration,
513 }
514
515 impl TestService {
516 fn new(name: String, work_duration: Duration) -> Self {
517 Self { name, work_duration }
518 }
519 }
520
521 impl ServiceTask for TestService {
522 fn get_service_name(&self) -> String {
523 self.name.clone()
524 }
525
526 async fn run(&self, context: &ServiceContext) {
527 println!("TestService {} starting {}", self.name, context.is_stopped());
528
529 let mut counter = 0;
530
531 while !context.is_stopped() && counter < 5 {
532 context.wait_for_running(Duration::from_millis(100)).await;
533 println!("TestService {} running iteration {}", self.name, counter);
534 counter += 1;
535 }
536
537 println!("TestService {} finished after {} iterations", self.name, counter);
538 }
539
540 async fn on_wait_end(&self) {
541 println!("TestService {} performing work", self.name);
542 sleep(self.work_duration).await;
543 println!("TestService {} work completed", self.name);
544 }
545 }
546
547 service_manager!(TestService);
548
549 #[tokio::test]
550 async fn test_service_lifecycle() {
551 let service = TestService::new("test-service".to_string(), Duration::from_millis(50));
552 let service_thread = service.create_service_task();
553
554 assert_eq!(service_thread.get_lifecycle_state().await, ServiceLifecycle::NotStarted);
556 assert!(!service_thread.is_started());
557 assert!(!service_thread.is_stopped());
558
559 service_thread.start().await.unwrap();
561 assert_eq!(service_thread.get_lifecycle_state().await, ServiceLifecycle::Running);
562 assert!(service_thread.is_started());
563 assert!(!service_thread.is_stopped());
564
565 sleep(Duration::from_millis(300)).await;
567
568 service_thread.wakeup();
570 sleep(Duration::from_millis(100)).await;
571
572 service_thread.shutdown().await.unwrap();
574 assert_eq!(service_thread.get_lifecycle_state().await, ServiceLifecycle::Stopped);
575 assert!(!service_thread.is_started());
576 assert!(service_thread.is_stopped());
577 }
578
579 #[tokio::test]
580 async fn test_daemon_service() {
581 let service = TestService::new("daemon-service".to_string(), Duration::from_millis(10));
582 let service_thread = service.create_service_task();
583
584 service_thread.set_daemon(true);
586 assert!(service_thread.is_daemon());
587
588 service_thread.start().await.unwrap();
590 sleep(Duration::from_millis(100)).await;
591 service_thread.shutdown().await.unwrap();
592 }
593
594 #[tokio::test]
595 async fn test_multiple_start_attempts() {
596 let service = TestService::new("multi-start-service".to_string(), Duration::from_millis(10));
597 let service_thread = service.create_service_task();
598
599 service_thread.start().await.unwrap();
601 assert!(service_thread.is_started());
602
603 service_thread.start().await.unwrap();
605 assert!(service_thread.is_started());
606
607 service_thread.shutdown().await.unwrap();
609 assert!(!service_thread.is_started());
610 }
611
612 #[tokio::test]
613 async fn test_make_stop() {
614 let service = TestService::new("stop-service".to_string(), Duration::from_millis(10));
615 let service_thread = service.create_service_task();
616
617 service_thread.start().await.unwrap();
618 sleep(Duration::from_millis(50)).await;
619
620 service_thread.make_stop();
622 assert!(service_thread.is_stopped());
623
624 sleep(Duration::from_millis(100)).await;
626 }
627
628 #[tokio::test]
629 async fn test_example_transaction_service() {
630 let service = ExampleTransactionCheckService::new(Duration::from_millis(100), Duration::from_millis(1000));
631 let service_thread = service.create_service_task();
632
633 service_thread.start().await.unwrap();
634 sleep(Duration::from_millis(350)).await;
635 service_thread.wakeup();
636 sleep(Duration::from_millis(150)).await;
637 service_thread.shutdown().await.unwrap();
638 }
639}