rfs_runner/
worker_pool.rs

1//! `WorkerPool` manages a pool of worker threads for concurrent task execution.
2//!
3//! It provides functionalities to spawn workers, send data to them,
4//! and manage their lifecycle, including graceful shutdown on receiving signals.
5//!
6//! # Examples
7//!
8//! ```no_run
9//! use std::sync::Arc;
10//!
11//! use anyhow::{anyhow, Result};
12//! use lettre::message::{Mailbox, MessageBuilder};
13//! use lettre::transport::smtp::authentication::{Credentials, Mechanism};
14//! use lettre::transport::smtp::PoolConfig;
15//! use lettre::{AsyncSmtpTransport, AsyncTransport, Tokio1Executor};
16//! use rfs_runner::{DefaultTemplate, MainProgress, WorkerPool, WorkerTemplate};
17//! use tokio::io::{AsyncBufReadExt, AsyncSeekExt};
18//! const MAX_WORKER: u32 = 8;
19//!
20//! #[tokio::main]
21//! async fn main() -> Result<()> {
22//!   use std::env::temp_dir;
23//!
24//!   let smtp_credentials = Credentials::new("example@gmail.com".to_owned(), "KeepSecretPasswordFromCommit".to_owned());
25//!   let pool_config = PoolConfig::new().max_size(MAX_WORKER);
26//!   let smtp_server: Arc<AsyncSmtpTransport<Tokio1Executor>> = Arc::new(
27//!     AsyncSmtpTransport::<Tokio1Executor>::starttls_relay("smtp-relay.gmail.com")?
28//!         .authentication(vec![Mechanism::Login])
29//!         .credentials(smtp_credentials)
30//!         .pool_config(pool_config)
31//!         .build(),
32//!   );
33//!
34//!   let tmp = temp_dir().join("maillist.txt");
35//!   let file = tokio::fs::File::options().create(true).append(true).open(tmp).await?;
36//!   let mut reader = tokio::io::BufReader::new(file);
37//!
38//!   let mut total_lines = 0;
39//!   loop {
40//!     match reader.read_line(&mut String::new()).await? {
41//!       0 => break,
42//!       _ => total_lines += 1,
43//!     }
44//!   }
45//!
46//!   reader.rewind().await?;
47//!   let mut lines = reader.lines();
48//!
49//!   // <-- Start of Main Example -->
50//!   let template: DefaultTemplate = WorkerTemplate::new("Email Sent");
51//!   let mut worker_pool: WorkerPool<String, DefaultTemplate> = WorkerPool::new(total_lines, template, 20);
52//!
53//!   // Spawn all worker and hold the connections
54//!   (0..MAX_WORKER).for_each(|_| {
55//!     let smtp = smtp_server.clone();
56//!     worker_pool.spawn_worker(move |line, progress| send_email(line, progress, smtp.clone()));
57//!   });
58//!
59//!   while let Some(line) = lines.next_line().await? {
60//!     _ = worker_pool.send_seqcst(line).await;
61//!   }
62//!
63//!   // Wait for all workers to complete
64//!   worker_pool.join_all().await;
65//!   // <-- End of Main Example -->
66//!
67//!   Ok(())
68//! }
69//!
70//! async fn send_email(line: String, progress: MainProgress<DefaultTemplate>, smtp_server: Arc<AsyncSmtpTransport<Tokio1Executor>>) -> Result<()> {
71//!   let recipients = line.parse::<Mailbox>()?;
72//!   let message = MessageBuilder::new()
73//!       .to(recipients)
74//!       .from("example@gmail.com".parse()?)
75//!       .subject(format!("This is an email sent at #{} progress", progress.length().unwrap()))
76//!       .body(String::from("Test mail body"))?;
77//!
78//!   smtp_server
79//!       .send(message)
80//!       .await
81//!       .map(|_| progress.increment(1))
82//!       .map_err(|err| anyhow!("{err}"))
83//! }
84//! ```
85
86use std::cell::UnsafeCell;
87use std::collections::{HashMap, VecDeque};
88use std::marker::PhantomData;
89use std::process::abort;
90use std::time::Duration;
91
92use colored::Colorize;
93#[cfg(feature = "futures-util")]
94use futures_util::future::join_all;
95use indicatif::{MultiProgress, ProgressDrawTarget, ProgressState, ProgressStyle};
96use tokio::sync::mpsc::{self, Sender};
97use tokio::task::JoinHandle;
98
99use crate::helper::line_err;
100use crate::{limit_string, MainProgress, Result, Uid, WorkerTemplate};
101
102type Handle = JoinHandle<()>;
103type Handles = HashMap<Uid, Handle>;
104
105/// [`WorkerPool`] manages a pool of worker threads for parallel task execution.
106///
107/// It allows spawning workers, sending data to them, and managing their lifecycle.
108pub struct WorkerPool<D, S>
109where
110  S: WorkerTemplate,
111{
112  ui: MultiProgress,
113  channels: VecDeque<(Uid, Sender<D>)>,
114  handles: Handles,
115  main_progress: MainProgress<S>,
116  _unsafe_sync: PhantomData<UnsafeCell<()>>,
117}
118
119impl<D: Send + 'static, S: WorkerTemplate> WorkerPool<D, S> {
120  /// Creates a new `WorkerPool` with a specified number of workers, a worker template, and a draw frequency.
121  ///
122  /// # Arguments
123  ///
124  /// * `len` - The number of workers in the pool.
125  /// * `template` - The template for creating progress bars for each worker.
126  /// * `draw_hz` - The frequency at which the progress bars are drawn (frames per second).
127  pub fn new(len: u64, template: S, draw_hz: u8) -> Self {
128    let target = ProgressDrawTarget::stderr_with_hz(draw_hz);
129    let ui = MultiProgress::with_draw_target(target);
130    let main_ui = MainProgress::new(len, ui.clone(), template);
131    ui.set_move_cursor(false);
132
133    Self {
134      ui,
135      channels: Default::default(),
136      handles: Default::default(),
137      main_progress: main_ui,
138      _unsafe_sync: PhantomData,
139    }
140  }
141
142  /// Generates a new unique task ID.
143  pub fn new_task_id(&self) -> Uid {
144    Uid::new(self.handles.len() as u32 + 1).unwrap()
145  }
146
147  /// Spawns a new worker in the pool.
148  ///
149  /// # Arguments
150  ///
151  /// * `f` - A closure that represents the worker's task. It takes data `D` and a `MainProgress` instance as input
152  ///   and returns a `Result`.
153  pub fn spawn_worker<F, Fut>(&mut self, f: F) -> Uid
154  where
155    Fut: Future<Output = anyhow::Result<()>> + Send + 'static,
156    Fut::Output: Send + 'static,
157    F: Fn(D, MainProgress<S>) -> Fut + Send + 'static,
158  {
159    let task_id = self.new_task_id();
160    let (tx, rx) = mpsc::channel(1);
161    let handle = match super::WorkerHandleBuilder::default()
162      .fn_ptr(f)
163      .main_ui(self.main_ui())
164      .receiver(rx)
165      .build(task_id)
166    {
167      Ok(handle) => handle,
168      Err(build_error) => {
169        self.main_progress.println(build_error.to_string());
170        abort();
171      }
172    };
173
174    let handle: Handle = tokio::spawn(handle.run());
175    self.channels.push_back((task_id, tx));
176    self.handles.insert(task_id, handle);
177
178    task_id
179  }
180
181  /// Returns a clone of the main progress instance.
182  pub fn main_ui(&self) -> MainProgress<S> {
183    self.main_progress.clone()
184  }
185
186  /// Handles the SIGINT signal, setting a prefix message and stopping all workers.
187  pub async fn sigint(&mut self) {
188    let prefix = "<C-c> Received, Waiting all background processes to finished";
189    self.main_progress.set_prefix(limit_string(116, prefix.bright_red().to_string(), None));
190    self.stop_all_workers().await;
191  }
192
193  /// Stops all workers in the pool and waits for them to finish.
194  async fn stop_all_workers(&mut self) {
195    self.close_all();
196
197    loop {
198      if self.handles.iter().filter(|(_, h)| h.is_finished()).count().ge(&self.handles.len()) {
199        break;
200      }
201
202      tokio::time::sleep(Duration::from_millis(1)).await;
203    }
204
205    self
206      .main_progress
207      .println("All background process is finished!".bright_green().to_string())
208  }
209
210  /// Closes all channels in the pool.
211  fn close_all(&mut self) {
212    while self.channels.pop_back().is_some() {
213      // Empty
214    }
215  }
216
217  /// Sends data to the first available worker in the pool.
218  ///
219  /// # Arguments
220  ///
221  /// * `data`: The data to send to the worker.
222  ///
223  /// Returns: `Result<(), D>`
224  ///
225  pub async fn send_seqcst(&mut self, mut data: D) -> Result<(), D> {
226    let timeout_1ms = Duration::from_millis(1);
227
228    'sending: while !self.channels.is_empty() {
229      // Manages closed channel
230      self.channels.retain(|(_, tx)| !tx.is_closed());
231
232      let Some(channel) = self.channels.pop_front() else { continue };
233      let sent = channel.1.send_timeout(data, timeout_1ms).await;
234      self.channels.push_back(channel);
235
236      match sent {
237        Ok(_) => return Ok(()),
238        Err(error) => {
239          data = error.into_inner();
240          continue 'sending;
241        }
242      }
243    }
244
245    Err(data)
246  }
247
248  /// Sends data directly to the worker by its id.
249  ///
250  /// # Arguments
251  ///
252  /// * `id` - The ID of the worker to send data to.
253  /// * `data` - The data to send.
254  ///
255  /// returns: Result<(), D>
256  ///
257  /// # Panics
258  ///
259  /// Panics if the pool naver been spawned a worker.
260  pub async fn send_to(&mut self, id: Uid, data: D) -> Result<(), D> {
261    if self.handles.is_empty() {
262      panic!("No worker has ever been spawned!");
263    }
264
265    self.channels.retain(|(_, tx)| !tx.is_closed());
266
267    for (worker_id, tx) in &self.channels {
268      if worker_id == &id {
269        if let Err(err) = tx.send(data).await {
270          _ = self.ui.println(err.to_string().bright_red().to_string());
271          return Err(err.0);
272        }
273        return Ok(());
274      }
275    }
276
277    Err(data)
278  }
279
280  /// Returns the number of active threads in the pool.
281  pub fn thead_count(&self) -> usize {
282    self.handles.len()
283  }
284
285  /// Waits for all workers to complete their tasks.
286  pub async fn join_all(mut self) {
287    self.stop_all_workers().await;
288
289    #[cfg(feature = "futures-util")]
290    join_all(self.handles.into_values()).await;
291    #[cfg(not(feature = "futures-util"))]
292    for handle in self.handles.into_values() {
293      _ = handle.await;
294    }
295  }
296}
297
298impl<D: Send, S: WorkerTemplate> WorkerPool<D, S> {
299  /// Retrieves a `ProgressStyle` based on the provided template.
300  ///
301  /// This style includes customized progress characters, date formatting, and status indicators.
302  pub fn get_style(template: impl AsRef<str>) -> ProgressStyle {
303    type PS = ProgressState;
304    ProgressStyle::with_template(template.as_ref())
305      .unwrap()
306      .progress_chars("──")
307      .tick_strings(&["◜", "◠", "◝", "◞", "◡", "◟"]) // Progress Chars
308      .with_key("date", |_: &PS, w: &mut dyn std::fmt::Write| {
309        _ = write!(w, "[{}]", crate::dt_now_rfc2822())
310      })
311      .with_key("|", |_: &PS, w: &mut dyn std::fmt::Write| _ = w.write_str("│"))
312      .with_key("-", |_: &PS, w: &mut dyn std::fmt::Write| _ = w.write_str("─"))
313      .with_key("l", |_: &PS, w: &mut dyn std::fmt::Write| _ = w.write_str("╰"))
314      .with_key("status", |ps: &PS, w: &mut dyn std::fmt::Write| {
315        _ = write!(
316          w,
317          "{}",
318          if ps.is_finished() {
319            "FINISHED".bright_green()
320          } else {
321            "RUNNING".bright_yellow()
322          }
323        )
324      })
325  }
326
327  /// Prints a line to the UI.
328  pub fn println(&self, line: impl AsRef<str>) -> Result<(), std::io::Error> {
329    self.ui.println(line)
330  }
331
332  /// Prints an error line to the UI.
333  pub fn eprintln(&self, line: impl AsRef<str>) -> Result<(), std::io::Error> {
334    self.ui.println(line_err(line.as_ref()))
335  }
336
337  /// Generates a horizontal line of a specified length.
338  pub fn horizontal_line(len: usize) -> String {
339    "─".repeat(len)
340  }
341
342  /// Generates a vertical line of a specified length.
343  pub fn vertical_line(len: usize) -> String {
344    "│".repeat(len)
345  }
346}
347
348#[cfg(test)]
349mod test {
350  use static_assertions::assert_not_impl_any;
351
352  use super::*;
353  use crate::DefaultTemplate;
354
355  #[test]
356  fn worker_pool_should_not_be_sync() {
357    assert_not_impl_any!(WorkerPool<String, DefaultTemplate>: Sync);
358  }
359}