rfs_runner/worker_pool.rs
1//! `WorkerPool` manages a pool of worker threads for concurrent task execution.
2//!
3//! It provides functionalities to spawn workers, send data to them,
4//! and manage their lifecycle, including graceful shutdown on receiving signals.
5//!
6//! # Examples
7//!
8//! ```no_run
9//! use std::sync::Arc;
10//!
11//! use anyhow::{anyhow, Result};
12//! use lettre::message::{Mailbox, MessageBuilder};
13//! use lettre::transport::smtp::authentication::{Credentials, Mechanism};
14//! use lettre::transport::smtp::PoolConfig;
15//! use lettre::{AsyncSmtpTransport, AsyncTransport, Tokio1Executor};
16//! use rfs_runner::{DefaultTemplate, MainProgress, WorkerPool, WorkerTemplate};
17//! use tokio::io::{AsyncBufReadExt, AsyncSeekExt};
18//! const MAX_WORKER: u32 = 8;
19//!
20//! #[tokio::main]
21//! async fn main() -> Result<()> {
22//! use std::env::temp_dir;
23//!
24//! let smtp_credentials = Credentials::new("example@gmail.com".to_owned(), "KeepSecretPasswordFromCommit".to_owned());
25//! let pool_config = PoolConfig::new().max_size(MAX_WORKER);
26//! let smtp_server: Arc<AsyncSmtpTransport<Tokio1Executor>> = Arc::new(
27//! AsyncSmtpTransport::<Tokio1Executor>::starttls_relay("smtp-relay.gmail.com")?
28//! .authentication(vec![Mechanism::Login])
29//! .credentials(smtp_credentials)
30//! .pool_config(pool_config)
31//! .build(),
32//! );
33//!
34//! let tmp = temp_dir().join("maillist.txt");
35//! let file = tokio::fs::File::options().create(true).append(true).open(tmp).await?;
36//! let mut reader = tokio::io::BufReader::new(file);
37//!
38//! let mut total_lines = 0;
39//! loop {
40//! match reader.read_line(&mut String::new()).await? {
41//! 0 => break,
42//! _ => total_lines += 1,
43//! }
44//! }
45//!
46//! reader.rewind().await?;
47//! let mut lines = reader.lines();
48//!
49//! // <-- Start of Main Example -->
50//! let template: DefaultTemplate = WorkerTemplate::new("Email Sent");
51//! let mut worker_pool: WorkerPool<String, DefaultTemplate> = WorkerPool::new(total_lines, template, 20);
52//!
53//! // Spawn all worker and hold the connections
54//! (0..MAX_WORKER).for_each(|_| {
55//! let smtp = smtp_server.clone();
56//! worker_pool.spawn_worker(move |line, progress| send_email(line, progress, smtp.clone()));
57//! });
58//!
59//! while let Some(line) = lines.next_line().await? {
60//! _ = worker_pool.send_seqcst(line).await;
61//! }
62//!
63//! // Wait for all workers to complete
64//! worker_pool.join_all().await;
65//! // <-- End of Main Example -->
66//!
67//! Ok(())
68//! }
69//!
70//! async fn send_email(line: String, progress: MainProgress<DefaultTemplate>, smtp_server: Arc<AsyncSmtpTransport<Tokio1Executor>>) -> Result<()> {
71//! let recipients = line.parse::<Mailbox>()?;
72//! let message = MessageBuilder::new()
73//! .to(recipients)
74//! .from("example@gmail.com".parse()?)
75//! .subject(format!("This is an email sent at #{} progress", progress.length().unwrap()))
76//! .body(String::from("Test mail body"))?;
77//!
78//! smtp_server
79//! .send(message)
80//! .await
81//! .map(|_| progress.increment(1))
82//! .map_err(|err| anyhow!("{err}"))
83//! }
84//! ```
85
86use std::cell::UnsafeCell;
87use std::collections::{HashMap, VecDeque};
88use std::marker::PhantomData;
89use std::process::abort;
90use std::time::Duration;
91
92use colored::Colorize;
93#[cfg(feature = "futures-util")]
94use futures_util::future::join_all;
95use indicatif::{MultiProgress, ProgressDrawTarget, ProgressState, ProgressStyle};
96use tokio::sync::mpsc::{self, Sender};
97use tokio::task::JoinHandle;
98
99use crate::helper::line_err;
100use crate::{limit_string, MainProgress, Result, Uid, WorkerTemplate};
101
102type Handle = JoinHandle<()>;
103type Handles = HashMap<Uid, Handle>;
104
105/// [`WorkerPool`] manages a pool of worker threads for parallel task execution.
106///
107/// It allows spawning workers, sending data to them, and managing their lifecycle.
108pub struct WorkerPool<D, S>
109where
110 S: WorkerTemplate,
111{
112 ui: MultiProgress,
113 channels: VecDeque<(Uid, Sender<D>)>,
114 handles: Handles,
115 main_progress: MainProgress<S>,
116 _unsafe_sync: PhantomData<UnsafeCell<()>>,
117}
118
119impl<D: Send + 'static, S: WorkerTemplate> WorkerPool<D, S> {
120 /// Creates a new `WorkerPool` with a specified number of workers, a worker template, and a draw frequency.
121 ///
122 /// # Arguments
123 ///
124 /// * `len` - The number of workers in the pool.
125 /// * `template` - The template for creating progress bars for each worker.
126 /// * `draw_hz` - The frequency at which the progress bars are drawn (frames per second).
127 pub fn new(len: u64, template: S, draw_hz: u8) -> Self {
128 let target = ProgressDrawTarget::stderr_with_hz(draw_hz);
129 let ui = MultiProgress::with_draw_target(target);
130 let main_ui = MainProgress::new(len, ui.clone(), template);
131 ui.set_move_cursor(false);
132
133 Self {
134 ui,
135 channels: Default::default(),
136 handles: Default::default(),
137 main_progress: main_ui,
138 _unsafe_sync: PhantomData,
139 }
140 }
141
142 /// Generates a new unique task ID.
143 pub fn new_task_id(&self) -> Uid {
144 Uid::new(self.handles.len() as u32 + 1).unwrap()
145 }
146
147 /// Spawns a new worker in the pool.
148 ///
149 /// # Arguments
150 ///
151 /// * `f` - A closure that represents the worker's task. It takes data `D` and a `MainProgress` instance as input
152 /// and returns a `Result`.
153 pub fn spawn_worker<F, Fut>(&mut self, f: F) -> Uid
154 where
155 Fut: Future<Output = anyhow::Result<()>> + Send + 'static,
156 Fut::Output: Send + 'static,
157 F: Fn(D, MainProgress<S>) -> Fut + Send + 'static,
158 {
159 let task_id = self.new_task_id();
160 let (tx, rx) = mpsc::channel(1);
161 let handle = match super::WorkerHandleBuilder::default()
162 .fn_ptr(f)
163 .main_ui(self.main_ui())
164 .receiver(rx)
165 .build(task_id)
166 {
167 Ok(handle) => handle,
168 Err(build_error) => {
169 self.main_progress.println(build_error.to_string());
170 abort();
171 }
172 };
173
174 let handle: Handle = tokio::spawn(handle.run());
175 self.channels.push_back((task_id, tx));
176 self.handles.insert(task_id, handle);
177
178 task_id
179 }
180
181 /// Returns a clone of the main progress instance.
182 pub fn main_ui(&self) -> MainProgress<S> {
183 self.main_progress.clone()
184 }
185
186 /// Handles the SIGINT signal, setting a prefix message and stopping all workers.
187 pub async fn sigint(&mut self) {
188 let prefix = "<C-c> Received, Waiting all background processes to finished";
189 self.main_progress.set_prefix(limit_string(116, prefix.bright_red().to_string(), None));
190 self.stop_all_workers().await;
191 }
192
193 /// Stops all workers in the pool and waits for them to finish.
194 async fn stop_all_workers(&mut self) {
195 self.close_all();
196
197 loop {
198 if self.handles.iter().filter(|(_, h)| h.is_finished()).count().ge(&self.handles.len()) {
199 break;
200 }
201
202 tokio::time::sleep(Duration::from_millis(1)).await;
203 }
204
205 self
206 .main_progress
207 .println("All background process is finished!".bright_green().to_string())
208 }
209
210 /// Closes all channels in the pool.
211 fn close_all(&mut self) {
212 while self.channels.pop_back().is_some() {
213 // Empty
214 }
215 }
216
217 /// Sends data to the first available worker in the pool.
218 ///
219 /// # Arguments
220 ///
221 /// * `data`: The data to send to the worker.
222 ///
223 /// Returns: `Result<(), D>`
224 ///
225 pub async fn send_seqcst(&mut self, mut data: D) -> Result<(), D> {
226 let timeout_1ms = Duration::from_millis(1);
227
228 'sending: while !self.channels.is_empty() {
229 // Manages closed channel
230 self.channels.retain(|(_, tx)| !tx.is_closed());
231
232 let Some(channel) = self.channels.pop_front() else { continue };
233 let sent = channel.1.send_timeout(data, timeout_1ms).await;
234 self.channels.push_back(channel);
235
236 match sent {
237 Ok(_) => return Ok(()),
238 Err(error) => {
239 data = error.into_inner();
240 continue 'sending;
241 }
242 }
243 }
244
245 Err(data)
246 }
247
248 /// Sends data directly to the worker by its id.
249 ///
250 /// # Arguments
251 ///
252 /// * `id` - The ID of the worker to send data to.
253 /// * `data` - The data to send.
254 ///
255 /// returns: Result<(), D>
256 ///
257 /// # Panics
258 ///
259 /// Panics if the pool naver been spawned a worker.
260 pub async fn send_to(&mut self, id: Uid, data: D) -> Result<(), D> {
261 if self.handles.is_empty() {
262 panic!("No worker has ever been spawned!");
263 }
264
265 self.channels.retain(|(_, tx)| !tx.is_closed());
266
267 for (worker_id, tx) in &self.channels {
268 if worker_id == &id {
269 if let Err(err) = tx.send(data).await {
270 _ = self.ui.println(err.to_string().bright_red().to_string());
271 return Err(err.0);
272 }
273 return Ok(());
274 }
275 }
276
277 Err(data)
278 }
279
280 /// Returns the number of active threads in the pool.
281 pub fn thead_count(&self) -> usize {
282 self.handles.len()
283 }
284
285 /// Waits for all workers to complete their tasks.
286 pub async fn join_all(mut self) {
287 self.stop_all_workers().await;
288
289 #[cfg(feature = "futures-util")]
290 join_all(self.handles.into_values()).await;
291 #[cfg(not(feature = "futures-util"))]
292 for handle in self.handles.into_values() {
293 _ = handle.await;
294 }
295 }
296}
297
298impl<D: Send, S: WorkerTemplate> WorkerPool<D, S> {
299 /// Retrieves a `ProgressStyle` based on the provided template.
300 ///
301 /// This style includes customized progress characters, date formatting, and status indicators.
302 pub fn get_style(template: impl AsRef<str>) -> ProgressStyle {
303 type PS = ProgressState;
304 ProgressStyle::with_template(template.as_ref())
305 .unwrap()
306 .progress_chars("──")
307 .tick_strings(&["◜", "◠", "◝", "◞", "◡", "◟"]) // Progress Chars
308 .with_key("date", |_: &PS, w: &mut dyn std::fmt::Write| {
309 _ = write!(w, "[{}]", crate::dt_now_rfc2822())
310 })
311 .with_key("|", |_: &PS, w: &mut dyn std::fmt::Write| _ = w.write_str("│"))
312 .with_key("-", |_: &PS, w: &mut dyn std::fmt::Write| _ = w.write_str("─"))
313 .with_key("l", |_: &PS, w: &mut dyn std::fmt::Write| _ = w.write_str("╰"))
314 .with_key("status", |ps: &PS, w: &mut dyn std::fmt::Write| {
315 _ = write!(
316 w,
317 "{}",
318 if ps.is_finished() {
319 "FINISHED".bright_green()
320 } else {
321 "RUNNING".bright_yellow()
322 }
323 )
324 })
325 }
326
327 /// Prints a line to the UI.
328 pub fn println(&self, line: impl AsRef<str>) -> Result<(), std::io::Error> {
329 self.ui.println(line)
330 }
331
332 /// Prints an error line to the UI.
333 pub fn eprintln(&self, line: impl AsRef<str>) -> Result<(), std::io::Error> {
334 self.ui.println(line_err(line.as_ref()))
335 }
336
337 /// Generates a horizontal line of a specified length.
338 pub fn horizontal_line(len: usize) -> String {
339 "─".repeat(len)
340 }
341
342 /// Generates a vertical line of a specified length.
343 pub fn vertical_line(len: usize) -> String {
344 "│".repeat(len)
345 }
346}
347
348#[cfg(test)]
349mod test {
350 use static_assertions::assert_not_impl_any;
351
352 use super::*;
353 use crate::DefaultTemplate;
354
355 #[test]
356 fn worker_pool_should_not_be_sync() {
357 assert_not_impl_any!(WorkerPool<String, DefaultTemplate>: Sync);
358 }
359}