rocketmq_rust/task/
service_task.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17use std::sync::atomic::AtomicBool;
18use std::sync::atomic::Ordering;
19use std::sync::Arc;
20use std::time::Duration;
21use std::time::Instant;
22
23use rocketmq_error::RocketMQResult;
24use tokio::sync::Notify;
25use tokio::sync::RwLock;
26use tokio::time::timeout;
27use tracing::info;
28use tracing::warn;
29
30/// Service thread context that gets passed to the service
31/// This contains all the control mechanisms
32pub struct ServiceContext {
33    /// Wait point for notifications
34    wait_point: Arc<Notify>,
35    /// Notification flag
36    has_notified: Arc<AtomicBool>,
37    /// Stop flag
38    stopped: Arc<AtomicBool>,
39}
40
41impl ServiceContext {
42    pub fn new(
43        wait_point: Arc<Notify>,
44        has_notified: Arc<AtomicBool>,
45        stopped: Arc<AtomicBool>,
46    ) -> Self {
47        Self {
48            wait_point,
49            has_notified,
50            stopped,
51        }
52    }
53
54    /// Check if service is stopped
55    pub fn is_stopped(&self) -> bool {
56        self.stopped.load(Ordering::Acquire)
57    }
58
59    /// Wait for running with interval
60    pub async fn wait_for_running(&self, interval: Duration) -> bool {
61        // Check if already notified
62        if self
63            .has_notified
64            .compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire)
65            .is_ok()
66        {
67            return true; // Should call on_wait_end
68        }
69
70        // Entry to wait
71        match timeout(interval, self.wait_point.notified()).await {
72            Ok(_) => {
73                // Notified
74            }
75            Err(_) => {
76                // Timeout occurred - this is normal behavior
77            }
78        }
79        // Reset notification flag
80        self.has_notified.store(false, Ordering::Release);
81        true // Should call on_wait_end
82    }
83
84    pub fn wakeup(&self) {
85        if self
86            .has_notified
87            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
88            .is_ok()
89        {
90            self.wait_point.notify_one();
91        }
92    }
93}
94
95/*#[trait_variant::make(ServiceTask: Send)]
96pub trait ServiceTaskInner: Sync {
97    /// Get the service name
98    fn get_service_name(&self) -> String;
99
100    /// Main run method - implement the service logic here
101    async fn run(&self, context: &ServiceTaskContext);
102
103    /// Called when wait ends - override for custom behavior
104    async fn on_wait_end(&self);
105
106    /// Get join time for shutdown (default 90 seconds)
107    fn get_join_time(&self) -> Duration {
108        Duration::from_millis(90_000)
109    }
110}*/
111
112pub trait ServiceTask: Sync + Send {
113    /// Get the service name
114    fn get_service_name(&self) -> String;
115
116    /// implement the service logic here
117    fn run(&self, context: &ServiceContext) -> impl ::core::future::Future<Output = ()> + Send;
118
119    /// override for custom behavior
120    fn on_wait_end(&self) -> impl ::core::future::Future<Output = ()> + Send {
121        async {
122            // Default implementation does nothing
123        }
124    }
125
126    /// Get join time for shutdown (default 90 seconds)
127    fn get_join_time(&self) -> Duration {
128        Duration::from_millis(90_000)
129    }
130}
131
132/// Service thread implementation with lifecycle management
133pub struct ServiceManager<T: ServiceTask + 'static> {
134    /// The actual service implementation
135    service: Arc<T>,
136
137    /// Thread state management
138    state: Arc<RwLock<ServiceLifecycle>>,
139
140    /// Stop flag
141    stopped: Arc<AtomicBool>,
142
143    /// Started flag for restart capability
144    started: Arc<AtomicBool>,
145
146    /// Notification flag
147    has_notified: Arc<AtomicBool>,
148
149    /// Wait point for notifications
150    wait_point: Arc<Notify>,
151
152    /// Task handle for the running service
153    task_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
154
155    /// Whether this is a daemon service
156    is_daemon: AtomicBool,
157}
158
159impl<T: ServiceTask> AsRef<T> for ServiceManager<T> {
160    fn as_ref(&self) -> &T {
161        &self.service
162    }
163}
164
165/// Service state enumeration
166#[derive(Debug, Clone, Copy, PartialEq)]
167pub enum ServiceLifecycle {
168    NotStarted,
169    Starting,
170    Running,
171    Stopping,
172    Stopped,
173}
174
175impl<T: ServiceTask + 'static> ServiceManager<T> {
176    /// Create new service thread implementation
177    pub fn new(service: T) -> Self {
178        Self {
179            service: Arc::new(service),
180            state: Arc::new(RwLock::new(ServiceLifecycle::NotStarted)),
181            stopped: Arc::new(AtomicBool::new(false)),
182            started: Arc::new(AtomicBool::new(false)),
183            has_notified: Arc::new(AtomicBool::new(false)),
184            wait_point: Arc::new(Notify::new()),
185            task_handle: Arc::new(RwLock::new(None)),
186            is_daemon: AtomicBool::new(false),
187        }
188    }
189
190    pub fn new_arc(service: Arc<T>) -> Self {
191        Self {
192            service,
193            state: Arc::new(RwLock::new(ServiceLifecycle::NotStarted)),
194            stopped: Arc::new(AtomicBool::new(false)),
195            started: Arc::new(AtomicBool::new(false)),
196            has_notified: Arc::new(AtomicBool::new(false)),
197            wait_point: Arc::new(Notify::new()),
198            task_handle: Arc::new(RwLock::new(None)),
199            is_daemon: AtomicBool::new(false),
200        }
201    }
202
203    /// Start the service thread
204    pub async fn start(&self) -> RocketMQResult<()> {
205        let service_name = self.service.get_service_name();
206
207        info!(
208            "Try to start service thread: {} started: {} current_state: {:?}",
209            service_name,
210            self.started.load(Ordering::Acquire),
211            self.get_lifecycle_state().await
212        );
213
214        // Check if already started
215        if self
216            .started
217            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
218            .is_err()
219        {
220            warn!("Service thread {} is already started", service_name);
221            return Ok(());
222        }
223
224        // Update state
225        {
226            let mut state = self.state.write().await;
227            *state = ServiceLifecycle::Starting;
228        }
229
230        // Reset stopped flag
231        self.stopped.store(false, Ordering::Release);
232
233        // Clone necessary components for the task
234        let service = self.service.clone();
235        let state = self.state.clone();
236        let stopped = self.stopped.clone();
237        let started = self.started.clone();
238        let has_notified = self.has_notified.clone();
239        let wait_point = self.wait_point.clone();
240        let task_handle = self.task_handle.clone();
241
242        // Spawn the service task
243        let handle = tokio::spawn(async move {
244            Self::run_internal(service, state, stopped, started, has_notified, wait_point).await;
245        });
246
247        // Store the task handle
248        {
249            let mut handle_guard = task_handle.write().await;
250            *handle_guard = Some(handle);
251        }
252
253        // Update state to running
254        {
255            let mut state = self.state.write().await;
256            *state = ServiceLifecycle::Running;
257        }
258
259        info!(
260            "Started service thread: {} started: {}",
261            service_name,
262            self.started.load(Ordering::Acquire)
263        );
264
265        Ok(())
266    }
267
268    /// Internal run method
269    async fn run_internal(
270        service: Arc<T>,
271        state: Arc<RwLock<ServiceLifecycle>>,
272        stopped: Arc<AtomicBool>,
273        started: Arc<AtomicBool>,
274        has_notified: Arc<AtomicBool>,
275        wait_point: Arc<Notify>,
276    ) {
277        let service_name = service.get_service_name();
278        info!("Service thread {} is running", service_name);
279
280        // Set state to running
281        {
282            let mut state_guard = state.write().await;
283            *state_guard = ServiceLifecycle::Running;
284        }
285        // Create context for the service
286        let context =
287            ServiceContext::new(wait_point.clone(), has_notified.clone(), stopped.clone());
288        // Run the service
289        service.run(&context).await;
290
291        // Clean up after run completes
292        started.store(false, Ordering::Release);
293        stopped.store(true, Ordering::Release);
294        has_notified.store(false, Ordering::Release);
295
296        {
297            let mut state_guard = state.write().await;
298            *state_guard = ServiceLifecycle::Stopped;
299        }
300
301        info!("Service thread {} has stopped", service_name);
302    }
303
304    /// Shutdown the service
305    pub async fn shutdown(&self) -> RocketMQResult<()> {
306        self.shutdown_with_interrupt(false).await
307    }
308
309    /// Shutdown the service with optional interrupt
310    pub async fn shutdown_with_interrupt(&self, interrupt: bool) -> RocketMQResult<()> {
311        let service_name = self.service.get_service_name();
312
313        info!(
314            "Try to shutdown service thread: {} started: {} current_state: {:?}",
315            service_name,
316            self.started.load(Ordering::Acquire),
317            self.get_lifecycle_state().await
318        );
319
320        // Check if not started
321        if self
322            .started
323            .compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire)
324            .is_err()
325        {
326            warn!("Service thread {} is not running", service_name);
327            return Ok(());
328        }
329
330        // Update state
331        {
332            let mut state = self.state.write().await;
333            *state = ServiceLifecycle::Stopping;
334        }
335
336        // Set stopped flag
337        self.stopped.store(true, Ordering::Release);
338
339        info!("Shutdown thread[{}] interrupt={}", service_name, interrupt);
340
341        // Wake up if thread is waiting
342        self.wakeup();
343
344        let begin_time = Instant::now();
345
346        // Wait for the task to complete
347        let join_time = self.service.get_join_time();
348        let result = if !self.is_daemon() {
349            let mut handle_guard = self.task_handle.write().await;
350            if let Some(handle) = handle_guard.take() {
351                if interrupt {
352                    handle.abort();
353                    Ok(())
354                } else {
355                    match timeout(join_time, handle).await {
356                        Ok(_) => Ok(()),
357                        Err(_) => {
358                            warn!("Service thread {} shutdown timeout", service_name);
359                            Ok(())
360                        }
361                    }
362                }
363            } else {
364                Ok(())
365            }
366        } else {
367            Ok(())
368        };
369
370        let elapsed_time = begin_time.elapsed();
371        info!(
372            "Join thread[{}], elapsed time: {}ms, join time: {}ms",
373            service_name,
374            elapsed_time.as_millis(),
375            join_time.as_millis()
376        );
377
378        // Update final state
379        {
380            let mut state = self.state.write().await;
381            *state = ServiceLifecycle::Stopped;
382        }
383
384        result
385    }
386
387    /// Make the service stop (without waiting)
388    pub fn make_stop(&self) {
389        if !self.started.load(Ordering::Acquire) {
390            return;
391        }
392
393        self.stopped.store(true, Ordering::Release);
394        info!("Make stop thread[{}]", self.service.get_service_name());
395    }
396
397    /// Wake up the service thread
398    pub fn wakeup(&self) {
399        if self
400            .has_notified
401            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
402            .is_ok()
403        {
404            self.wait_point.notify_one();
405        }
406    }
407
408    /// Wait for running with interval
409    pub async fn wait_for_running(&self, interval: Duration) {
410        // Check if already notified
411        if self
412            .has_notified
413            .compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire)
414            .is_ok()
415        {
416            self.service.on_wait_end().await;
417            return;
418        }
419
420        // Wait for notification or timeout
421        let wait_result = timeout(interval, self.wait_point.notified()).await;
422
423        // Reset notification flag
424        self.has_notified.store(false, Ordering::Release);
425
426        // Call on_wait_end regardless of how we were woken up
427        self.service.on_wait_end().await;
428
429        if wait_result.is_err() {
430            // Timeout occurred - this is normal behavior
431        }
432    }
433
434    /// Check if service is stopped
435    pub fn is_stopped(&self) -> bool {
436        self.stopped.load(Ordering::Acquire)
437    }
438
439    /// Check if service is daemon
440    pub fn is_daemon(&self) -> bool {
441        self.is_daemon.load(Ordering::Acquire)
442    }
443
444    /// Set daemon flag
445    pub fn set_daemon(&self, daemon: bool) {
446        self.is_daemon.store(daemon, Ordering::Release);
447    }
448
449    /// Get current service state
450    pub async fn get_lifecycle_state(&self) -> ServiceLifecycle {
451        *self.state.read().await
452    }
453
454    /// Check if service is started
455    pub fn is_started(&self) -> bool {
456        self.started.load(Ordering::Acquire)
457    }
458}
459
460#[cfg(test)]
461mod tests {
462    use tokio::time::sleep;
463    use tokio::time::Duration;
464
465    use super::*;
466    use crate::service_manager;
467
468    /// Example implementation - Transaction Check Service
469    pub struct ExampleTransactionCheckService {
470        check_interval: Duration,
471        transaction_timeout: Duration,
472    }
473
474    impl ExampleTransactionCheckService {
475        pub fn new(check_interval: Duration, transaction_timeout: Duration) -> Self {
476            Self {
477                check_interval,
478                transaction_timeout,
479            }
480        }
481    }
482
483    impl ServiceTask for ExampleTransactionCheckService {
484        fn get_service_name(&self) -> String {
485            "ExampleTransactionCheckService".to_string()
486        }
487
488        async fn run(&self, context: &ServiceContext) {
489            info!("Start transaction check service thread!");
490
491            while !context.is_stopped() {
492                context.wait_for_running(self.check_interval).await;
493            }
494
495            info!("End transaction check service thread!");
496        }
497
498        async fn on_wait_end(&self) {
499            let begin = Instant::now();
500            info!("Begin to check prepare message, begin time: {:?}", begin);
501
502            // Simulate transaction check work
503            self.perform_transaction_check().await;
504
505            let elapsed = begin.elapsed();
506            info!(
507                "End to check prepare message, consumed time: {}ms",
508                elapsed.as_millis()
509            );
510        }
511    }
512
513    impl ExampleTransactionCheckService {
514        async fn perform_transaction_check(&self) {
515            // Simulate work
516            sleep(Duration::from_millis(100)).await;
517            info!(
518                "Transaction check completed with timeout: {:?}",
519                self.transaction_timeout
520            );
521        }
522    }
523
524    impl Clone for ExampleTransactionCheckService {
525        fn clone(&self) -> Self {
526            Self {
527                check_interval: self.check_interval,
528                transaction_timeout: self.transaction_timeout,
529            }
530        }
531    }
532
533    // Use the macro to add service thread functionality
534    service_manager!(ExampleTransactionCheckService);
535
536    #[derive(Clone)]
537    struct TestService {
538        name: String,
539        work_duration: Duration,
540    }
541
542    impl TestService {
543        fn new(name: String, work_duration: Duration) -> Self {
544            Self {
545                name,
546                work_duration,
547            }
548        }
549    }
550
551    impl ServiceTask for TestService {
552        fn get_service_name(&self) -> String {
553            self.name.clone()
554        }
555
556        async fn run(&self, context: &ServiceContext) {
557            println!(
558                "TestService {} starting {}",
559                self.name,
560                context.is_stopped()
561            );
562
563            let mut counter = 0;
564
565            while !context.is_stopped() && counter < 5 {
566                context.wait_for_running(Duration::from_millis(100)).await;
567                println!("TestService {} running iteration {}", self.name, counter);
568                counter += 1;
569            }
570
571            println!(
572                "TestService {} finished after {} iterations",
573                self.name, counter
574            );
575        }
576
577        async fn on_wait_end(&self) {
578            println!("TestService {} performing work", self.name);
579            sleep(self.work_duration).await;
580            println!("TestService {} work completed", self.name);
581        }
582    }
583
584    service_manager!(TestService);
585
586    #[tokio::test]
587    async fn test_service_lifecycle() {
588        let service = TestService::new("test-service".to_string(), Duration::from_millis(50));
589        let service_thread = service.create_service_task();
590
591        // Test initial state
592        assert_eq!(
593            service_thread.get_lifecycle_state().await,
594            ServiceLifecycle::NotStarted
595        );
596        assert!(!service_thread.is_started());
597        assert!(!service_thread.is_stopped());
598
599        // Test start
600        service_thread.start().await.unwrap();
601        assert_eq!(
602            service_thread.get_lifecycle_state().await,
603            ServiceLifecycle::Running
604        );
605        assert!(service_thread.is_started());
606        assert!(!service_thread.is_stopped());
607
608        // Let it run for a bit
609        sleep(Duration::from_millis(300)).await;
610
611        // Test wakeup
612        service_thread.wakeup();
613        sleep(Duration::from_millis(100)).await;
614
615        // Test shutdown
616        service_thread.shutdown().await.unwrap();
617        assert_eq!(
618            service_thread.get_lifecycle_state().await,
619            ServiceLifecycle::Stopped
620        );
621        assert!(!service_thread.is_started());
622        assert!(service_thread.is_stopped());
623    }
624
625    #[tokio::test]
626    async fn test_daemon_service() {
627        let service = TestService::new("daemon-service".to_string(), Duration::from_millis(10));
628        let service_thread = service.create_service_task();
629
630        // Set as daemon
631        service_thread.set_daemon(true);
632        assert!(service_thread.is_daemon());
633
634        // Start and shutdown
635        service_thread.start().await.unwrap();
636        sleep(Duration::from_millis(100)).await;
637        service_thread.shutdown().await.unwrap();
638    }
639
640    #[tokio::test]
641    async fn test_multiple_start_attempts() {
642        let service =
643            TestService::new("multi-start-service".to_string(), Duration::from_millis(10));
644        let service_thread = service.create_service_task();
645
646        // First start should succeed
647        service_thread.start().await.unwrap();
648        assert!(service_thread.is_started());
649
650        // Second start should be ignored
651        service_thread.start().await.unwrap();
652        assert!(service_thread.is_started());
653
654        // Shutdown
655        service_thread.shutdown().await.unwrap();
656        assert!(!service_thread.is_started());
657    }
658
659    #[tokio::test]
660    async fn test_make_stop() {
661        let service = TestService::new("stop-service".to_string(), Duration::from_millis(10));
662        let service_thread = service.create_service_task();
663
664        service_thread.start().await.unwrap();
665        sleep(Duration::from_millis(50)).await;
666
667        // Make stop should set stopped flag
668        service_thread.make_stop();
669        assert!(service_thread.is_stopped());
670
671        // Wait a bit for cleanup
672        sleep(Duration::from_millis(100)).await;
673    }
674
675    #[tokio::test]
676    async fn test_example_transaction_service() {
677        let service = ExampleTransactionCheckService::new(
678            Duration::from_millis(100),
679            Duration::from_millis(1000),
680        );
681        let service_thread = service.create_service_task();
682
683        service_thread.start().await.unwrap();
684        sleep(Duration::from_millis(350)).await;
685        service_thread.wakeup();
686        sleep(Duration::from_millis(150)).await;
687        service_thread.shutdown().await.unwrap();
688    }
689}