1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455
use std::cell::RefCell;
use std::sync::{Arc, Barrier, Mutex};
use std::thread::{Builder, JoinHandle, LocalKey};
use std::time::Duration;
use anyhow::anyhow;
use crate::blocking_queue::BlockingQueue;
use crate::command::Command;
use crate::shutdown_mode::ShutdownMode;
use crate::signal::Signal;
struct RunInAllThreadsCommand {
f: Arc<dyn Fn() + Send + Sync>,
b: Arc<Barrier>,
}
impl RunInAllThreadsCommand {
pub fn new(f: Arc<dyn Fn() + Send + Sync>, b: Arc<Barrier>) -> RunInAllThreadsCommand {
RunInAllThreadsCommand {
f,
b,
}
}
}
impl Command for RunInAllThreadsCommand {
fn execute(&self) -> Result<(), anyhow::Error> {
{
(self.f)();
}
self.b.wait();
Ok(())
}
}
struct RunMutInAllThreadsCommand {
f: Arc<Mutex<dyn FnMut() + Send + Sync>>,
b: Arc<Barrier>,
}
impl RunMutInAllThreadsCommand {
pub fn new(f: Arc<Mutex<dyn FnMut() + Send + Sync>>, b: Arc<Barrier>) -> RunMutInAllThreadsCommand {
RunMutInAllThreadsCommand {
f,
b,
}
}
}
impl Command for RunMutInAllThreadsCommand {
fn execute(&self) -> Result<(), anyhow::Error> {
{
let mut f = self.f.lock().unwrap();
f();
}
self.b.wait();
Ok(())
}
}
/// Execute tasks concurrently while maintaining bounds on memory consumption
///
/// To demonstrate the use case this implementation solves let's consider a program that reads
/// lines from a file and writes those lines to another file after some processing. The processing
/// itself is stateless and can be done in parallel on each line, but the reading and writing must
/// be sequential. Using this implementation we will read the input in the main thread, submit it
/// for concurrent processing to a processing thread pool and collect it for writing in a writing thread pool
/// with a single thread. See ./examples/read_process_write_pipeline.rs. The submission to a thread
/// pool is done through a blocking bounded queue, so if the processing thread pool or the writing
/// thread pool cannot keep up, their blocking queues will fill up and create a backpressure that
/// will pause the reading. So the resulting pipeline will stabilize on a throughput commanded by the
/// slowest stage with the memory consumption determined by sizes of queues and number of
/// threads in each thread pool.
///
/// For reference see [Command Pattern](https://en.wikipedia.org/wiki/Command_pattern) and
/// [Producer-Consumer](https://en.wikipedia.org/wiki/Producer%E2%80%93consumer_problem)
///
pub struct ThreadPool {
name: String,
tasks: usize,
queue: Arc<BlockingQueue<Box<dyn Command + Send + Sync>, Signal>>,
threads: Vec<JoinHandle<Result<(), anyhow::Error>>>,
join_error_handler: fn(String, String),
shutdown_mode: ShutdownMode,
expired: bool,
}
impl ThreadPool {
pub(crate) fn new(
name: String,
tasks: usize,
queue_size: usize,
join_error_handler: fn(String, String),
shutdown_mode: ShutdownMode,
) -> Result<ThreadPool, anyhow::Error> {
let start_barrier = Arc::new(Barrier::new(tasks + 1));
let mut threads = Vec::<JoinHandle<Result<(), anyhow::Error>>>::new();
let queue = Arc::new(
BlockingQueue::<Box<dyn Command + Send + Sync>, Signal>::new(queue_size)
);
for i in 0..tasks {
let barrier = start_barrier.clone();
let t = Self::create_thread(
&name,
i,
barrier,
queue.clone(),
);
threads.push(t.unwrap());
}
start_barrier.wait();
Ok(
ThreadPool {
name,
tasks,
queue: queue.clone(),
threads,
join_error_handler,
shutdown_mode,
expired: false,
}
)
}
/// Set the number of concurrent threads in the thread pool
pub fn tasks(&self) -> usize {
self.tasks
}
fn create_thread(
name: &String,
index: usize,
barrier: Arc<Barrier>,
queue: Arc<BlockingQueue<Box<dyn Command + Send + Sync>, Signal>>,
) -> Result<JoinHandle<Result<(), anyhow::Error>>, anyhow::Error> {
let builder = Builder::new();
Ok(builder
.name(format!("{name}-{index}"))
.spawn(move || {
barrier.wait();
let mut r: Result<(), anyhow::Error> = Ok(());
loop {
let (command, signal) = queue.dequeue();
if let Some(c) = command {
match c.execute() {
Ok(_) => {}
Err(e) => {
r = Err(e);
}
}
}
if let Some(s) = signal {
match s {
Signal::Shutdown => {
break r;
}
}
}
}
}
)?
)
}
/// Execute f in all threads.
///
/// This function returns only after f had completed in all threads. Can be used to collect
/// data produced by the threads. See ./examples/fetch_thread_local.rs.
///
/// Caveat: this is a [barrier](https://en.wikipedia.org/wiki/Barrier_%28computer_science%29)
/// function. So if one of the threads is busy with a long running task or is deadlocked, this
/// will halt all the threads until f can be executed.
pub fn in_all_threads_mut(&self, f: Arc<Mutex<dyn FnMut() + Send + Sync>>) {
let b = Arc::new(Barrier::new(self.tasks + 1));
for _i in 0..self.tasks {
self.submit(Box::new(RunMutInAllThreadsCommand::new(f.clone(), b.clone())));
}
b.wait();
}
/// Execute f in all threads.
///
/// This function returns only after f had completed in all threads. Can be used to flush
/// data produced by the threads or simply execute work concurrently. See ./examples/flush_thread_local.rs.
///
/// Caveat: this is a [barrier](https://en.wikipedia.org/wiki/Barrier_%28computer_science%29)
/// function. So if one of the threads is busy with a long running task or is deadlocked, this
/// will halt all the threads until f can be executed.
pub fn in_all_threads(&self, f: Arc<dyn Fn() + Send + Sync>) {
let b = Arc::new(Barrier::new(self.tasks + 1));
for _i in 0..self.tasks {
self.submit(Box::new(RunInAllThreadsCommand::new(f.clone(), b.clone())));
}
b.wait();
}
/// Initializes the `local_key` to contain `val`.
///
/// See ./examples/thread_local.rs
pub fn set_thread_local<T>(&mut self, local_key: &'static LocalKey<RefCell<T>>, val: T)
where T: Sync + Send + Clone {
self.in_all_threads_mut(
Arc::new(
Mutex::new(
move || {
local_key.with(
|value| {
value.replace(val.clone())
}
);
}
)
)
);
}
/// Shut down the thread pool.
///
/// This will shut down the thread pool according to configuration. When configured with
/// * [ShutdownMode::Immediate] - terminate each tread after completing the current task
/// * [ShutdownMode::CompletePending] - terminate after completing all pending tasks
pub fn shutdown(&mut self) {
self.expired = true;
match self.shutdown_mode {
ShutdownMode::Immediate => {
self.queue.signal(Signal::Shutdown);
}
ShutdownMode::CompletePending => {
self.queue.wait_empty(Duration::MAX);
self.queue.signal(Signal::Shutdown);
}
}
}
/// Wait until all thread pool threads completed.
pub fn join(&mut self) -> Result<(), anyhow::Error> {
let mut join_errors = Vec::<String>::new();
while self.threads.len() > 0 {
let t = self.threads.pop().unwrap();
let name = t.thread().name().unwrap_or("unnamed").to_string();
match t.join() {
Ok(r) => {
match r {
Ok(_) => {}
Err(e) => {
let message = format!("{e:?}");
join_errors.push(message.clone());
(self.join_error_handler)(name, message);
}
}
}
Err(e) => {
let mut message = "Unknown error".to_string();
if let Some(error) = e.downcast_ref::<&'static str>() {
message = error.to_string();
}
join_errors.push(message.clone());
(self.join_error_handler)(name, message);
}
}
}
if join_errors.is_empty() {
Ok(())
} else {
Err(anyhow!("Errors occurred while joining threads in the {} pool: {}", self.name, join_errors.join(", "))
)
}
}
/// Submit command for execution
pub fn submit(&self, command: Box<dyn Command + Send + Sync>) {
if self.expired {
panic!("the thread pool {} is expired", self.name)
}
self.try_submit(command, Duration::MAX);
}
/// Submit command for execution with timeout
///
/// Returns the command on failure and None on success
pub fn try_submit(&self, command: Box<dyn Command + Send + Sync>, timeout: Duration) -> Option<Box<dyn Command + Send + Sync>> {
self.queue.try_enqueue(command, timeout)
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::shutdown_mode::ShutdownMode;
use crate::shutdown_mode::ShutdownMode::CompletePending;
use crate::thread_pool_builder::ThreadPoolBuilder;
use super::*;
struct TestCommand {
_payload: i32,
execution_counter: Arc<AtomicUsize>,
}
impl TestCommand {
pub fn new(payload: i32, execution_counter: Arc<AtomicUsize>) -> TestCommand {
TestCommand {
_payload: payload,
execution_counter,
}
}
}
impl Command for TestCommand {
fn execute(&self) -> Result<(), anyhow::Error> {
self.execution_counter.fetch_add(1, Ordering::Relaxed);
Ok(())
}
}
#[test]
fn test_create() {
let mut thread_pool_builder = ThreadPoolBuilder::new();
let tp_result = thread_pool_builder
.name("t".to_string())
.tasks(4)
.queue_size(8)
.build();
match tp_result {
Ok(mut tp) => {
assert!(true);
tp.shutdown();
assert_eq!((), tp.join().unwrap());
}
Err(_) => {
assert!(false);
}
}
}
#[test]
fn test_submit() {
let mut thread_pool_builder = ThreadPoolBuilder::new();
let mut tp = thread_pool_builder
.name("t".to_string())
.tasks(4)
.queue_size(2048)
.build()
.unwrap();
let execution_counter = Arc::new(AtomicUsize::from(0));
for _i in 0..1024 {
let ec = execution_counter.clone();
tp.submit(Box::new(TestCommand::new(4, ec)));
}
tp.shutdown();
tp.join().expect("Failed to join thread pool");
assert_eq!((), tp.join().unwrap());
// accidental but usually works
// if fails safe to comment out the next two lines
// assert!(execution_counter.fetch_or(0, Ordering::Relaxed) > 0);
// assert!(execution_counter.fetch_or(0, Ordering::Relaxed) < 1024);
}
#[test]
fn test_shutdown_complete_pending() {
let mut thread_pool_builder = ThreadPoolBuilder::new();
let mut tp = thread_pool_builder
.name("t".to_string())
.tasks(4)
.queue_size(2048)
.shutdown_mode(ShutdownMode::CompletePending)
.build()
.unwrap();
let execution_counter = Arc::new(AtomicUsize::from(0));
for _i in 0..1024 {
let ec = execution_counter.clone();
tp.submit(Box::new(TestCommand::new(4, ec)));
}
tp.shutdown();
tp.join().expect("Failed to join thread pool");
assert_eq!((), tp.join().unwrap());
assert_eq!(execution_counter.fetch_or(0, Ordering::Relaxed), 1024);
}
struct PanicTestCommand {}
impl PanicTestCommand {
pub fn new() -> PanicTestCommand {
PanicTestCommand {}
}
}
impl Command for PanicTestCommand {
fn execute(&self) -> Result<(), anyhow::Error> {
Err(anyhow!("simulating error during command execution"))
}
}
#[test]
fn test_join_error_handler() {
let mut thread_pool_builder = ThreadPoolBuilder::new();
let mut tp = thread_pool_builder
.name("t".to_string())
.tasks(4)
.shutdown_mode(CompletePending)
.queue_size(8)
.join_error_handler(
|name, message| {
println!("Thread {name} ended with and error {message}")
}
)
.build()
.unwrap();
for _i in 0..2 {
tp.submit(Box::new(PanicTestCommand::new()));
}
tp.shutdown();
let r = tp.join();
assert!(r.is_err());
}
#[test]
#[should_panic]
fn test_use_after_join() {
let mut thread_pool_builder = ThreadPoolBuilder::new();
let mut tp = thread_pool_builder
.name("t".to_string())
.tasks(4)
.queue_size(2048)
.shutdown_mode(ShutdownMode::CompletePending)
.build()
.unwrap();
let execution_counter = Arc::new(AtomicUsize::from(0));
for _i in 0..1024 {
let ec = execution_counter.clone();
tp.submit(Box::new(TestCommand::new(4, ec)));
}
tp.shutdown();
tp.join().expect("Failed to join thread pool");
let execution_counter = Arc::new(AtomicUsize::from(0));
for _i in 0..1024 {
let ec = execution_counter.clone();
tp.submit(Box::new(TestCommand::new(4, ec)));
}
}
}