Skip to main content

rocketmq_rust/task/
service_task.rs

1// Copyright 2023 The RocketMQ Rust Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::atomic::AtomicBool;
16use std::sync::atomic::Ordering;
17use std::sync::Arc;
18use std::time::Duration;
19use std::time::Instant;
20
21use rocketmq_error::RocketMQResult;
22use tokio::sync::Notify;
23use tokio::sync::RwLock;
24use tokio::time::timeout;
25use tracing::info;
26use tracing::warn;
27
28/// Service thread context that gets passed to the service
29/// This contains all the control mechanisms
30pub struct ServiceContext {
31    /// Wait point for notifications
32    wait_point: Arc<Notify>,
33    /// Notification flag
34    has_notified: Arc<AtomicBool>,
35    /// Stop flag
36    stopped: Arc<AtomicBool>,
37}
38
39impl ServiceContext {
40    pub fn new(wait_point: Arc<Notify>, has_notified: Arc<AtomicBool>, stopped: Arc<AtomicBool>) -> Self {
41        Self {
42            wait_point,
43            has_notified,
44            stopped,
45        }
46    }
47
48    /// Check if service is stopped
49    pub fn is_stopped(&self) -> bool {
50        self.stopped.load(Ordering::Acquire)
51    }
52
53    /// Wait for running with interval
54    pub async fn wait_for_running(&self, interval: Duration) -> bool {
55        // Check if already notified
56        if self
57            .has_notified
58            .compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire)
59            .is_ok()
60        {
61            return true; // Should call on_wait_end
62        }
63
64        // Entry to wait
65        match timeout(interval, self.wait_point.notified()).await {
66            Ok(_) => {
67                // Notified
68            }
69            Err(_) => {
70                // Timeout occurred - this is normal behavior
71            }
72        }
73        // Reset notification flag
74        self.has_notified.store(false, Ordering::Release);
75        true // Should call on_wait_end
76    }
77
78    pub fn wakeup(&self) {
79        if self
80            .has_notified
81            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
82            .is_ok()
83        {
84            self.wait_point.notify_one();
85        }
86    }
87}
88
89pub trait ServiceTask: Sync + Send {
90    /// Get the service name
91    fn get_service_name(&self) -> String;
92
93    /// implement the service logic here
94    fn run(&self, context: &ServiceContext) -> impl ::core::future::Future<Output = ()> + Send;
95
96    /// override for custom behavior
97    fn on_wait_end(&self) -> impl ::core::future::Future<Output = ()> + Send {
98        async {
99            // Default implementation does nothing
100        }
101    }
102
103    /// Get join time for shutdown (default 90 seconds)
104    fn get_join_time(&self) -> Duration {
105        Duration::from_millis(90_000)
106    }
107}
108
109/// Service thread implementation with lifecycle management
110pub struct ServiceManager<T: ServiceTask + 'static> {
111    /// The actual service implementation
112    service: Arc<T>,
113
114    /// Thread state management
115    state: Arc<RwLock<ServiceLifecycle>>,
116
117    /// Stop flag
118    stopped: Arc<AtomicBool>,
119
120    /// Started flag for restart capability
121    started: Arc<AtomicBool>,
122
123    /// Notification flag
124    has_notified: Arc<AtomicBool>,
125
126    /// Wait point for notifications
127    wait_point: Arc<Notify>,
128
129    /// Task handle for the running service
130    task_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
131
132    /// Whether this is a daemon service
133    is_daemon: AtomicBool,
134}
135
136impl<T: ServiceTask> AsRef<T> for ServiceManager<T> {
137    fn as_ref(&self) -> &T {
138        &self.service
139    }
140}
141
142/// Service state enumeration
143#[derive(Debug, Clone, Copy, PartialEq)]
144pub enum ServiceLifecycle {
145    NotStarted,
146    Starting,
147    Running,
148    Stopping,
149    Stopped,
150}
151
152impl<T: ServiceTask + 'static> ServiceManager<T> {
153    /// Create new service thread implementation
154    pub fn new(service: T) -> Self {
155        Self {
156            service: Arc::new(service),
157            state: Arc::new(RwLock::new(ServiceLifecycle::NotStarted)),
158            stopped: Arc::new(AtomicBool::new(false)),
159            started: Arc::new(AtomicBool::new(false)),
160            has_notified: Arc::new(AtomicBool::new(false)),
161            wait_point: Arc::new(Notify::new()),
162            task_handle: Arc::new(RwLock::new(None)),
163            is_daemon: AtomicBool::new(false),
164        }
165    }
166
167    pub fn new_arc(service: Arc<T>) -> Self {
168        Self {
169            service,
170            state: Arc::new(RwLock::new(ServiceLifecycle::NotStarted)),
171            stopped: Arc::new(AtomicBool::new(false)),
172            started: Arc::new(AtomicBool::new(false)),
173            has_notified: Arc::new(AtomicBool::new(false)),
174            wait_point: Arc::new(Notify::new()),
175            task_handle: Arc::new(RwLock::new(None)),
176            is_daemon: AtomicBool::new(false),
177        }
178    }
179
180    /// Start the service thread
181    pub async fn start(&self) -> RocketMQResult<()> {
182        let service_name = self.service.get_service_name();
183
184        info!(
185            "Try to start service thread: {} started: {} current_state: {:?}",
186            service_name,
187            self.started.load(Ordering::Acquire),
188            self.get_lifecycle_state().await
189        );
190
191        // Check if already started
192        if self
193            .started
194            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
195            .is_err()
196        {
197            warn!("Service thread {} is already started", service_name);
198            return Ok(());
199        }
200
201        // Update state
202        {
203            let mut state = self.state.write().await;
204            *state = ServiceLifecycle::Starting;
205        }
206
207        // Reset stopped flag
208        self.stopped.store(false, Ordering::Release);
209
210        // Clone necessary components for the task
211        let service = self.service.clone();
212        let state = self.state.clone();
213        let stopped = self.stopped.clone();
214        let started = self.started.clone();
215        let has_notified = self.has_notified.clone();
216        let wait_point = self.wait_point.clone();
217        let task_handle = self.task_handle.clone();
218
219        // Spawn the service task
220        let handle = tokio::spawn(async move {
221            Self::run_internal(service, state, stopped, started, has_notified, wait_point).await;
222        });
223
224        // Store the task handle
225        {
226            let mut handle_guard = task_handle.write().await;
227            *handle_guard = Some(handle);
228        }
229
230        // Update state to running
231        {
232            let mut state = self.state.write().await;
233            *state = ServiceLifecycle::Running;
234        }
235
236        info!(
237            "Started service thread: {} started: {}",
238            service_name,
239            self.started.load(Ordering::Acquire)
240        );
241
242        Ok(())
243    }
244
245    /// Internal run method
246    async fn run_internal(
247        service: Arc<T>,
248        state: Arc<RwLock<ServiceLifecycle>>,
249        stopped: Arc<AtomicBool>,
250        started: Arc<AtomicBool>,
251        has_notified: Arc<AtomicBool>,
252        wait_point: Arc<Notify>,
253    ) {
254        let service_name = service.get_service_name();
255        info!("Service thread {} is running", service_name);
256
257        // Set state to running
258        {
259            let mut state_guard = state.write().await;
260            *state_guard = ServiceLifecycle::Running;
261        }
262        // Create context for the service
263        let context = ServiceContext::new(wait_point.clone(), has_notified.clone(), stopped.clone());
264        // Run the service
265        service.run(&context).await;
266
267        // Clean up after run completes
268        started.store(false, Ordering::Release);
269        stopped.store(true, Ordering::Release);
270        has_notified.store(false, Ordering::Release);
271
272        {
273            let mut state_guard = state.write().await;
274            *state_guard = ServiceLifecycle::Stopped;
275        }
276
277        info!("Service thread {} has stopped", service_name);
278    }
279
280    /// Shutdown the service
281    pub async fn shutdown(&self) -> RocketMQResult<()> {
282        self.shutdown_with_interrupt(false).await
283    }
284
285    /// Shutdown the service with optional interrupt
286    pub async fn shutdown_with_interrupt(&self, interrupt: bool) -> RocketMQResult<()> {
287        let service_name = self.service.get_service_name();
288
289        info!(
290            "Try to shutdown service thread: {} started: {} current_state: {:?}",
291            service_name,
292            self.started.load(Ordering::Acquire),
293            self.get_lifecycle_state().await
294        );
295
296        // Check if not started
297        if self
298            .started
299            .compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire)
300            .is_err()
301        {
302            warn!("Service thread {} is not running", service_name);
303            return Ok(());
304        }
305
306        // Update state
307        {
308            let mut state = self.state.write().await;
309            *state = ServiceLifecycle::Stopping;
310        }
311
312        // Set stopped flag
313        self.stopped.store(true, Ordering::Release);
314
315        info!("Shutdown thread[{}] interrupt={}", service_name, interrupt);
316
317        // Wake up if thread is waiting
318        self.wakeup();
319
320        let begin_time = Instant::now();
321
322        // Wait for the task to complete
323        let join_time = self.service.get_join_time();
324        let result = if !self.is_daemon() {
325            let mut handle_guard = self.task_handle.write().await;
326            if let Some(handle) = handle_guard.take() {
327                if interrupt {
328                    handle.abort();
329                    Ok(())
330                } else {
331                    match timeout(join_time, handle).await {
332                        Ok(_) => Ok(()),
333                        Err(_) => {
334                            warn!("Service thread {} shutdown timeout", service_name);
335                            Ok(())
336                        }
337                    }
338                }
339            } else {
340                Ok(())
341            }
342        } else {
343            Ok(())
344        };
345
346        let elapsed_time = begin_time.elapsed();
347        info!(
348            "Join thread[{}], elapsed time: {}ms, join time: {}ms",
349            service_name,
350            elapsed_time.as_millis(),
351            join_time.as_millis()
352        );
353
354        // Update final state
355        {
356            let mut state = self.state.write().await;
357            *state = ServiceLifecycle::Stopped;
358        }
359
360        result
361    }
362
363    /// Make the service stop (without waiting)
364    pub fn make_stop(&self) {
365        if !self.started.load(Ordering::Acquire) {
366            return;
367        }
368
369        self.stopped.store(true, Ordering::Release);
370        info!("Make stop thread[{}]", self.service.get_service_name());
371    }
372
373    /// Wake up the service thread
374    pub fn wakeup(&self) {
375        if self
376            .has_notified
377            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
378            .is_ok()
379        {
380            self.wait_point.notify_one();
381        }
382    }
383
384    /// Wait for running with interval
385    pub async fn wait_for_running(&self, interval: Duration) {
386        // Check if already notified
387        if self
388            .has_notified
389            .compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire)
390            .is_ok()
391        {
392            self.service.on_wait_end().await;
393            return;
394        }
395
396        // Wait for notification or timeout
397        let wait_result = timeout(interval, self.wait_point.notified()).await;
398
399        // Reset notification flag
400        self.has_notified.store(false, Ordering::Release);
401
402        // Call on_wait_end regardless of how we were woken up
403        self.service.on_wait_end().await;
404
405        if wait_result.is_err() {
406            // Timeout occurred - this is normal behavior
407        }
408    }
409
410    /// Check if service is stopped
411    pub fn is_stopped(&self) -> bool {
412        self.stopped.load(Ordering::Acquire)
413    }
414
415    /// Check if service is daemon
416    pub fn is_daemon(&self) -> bool {
417        self.is_daemon.load(Ordering::Acquire)
418    }
419
420    /// Set daemon flag
421    pub fn set_daemon(&self, daemon: bool) {
422        self.is_daemon.store(daemon, Ordering::Release);
423    }
424
425    /// Get current service state
426    pub async fn get_lifecycle_state(&self) -> ServiceLifecycle {
427        *self.state.read().await
428    }
429
430    /// Check if service is started
431    pub fn is_started(&self) -> bool {
432        self.started.load(Ordering::Acquire)
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use tokio::time::sleep;
439    use tokio::time::Duration;
440
441    use super::*;
442    use crate::service_manager;
443
444    /// Example implementation - Transaction Check Service
445    pub struct ExampleTransactionCheckService {
446        check_interval: Duration,
447        transaction_timeout: Duration,
448    }
449
450    impl ExampleTransactionCheckService {
451        pub fn new(check_interval: Duration, transaction_timeout: Duration) -> Self {
452            Self {
453                check_interval,
454                transaction_timeout,
455            }
456        }
457    }
458
459    impl ServiceTask for ExampleTransactionCheckService {
460        fn get_service_name(&self) -> String {
461            "ExampleTransactionCheckService".to_string()
462        }
463
464        async fn run(&self, context: &ServiceContext) {
465            info!("Start transaction check service thread!");
466
467            while !context.is_stopped() {
468                context.wait_for_running(self.check_interval).await;
469            }
470
471            info!("End transaction check service thread!");
472        }
473
474        async fn on_wait_end(&self) {
475            let begin = Instant::now();
476            info!("Begin to check prepare message, begin time: {:?}", begin);
477
478            // Simulate transaction check work
479            self.perform_transaction_check().await;
480
481            let elapsed = begin.elapsed();
482            info!("End to check prepare message, consumed time: {}ms", elapsed.as_millis());
483        }
484    }
485
486    impl ExampleTransactionCheckService {
487        async fn perform_transaction_check(&self) {
488            // Simulate work
489            sleep(Duration::from_millis(100)).await;
490            info!(
491                "Transaction check completed with timeout: {:?}",
492                self.transaction_timeout
493            );
494        }
495    }
496
497    impl Clone for ExampleTransactionCheckService {
498        fn clone(&self) -> Self {
499            Self {
500                check_interval: self.check_interval,
501                transaction_timeout: self.transaction_timeout,
502            }
503        }
504    }
505
506    // Use the macro to add service thread functionality
507    service_manager!(ExampleTransactionCheckService);
508
509    #[derive(Clone)]
510    struct TestService {
511        name: String,
512        work_duration: Duration,
513    }
514
515    impl TestService {
516        fn new(name: String, work_duration: Duration) -> Self {
517            Self { name, work_duration }
518        }
519    }
520
521    impl ServiceTask for TestService {
522        fn get_service_name(&self) -> String {
523            self.name.clone()
524        }
525
526        async fn run(&self, context: &ServiceContext) {
527            println!("TestService {} starting {}", self.name, context.is_stopped());
528
529            let mut counter = 0;
530
531            while !context.is_stopped() && counter < 5 {
532                context.wait_for_running(Duration::from_millis(100)).await;
533                println!("TestService {} running iteration {}", self.name, counter);
534                counter += 1;
535            }
536
537            println!("TestService {} finished after {} iterations", self.name, counter);
538        }
539
540        async fn on_wait_end(&self) {
541            println!("TestService {} performing work", self.name);
542            sleep(self.work_duration).await;
543            println!("TestService {} work completed", self.name);
544        }
545    }
546
547    service_manager!(TestService);
548
549    #[tokio::test]
550    async fn test_service_lifecycle() {
551        let service = TestService::new("test-service".to_string(), Duration::from_millis(50));
552        let service_thread = service.create_service_task();
553
554        // Test initial state
555        assert_eq!(service_thread.get_lifecycle_state().await, ServiceLifecycle::NotStarted);
556        assert!(!service_thread.is_started());
557        assert!(!service_thread.is_stopped());
558
559        // Test start
560        service_thread.start().await.unwrap();
561        assert_eq!(service_thread.get_lifecycle_state().await, ServiceLifecycle::Running);
562        assert!(service_thread.is_started());
563        assert!(!service_thread.is_stopped());
564
565        // Let it run for a bit
566        sleep(Duration::from_millis(300)).await;
567
568        // Test wakeup
569        service_thread.wakeup();
570        sleep(Duration::from_millis(100)).await;
571
572        // Test shutdown
573        service_thread.shutdown().await.unwrap();
574        assert_eq!(service_thread.get_lifecycle_state().await, ServiceLifecycle::Stopped);
575        assert!(!service_thread.is_started());
576        assert!(service_thread.is_stopped());
577    }
578
579    #[tokio::test]
580    async fn test_daemon_service() {
581        let service = TestService::new("daemon-service".to_string(), Duration::from_millis(10));
582        let service_thread = service.create_service_task();
583
584        // Set as daemon
585        service_thread.set_daemon(true);
586        assert!(service_thread.is_daemon());
587
588        // Start and shutdown
589        service_thread.start().await.unwrap();
590        sleep(Duration::from_millis(100)).await;
591        service_thread.shutdown().await.unwrap();
592    }
593
594    #[tokio::test]
595    async fn test_multiple_start_attempts() {
596        let service = TestService::new("multi-start-service".to_string(), Duration::from_millis(10));
597        let service_thread = service.create_service_task();
598
599        // First start should succeed
600        service_thread.start().await.unwrap();
601        assert!(service_thread.is_started());
602
603        // Second start should be ignored
604        service_thread.start().await.unwrap();
605        assert!(service_thread.is_started());
606
607        // Shutdown
608        service_thread.shutdown().await.unwrap();
609        assert!(!service_thread.is_started());
610    }
611
612    #[tokio::test]
613    async fn test_make_stop() {
614        let service = TestService::new("stop-service".to_string(), Duration::from_millis(10));
615        let service_thread = service.create_service_task();
616
617        service_thread.start().await.unwrap();
618        sleep(Duration::from_millis(50)).await;
619
620        // Make stop should set stopped flag
621        service_thread.make_stop();
622        assert!(service_thread.is_stopped());
623
624        // Wait a bit for cleanup
625        sleep(Duration::from_millis(100)).await;
626    }
627
628    #[tokio::test]
629    async fn test_example_transaction_service() {
630        let service = ExampleTransactionCheckService::new(Duration::from_millis(100), Duration::from_millis(1000));
631        let service_thread = service.create_service_task();
632
633        service_thread.start().await.unwrap();
634        sleep(Duration::from_millis(350)).await;
635        service_thread.wakeup();
636        sleep(Duration::from_millis(150)).await;
637        service_thread.shutdown().await.unwrap();
638    }
639}