Skip to main content

rs_pkg/worker/
worker.rs

1use crate::monitor::Monitor;
2use std::{future::Future, sync::Arc};
3use tokio::sync::{
4    Mutex,
5    mpsc::{
6        Receiver, Sender, channel,
7        error::{
8            SendError,
9            TryRecvError::{Disconnected, Empty},
10        },
11    },
12};
13use tracing::{Instrument, debug, debug_span};
14
15#[derive(Clone)]
16pub struct Worker<J> {
17    name: String,
18    work_count: Arc<Mutex<usize>>,
19    monitor: Monitor,
20    recv: Arc<Mutex<Receiver<J>>>,
21    send: Arc<Sender<J>>,
22    graceful: bool,
23}
24
25async fn handle<F, Fut, J>(
26    trigger: Arc<Mutex<Receiver<J>>>,
27    done: Arc<Mutex<Receiver<()>>>,
28    how: Arc<Mutex<F>>,
29    graceful: bool,
30) where
31    F: FnMut(J) -> Fut + Send + Sync + 'static,
32    Fut: Future<Output = ()> + Send + 'static,
33    J: Send + Sync + 'static,
34{
35    match graceful {
36        false => {
37            tokio::spawn(
38                async move {
39                    let mut done = done.lock().await;
40                    let mut how_guard = how.lock().await;
41                    loop {
42                        // 非阻塞检查信号
43                        match done.try_recv() {
44                            Ok(_) | Err(Disconnected) => {
45                                done.close();
46                                return;
47                            }
48
49                            Err(Empty) => {
50                                // 检查是否有新的工作项
51                                let mut guard = trigger.lock().await;
52                                if let Ok(item) = guard.try_recv() {
53                                    drop(guard); // 释放锁,避免在异步调用时持有锁
54                                    how_guard(item).await;
55                                }
56                            }
57                        }
58                    }
59                }
60                .instrument(debug_span!("handle")),
61            );
62        }
63
64        true => {
65            tokio::spawn(
66                async move {
67                    let mut how_guard = how.lock().await;
68                    let mut guard = trigger.lock().await;
69                    loop {
70                        // 检查是否有新的工作项
71                        match guard.recv().await {
72                            Some(item) => {
73                                how_guard(item).await;
74                            }
75
76                            None => return,
77                        }
78                    }
79                }
80                .instrument(debug_span!("grace")),
81            );
82        }
83    }
84}
85
86impl<J> Worker<J> {
87    pub fn new(name: &str, buf: usize) -> Self {
88        let (tx, rx) = channel(buf);
89        let work_count = Arc::new(Mutex::new(0));
90        Self {
91            name: name.to_string(),
92            work_count,
93            monitor: Monitor::new(name),
94            recv: Arc::new(Mutex::new(rx)),
95            graceful: false,
96            send: Arc::new(tx),
97        }
98    }
99
100    pub fn with_on_start<F, Fut>(mut self, task: F) -> Self
101    where
102        F: FnOnce() -> Fut + Send + Sync + 'static,
103        Fut: Future<Output = ()> + Send + Sync + 'static,
104    {
105        self.monitor = self.monitor.with_on_start(task);
106        self
107    }
108
109    pub fn with_on_exit<F, Fut>(mut self, task: F) -> Self
110    where
111        F: FnOnce() -> Fut + Send + Sync + 'static,
112        Fut: Future<Output = ()> + Send + Sync + 'static,
113    {
114        self.monitor = self.monitor.with_on_exit(task);
115        self
116    }
117
118    pub fn with_graceful(mut self, graceful: bool) -> Self {
119        self.graceful = graceful;
120        self
121    }
122
123    pub fn with_trigger(mut self, trigger: (Arc<Sender<J>>, Arc<Mutex<Receiver<J>>>)) -> Self {
124        let (send, recv) = trigger;
125        self.send = send;
126        self.recv = recv;
127        self
128    }
129
130    pub fn get_sender(&self) -> Arc<Sender<J>> {
131        self.send.clone()
132    }
133
134    pub async fn send(&self, job: J) -> Result<(), SendError<J>> {
135        self.send.send(job).await
136    }
137
138    pub fn name(&self) -> String {
139        self.name.to_string()
140    }
141
142    pub async fn count(&self) -> usize {
143        let guard = self.work_count.lock().await;
144        *guard
145    }
146
147    pub async fn stop(&self) -> Result<(), SendError<()>> {
148        self.monitor.stop().await
149    }
150
151    pub async fn run<F, Fut>(&self, how: F)
152    where
153        F: FnMut(J) -> Fut + Send + Sync + 'static,
154        Fut: Future<Output = ()> + Send + 'static,
155        J: Send + Sync + 'static,
156    {
157        debug!("WORKER START - {}", self.name);
158        let trigger = self.recv.clone();
159        let graceful = self.graceful;
160        let how = Arc::new(Mutex::new(how));
161        let task = move |done: Receiver<()>| async move {
162            let done = Arc::new(Mutex::new(done));
163            handle(trigger, done, how, graceful).await;
164        };
165
166        _ = self
167            .monitor
168            .run(task)
169            .instrument(debug_span!("monitor"))
170            .await;
171    }
172}