1use std::sync::atomic::AtomicBool;
18use std::sync::atomic::Ordering;
19use std::sync::Arc;
20use std::time::Duration;
21use std::time::Instant;
22
23use rocketmq_error::RocketMQResult;
24use tokio::sync::Notify;
25use tokio::sync::RwLock;
26use tokio::time::timeout;
27use tracing::info;
28use tracing::warn;
29
30pub struct ServiceContext {
33 wait_point: Arc<Notify>,
35 has_notified: Arc<AtomicBool>,
37 stopped: Arc<AtomicBool>,
39}
40
41impl ServiceContext {
42 pub fn new(
43 wait_point: Arc<Notify>,
44 has_notified: Arc<AtomicBool>,
45 stopped: Arc<AtomicBool>,
46 ) -> Self {
47 Self {
48 wait_point,
49 has_notified,
50 stopped,
51 }
52 }
53
54 pub fn is_stopped(&self) -> bool {
56 self.stopped.load(Ordering::Acquire)
57 }
58
59 pub async fn wait_for_running(&self, interval: Duration) -> bool {
61 if self
63 .has_notified
64 .compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire)
65 .is_ok()
66 {
67 return true; }
69
70 match timeout(interval, self.wait_point.notified()).await {
72 Ok(_) => {
73 }
75 Err(_) => {
76 }
78 }
79 self.has_notified.store(false, Ordering::Release);
81 true }
83
84 pub fn wakeup(&self) {
85 if self
86 .has_notified
87 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
88 .is_ok()
89 {
90 self.wait_point.notify_one();
91 }
92 }
93}
94
95pub trait ServiceTask: Sync + Send {
113 fn get_service_name(&self) -> String;
115
116 fn run(&self, context: &ServiceContext) -> impl ::core::future::Future<Output = ()> + Send;
118
119 fn on_wait_end(&self) -> impl ::core::future::Future<Output = ()> + Send {
121 async {
122 }
124 }
125
126 fn get_join_time(&self) -> Duration {
128 Duration::from_millis(90_000)
129 }
130}
131
132pub struct ServiceManager<T: ServiceTask + 'static> {
134 service: Arc<T>,
136
137 state: Arc<RwLock<ServiceLifecycle>>,
139
140 stopped: Arc<AtomicBool>,
142
143 started: Arc<AtomicBool>,
145
146 has_notified: Arc<AtomicBool>,
148
149 wait_point: Arc<Notify>,
151
152 task_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
154
155 is_daemon: AtomicBool,
157}
158
159impl<T: ServiceTask> AsRef<T> for ServiceManager<T> {
160 fn as_ref(&self) -> &T {
161 &self.service
162 }
163}
164
165#[derive(Debug, Clone, Copy, PartialEq)]
167pub enum ServiceLifecycle {
168 NotStarted,
169 Starting,
170 Running,
171 Stopping,
172 Stopped,
173}
174
175impl<T: ServiceTask + 'static> ServiceManager<T> {
176 pub fn new(service: T) -> Self {
178 Self {
179 service: Arc::new(service),
180 state: Arc::new(RwLock::new(ServiceLifecycle::NotStarted)),
181 stopped: Arc::new(AtomicBool::new(false)),
182 started: Arc::new(AtomicBool::new(false)),
183 has_notified: Arc::new(AtomicBool::new(false)),
184 wait_point: Arc::new(Notify::new()),
185 task_handle: Arc::new(RwLock::new(None)),
186 is_daemon: AtomicBool::new(false),
187 }
188 }
189
190 pub fn new_arc(service: Arc<T>) -> Self {
191 Self {
192 service,
193 state: Arc::new(RwLock::new(ServiceLifecycle::NotStarted)),
194 stopped: Arc::new(AtomicBool::new(false)),
195 started: Arc::new(AtomicBool::new(false)),
196 has_notified: Arc::new(AtomicBool::new(false)),
197 wait_point: Arc::new(Notify::new()),
198 task_handle: Arc::new(RwLock::new(None)),
199 is_daemon: AtomicBool::new(false),
200 }
201 }
202
203 pub async fn start(&self) -> RocketMQResult<()> {
205 let service_name = self.service.get_service_name();
206
207 info!(
208 "Try to start service thread: {} started: {} current_state: {:?}",
209 service_name,
210 self.started.load(Ordering::Acquire),
211 self.get_lifecycle_state().await
212 );
213
214 if self
216 .started
217 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
218 .is_err()
219 {
220 warn!("Service thread {} is already started", service_name);
221 return Ok(());
222 }
223
224 {
226 let mut state = self.state.write().await;
227 *state = ServiceLifecycle::Starting;
228 }
229
230 self.stopped.store(false, Ordering::Release);
232
233 let service = self.service.clone();
235 let state = self.state.clone();
236 let stopped = self.stopped.clone();
237 let started = self.started.clone();
238 let has_notified = self.has_notified.clone();
239 let wait_point = self.wait_point.clone();
240 let task_handle = self.task_handle.clone();
241
242 let handle = tokio::spawn(async move {
244 Self::run_internal(service, state, stopped, started, has_notified, wait_point).await;
245 });
246
247 {
249 let mut handle_guard = task_handle.write().await;
250 *handle_guard = Some(handle);
251 }
252
253 {
255 let mut state = self.state.write().await;
256 *state = ServiceLifecycle::Running;
257 }
258
259 info!(
260 "Started service thread: {} started: {}",
261 service_name,
262 self.started.load(Ordering::Acquire)
263 );
264
265 Ok(())
266 }
267
268 async fn run_internal(
270 service: Arc<T>,
271 state: Arc<RwLock<ServiceLifecycle>>,
272 stopped: Arc<AtomicBool>,
273 started: Arc<AtomicBool>,
274 has_notified: Arc<AtomicBool>,
275 wait_point: Arc<Notify>,
276 ) {
277 let service_name = service.get_service_name();
278 info!("Service thread {} is running", service_name);
279
280 {
282 let mut state_guard = state.write().await;
283 *state_guard = ServiceLifecycle::Running;
284 }
285 let context =
287 ServiceContext::new(wait_point.clone(), has_notified.clone(), stopped.clone());
288 service.run(&context).await;
290
291 started.store(false, Ordering::Release);
293 stopped.store(true, Ordering::Release);
294 has_notified.store(false, Ordering::Release);
295
296 {
297 let mut state_guard = state.write().await;
298 *state_guard = ServiceLifecycle::Stopped;
299 }
300
301 info!("Service thread {} has stopped", service_name);
302 }
303
304 pub async fn shutdown(&self) -> RocketMQResult<()> {
306 self.shutdown_with_interrupt(false).await
307 }
308
309 pub async fn shutdown_with_interrupt(&self, interrupt: bool) -> RocketMQResult<()> {
311 let service_name = self.service.get_service_name();
312
313 info!(
314 "Try to shutdown service thread: {} started: {} current_state: {:?}",
315 service_name,
316 self.started.load(Ordering::Acquire),
317 self.get_lifecycle_state().await
318 );
319
320 if self
322 .started
323 .compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire)
324 .is_err()
325 {
326 warn!("Service thread {} is not running", service_name);
327 return Ok(());
328 }
329
330 {
332 let mut state = self.state.write().await;
333 *state = ServiceLifecycle::Stopping;
334 }
335
336 self.stopped.store(true, Ordering::Release);
338
339 info!("Shutdown thread[{}] interrupt={}", service_name, interrupt);
340
341 self.wakeup();
343
344 let begin_time = Instant::now();
345
346 let join_time = self.service.get_join_time();
348 let result = if !self.is_daemon() {
349 let mut handle_guard = self.task_handle.write().await;
350 if let Some(handle) = handle_guard.take() {
351 if interrupt {
352 handle.abort();
353 Ok(())
354 } else {
355 match timeout(join_time, handle).await {
356 Ok(_) => Ok(()),
357 Err(_) => {
358 warn!("Service thread {} shutdown timeout", service_name);
359 Ok(())
360 }
361 }
362 }
363 } else {
364 Ok(())
365 }
366 } else {
367 Ok(())
368 };
369
370 let elapsed_time = begin_time.elapsed();
371 info!(
372 "Join thread[{}], elapsed time: {}ms, join time: {}ms",
373 service_name,
374 elapsed_time.as_millis(),
375 join_time.as_millis()
376 );
377
378 {
380 let mut state = self.state.write().await;
381 *state = ServiceLifecycle::Stopped;
382 }
383
384 result
385 }
386
387 pub fn make_stop(&self) {
389 if !self.started.load(Ordering::Acquire) {
390 return;
391 }
392
393 self.stopped.store(true, Ordering::Release);
394 info!("Make stop thread[{}]", self.service.get_service_name());
395 }
396
397 pub fn wakeup(&self) {
399 if self
400 .has_notified
401 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
402 .is_ok()
403 {
404 self.wait_point.notify_one();
405 }
406 }
407
408 pub async fn wait_for_running(&self, interval: Duration) {
410 if self
412 .has_notified
413 .compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire)
414 .is_ok()
415 {
416 self.service.on_wait_end().await;
417 return;
418 }
419
420 let wait_result = timeout(interval, self.wait_point.notified()).await;
422
423 self.has_notified.store(false, Ordering::Release);
425
426 self.service.on_wait_end().await;
428
429 if wait_result.is_err() {
430 }
432 }
433
434 pub fn is_stopped(&self) -> bool {
436 self.stopped.load(Ordering::Acquire)
437 }
438
439 pub fn is_daemon(&self) -> bool {
441 self.is_daemon.load(Ordering::Acquire)
442 }
443
444 pub fn set_daemon(&self, daemon: bool) {
446 self.is_daemon.store(daemon, Ordering::Release);
447 }
448
449 pub async fn get_lifecycle_state(&self) -> ServiceLifecycle {
451 *self.state.read().await
452 }
453
454 pub fn is_started(&self) -> bool {
456 self.started.load(Ordering::Acquire)
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use tokio::time::sleep;
463 use tokio::time::Duration;
464
465 use super::*;
466 use crate::service_manager;
467
468 pub struct ExampleTransactionCheckService {
470 check_interval: Duration,
471 transaction_timeout: Duration,
472 }
473
474 impl ExampleTransactionCheckService {
475 pub fn new(check_interval: Duration, transaction_timeout: Duration) -> Self {
476 Self {
477 check_interval,
478 transaction_timeout,
479 }
480 }
481 }
482
483 impl ServiceTask for ExampleTransactionCheckService {
484 fn get_service_name(&self) -> String {
485 "ExampleTransactionCheckService".to_string()
486 }
487
488 async fn run(&self, context: &ServiceContext) {
489 info!("Start transaction check service thread!");
490
491 while !context.is_stopped() {
492 context.wait_for_running(self.check_interval).await;
493 }
494
495 info!("End transaction check service thread!");
496 }
497
498 async fn on_wait_end(&self) {
499 let begin = Instant::now();
500 info!("Begin to check prepare message, begin time: {:?}", begin);
501
502 self.perform_transaction_check().await;
504
505 let elapsed = begin.elapsed();
506 info!(
507 "End to check prepare message, consumed time: {}ms",
508 elapsed.as_millis()
509 );
510 }
511 }
512
513 impl ExampleTransactionCheckService {
514 async fn perform_transaction_check(&self) {
515 sleep(Duration::from_millis(100)).await;
517 info!(
518 "Transaction check completed with timeout: {:?}",
519 self.transaction_timeout
520 );
521 }
522 }
523
524 impl Clone for ExampleTransactionCheckService {
525 fn clone(&self) -> Self {
526 Self {
527 check_interval: self.check_interval,
528 transaction_timeout: self.transaction_timeout,
529 }
530 }
531 }
532
533 service_manager!(ExampleTransactionCheckService);
535
536 #[derive(Clone)]
537 struct TestService {
538 name: String,
539 work_duration: Duration,
540 }
541
542 impl TestService {
543 fn new(name: String, work_duration: Duration) -> Self {
544 Self {
545 name,
546 work_duration,
547 }
548 }
549 }
550
551 impl ServiceTask for TestService {
552 fn get_service_name(&self) -> String {
553 self.name.clone()
554 }
555
556 async fn run(&self, context: &ServiceContext) {
557 println!(
558 "TestService {} starting {}",
559 self.name,
560 context.is_stopped()
561 );
562
563 let mut counter = 0;
564
565 while !context.is_stopped() && counter < 5 {
566 context.wait_for_running(Duration::from_millis(100)).await;
567 println!("TestService {} running iteration {}", self.name, counter);
568 counter += 1;
569 }
570
571 println!(
572 "TestService {} finished after {} iterations",
573 self.name, counter
574 );
575 }
576
577 async fn on_wait_end(&self) {
578 println!("TestService {} performing work", self.name);
579 sleep(self.work_duration).await;
580 println!("TestService {} work completed", self.name);
581 }
582 }
583
584 service_manager!(TestService);
585
586 #[tokio::test]
587 async fn test_service_lifecycle() {
588 let service = TestService::new("test-service".to_string(), Duration::from_millis(50));
589 let service_thread = service.create_service_task();
590
591 assert_eq!(
593 service_thread.get_lifecycle_state().await,
594 ServiceLifecycle::NotStarted
595 );
596 assert!(!service_thread.is_started());
597 assert!(!service_thread.is_stopped());
598
599 service_thread.start().await.unwrap();
601 assert_eq!(
602 service_thread.get_lifecycle_state().await,
603 ServiceLifecycle::Running
604 );
605 assert!(service_thread.is_started());
606 assert!(!service_thread.is_stopped());
607
608 sleep(Duration::from_millis(300)).await;
610
611 service_thread.wakeup();
613 sleep(Duration::from_millis(100)).await;
614
615 service_thread.shutdown().await.unwrap();
617 assert_eq!(
618 service_thread.get_lifecycle_state().await,
619 ServiceLifecycle::Stopped
620 );
621 assert!(!service_thread.is_started());
622 assert!(service_thread.is_stopped());
623 }
624
625 #[tokio::test]
626 async fn test_daemon_service() {
627 let service = TestService::new("daemon-service".to_string(), Duration::from_millis(10));
628 let service_thread = service.create_service_task();
629
630 service_thread.set_daemon(true);
632 assert!(service_thread.is_daemon());
633
634 service_thread.start().await.unwrap();
636 sleep(Duration::from_millis(100)).await;
637 service_thread.shutdown().await.unwrap();
638 }
639
640 #[tokio::test]
641 async fn test_multiple_start_attempts() {
642 let service =
643 TestService::new("multi-start-service".to_string(), Duration::from_millis(10));
644 let service_thread = service.create_service_task();
645
646 service_thread.start().await.unwrap();
648 assert!(service_thread.is_started());
649
650 service_thread.start().await.unwrap();
652 assert!(service_thread.is_started());
653
654 service_thread.shutdown().await.unwrap();
656 assert!(!service_thread.is_started());
657 }
658
659 #[tokio::test]
660 async fn test_make_stop() {
661 let service = TestService::new("stop-service".to_string(), Duration::from_millis(10));
662 let service_thread = service.create_service_task();
663
664 service_thread.start().await.unwrap();
665 sleep(Duration::from_millis(50)).await;
666
667 service_thread.make_stop();
669 assert!(service_thread.is_stopped());
670
671 sleep(Duration::from_millis(100)).await;
673 }
674
675 #[tokio::test]
676 async fn test_example_transaction_service() {
677 let service = ExampleTransactionCheckService::new(
678 Duration::from_millis(100),
679 Duration::from_millis(1000),
680 );
681 let service_thread = service.create_service_task();
682
683 service_thread.start().await.unwrap();
684 sleep(Duration::from_millis(350)).await;
685 service_thread.wakeup();
686 sleep(Duration::from_millis(150)).await;
687 service_thread.shutdown().await.unwrap();
688 }
689}