1use crate::{
2 task::{Task, TaskFn, TaskListeners},
3 worker::Worker,
4 ThreadPoolBuilder,
5};
6
7use crossbeam_channel::{bounded, Receiver, Sender, TrySendError};
8
9use std::{
10 sync::{atomic::AtomicUsize, Arc, Mutex},
11 thread,
12 time::Duration,
13};
14
15pub type ThreadFactory = dyn Fn() -> thread::Builder + Send + Sync + 'static;
17
18type TPResult<T> = Result<T, TPError>;
19
20#[derive(Debug)]
24pub enum TPError {
25 Abort,
29
30 Closed,
32}
33
34impl std::error::Error for TPError {}
35
36impl std::fmt::Display for TPError {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match &self {
39 TPError::Abort => writeln!(f, "task abortion error."),
40 TPError::Closed => writeln!(f, "the thread pool is closwd."),
41 }
42 }
43}
44
45#[derive(Clone)]
47pub enum RejectedTaskHandler {
48 Abort,
50
51 Discard,
53
54 CallerRuns,
56}
57
58pub(crate) struct ThreadPoolSharedData {
59 pub(crate) sender: Mutex<Option<Sender<Task>>>,
60 pub(crate) core_workers: Mutex<Option<Vec<Worker>>>,
61 pub(crate) workers: Mutex<Option<Vec<Worker>>>,
62 pub(crate) next_task_id: AtomicUsize,
63}
64
65impl ThreadPoolSharedData {
66 pub(crate) fn num_of_core_workers(&self) -> usize {
67 self.core_workers
68 .lock()
69 .unwrap()
70 .as_ref()
71 .map_or(0, Vec::len)
72 }
73
74 pub(crate) fn num_of_active_workers(&self) -> usize {
75 self.workers.lock().unwrap().as_ref().map_or(0, |x| {
76 x.iter().filter(|worker| !worker.is_finished()).count()
77 })
78 }
79}
80
81#[derive(Clone)]
133pub struct ThreadPool {
134 pub(crate) reciver: Receiver<Task>,
135 pub(crate) share: Arc<ThreadPoolSharedData>,
136
137 pub(crate) core_pool_size: usize,
138 pub(crate) max_pool_size: usize,
139 pub(crate) keep_alive_time: Duration,
140 pub(crate) rejected_task_handler: RejectedTaskHandler,
141 pub(crate) task_lisenters: Arc<TaskListeners>,
142 pub(crate) thread_factory: Arc<ThreadFactory>,
143}
144
145impl ThreadPool {
146 pub(crate) fn from_builder(builder: ThreadPoolBuilder) -> Self {
150 let (sender, reciver) = bounded(builder.channel_capacity);
151 Self {
152 reciver,
153 share: Arc::new(ThreadPoolSharedData {
154 sender: Mutex::new(Some(sender)),
155 core_workers: Mutex::new(Some(Vec::default())),
156 workers: Mutex::new(Some(Vec::default())),
157 next_task_id: AtomicUsize::new(0),
158 }),
159 core_pool_size: builder.core_pool_size,
160 max_pool_size: builder.max_pool_size,
161 keep_alive_time: builder.keep_alive_time,
162 rejected_task_handler: builder.rejected_task_handler,
163 task_lisenters: Arc::new(builder.task_lisenters),
164 thread_factory: builder.thread_factory,
165 }
166 }
167
168 pub fn execute<F>(&self, task_fn: F) -> Result<(), TPError>
183 where
184 F: FnOnce() + Send + 'static,
185 {
186 if self.is_closed() {
187 return Err(TPError::Closed);
188 }
189
190 let task = self.create_task(Box::new(task_fn));
191 let mut core_workers = self.share.core_workers.lock().unwrap();
192 if let Some(core_workers) = core_workers.as_mut() {
193 if core_workers.len() < self.core_pool_size {
195 let worker = self.create_worker(task, true);
196 core_workers.push(worker);
197 return Ok(());
198 }
199 }
200 drop(core_workers);
202 self.send_task(task)
203 }
204
205 #[must_use]
207 pub fn active_count(&self) -> usize {
208 self.share.num_of_active_workers() + self.share.num_of_core_workers()
209 }
210
211 pub fn shutdown(&self) {
230 self.share.sender.lock().unwrap().take();
231 }
232
233 #[must_use]
235 pub fn is_closed(&self) -> bool {
236 self.share.sender.lock().unwrap().is_none()
237 }
238
239 pub fn wait(&self) -> std::thread::Result<()> {
276 self.shutdown();
277 Self::wait_workers(self.share.core_workers.lock().unwrap().take())?;
278 Self::wait_workers(self.share.workers.lock().unwrap().take())
279 }
280
281 fn wait_workers(workers: Option<Vec<Worker>>) -> std::thread::Result<()> {
282 if let Some(workers) = workers {
283 for worker in workers {
284 worker.join()?;
285 }
286 }
287 Ok(())
288 }
289
290 fn create_task(&self, task_fn: TaskFn) -> Task {
291 let id = self
292 .share
293 .next_task_id
294 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
295 Task::create(id, task_fn, self.task_lisenters.clone())
296 }
297
298 fn create_worker(&self, task: Task, is_core: bool) -> Worker {
299 Worker::new(
300 is_core,
301 self.keep_alive_time,
302 self.thread_factory.clone(),
303 self.reciver.clone(),
304 task,
305 )
306 }
307
308 fn send_task(&self, task: Task) -> TPResult<()> {
309 let sender = self.share.sender.lock().unwrap();
310 if sender.is_none() {
311 return Err(TPError::Closed);
312 }
313
314 if let Err(err) = sender.as_ref().unwrap().try_send(task) {
315 drop(sender);
317 return match err {
318 TrySendError::Full(task) => self.process_task_if_channel_full(task),
319 TrySendError::Disconnected(_) => Err(TPError::Closed),
320 };
321 }
322 Ok(())
323 }
324
325 fn process_task_if_channel_full(&self, task: Task) -> TPResult<()> {
326 let mut workers = self.share.workers.lock().unwrap();
327 if workers.is_none() {
328 return self.reject(task);
331 }
332
333 let non_core_workers = workers.as_mut().unwrap();
334 let idle_worker = non_core_workers
336 .iter_mut()
337 .find(|worker| worker.is_finished());
338 if let Some(idle_worker) = idle_worker {
339 idle_worker.restart(task);
340 return Ok(());
341 }
342
343 if non_core_workers.len() < self.max_pool_size - self.core_pool_size {
344 let worker = self.create_worker(task, false);
345 non_core_workers.push(worker);
346 Ok(())
347 } else {
348 drop(workers);
350 self.reject(task)
353 }
354 }
355
356 fn reject(&self, task: Task) -> Result<(), TPError> {
357 match &self.rejected_task_handler {
358 RejectedTaskHandler::Abort => Err(TPError::Abort),
359 RejectedTaskHandler::CallerRuns => {
360 task.run();
361 Ok(())
362 }
363 RejectedTaskHandler::Discard => Ok(()),
364 }
365 }
366}
367
368#[cfg(test)]
369mod tests {
370
371 use crate::{RejectedTaskHandler, ThreadPoolBuilder};
372 use std::{
373 collections::HashSet,
374 sync::{
375 atomic::{AtomicUsize, Ordering},
376 Arc, Mutex,
377 },
378 thread,
379 time::Duration,
380 };
381
382 #[test]
383 fn test_execute_in_multiple_threads() {
384 let thread_pool = ThreadPoolBuilder::default()
385 .core_pool_size(4)
386 .max_pool_size(10)
387 .channel_capacity(100)
388 .keep_alive_time(Duration::from_secs(100))
389 .build();
390
391 let sum = Arc::new(AtomicUsize::new(0));
392 let mut handles = Vec::new();
393 for _ in 0..10 {
394 let sum = sum.clone();
395 let thread_pool = thread_pool.clone();
396 handles.push(thread::spawn(move || {
397 for _ in 0..10 {
398 let sum = sum.clone();
399 thread_pool
400 .execute(move || {
401 sum.fetch_add(1, Ordering::SeqCst);
402 })
403 .ok();
404 }
405 }));
406 }
407
408 for handle in handles {
409 handle.join().unwrap();
410 }
411
412 assert!(thread_pool.share.sender.lock().unwrap().is_some());
414 assert_eq!(4, thread_pool.share.num_of_core_workers());
415 assert!(thread_pool.share.num_of_active_workers() <= 6);
416
417 thread_pool.wait().unwrap();
418 assert_eq!(100, sum.load(Ordering::Relaxed));
419 }
420
421 #[test]
422 fn test_shutdown_in_multiple_threads() {
423 let thread_pool = ThreadPoolBuilder::default().build();
424 let counter = Arc::new(AtomicUsize::new(0));
425 let mut handles = Vec::new();
426 for _ in 0..100 {
427 let thread_pool = thread_pool.clone();
428 let counter = counter.clone();
429 handles.push(thread::spawn(move || {
430 if thread_pool.is_closed() {
431 counter.fetch_add(1, Ordering::SeqCst);
432 assert!(thread_pool.execute(|| ()).is_err());
433 }
434 thread_pool.shutdown();
435 assert!(thread_pool.execute(|| ()).is_err());
436 }));
437 }
438
439 for handle in handles {
440 handle.join().unwrap();
441 }
442
443 assert_eq!(99, counter.load(Ordering::Relaxed));
444 }
445
446 #[test]
447 fn test_lisenters() {
448 let map0 = Arc::new(Mutex::new(HashSet::new()));
449 let map1 = map0.clone();
450 let map2 = map0.clone();
451
452 let thread_pool = ThreadPoolBuilder::default()
453 .lisenter_before_execute(move |id| {
454 let mut map = map0.lock().unwrap();
455 map.insert(id);
456 })
457 .lisenter_after_execute(move |id| {
458 assert!(map1.lock().unwrap().contains(&id));
459 })
460 .channel_capacity(50)
461 .build();
462
463 for _ in 0..50 {
464 thread_pool
465 .execute(|| {
466 thread::sleep(Duration::from_millis(20));
467 })
468 .unwrap();
469 }
470 thread_pool.shutdown();
471 thread_pool.wait().unwrap();
472 assert_eq!(50, map2.lock().unwrap().len());
473 }
474
475 #[test]
476 fn test_thread_factory() {
477 let thread_pool = ThreadPoolBuilder::new()
478 .thread_factory_fn(|| thread::Builder::new().name("test".into()))
479 .core_pool_size(2)
480 .max_pool_size(5)
481 .channel_capacity(5)
482 .rejected_handler(RejectedTaskHandler::Discard)
483 .build();
484
485 for _ in 0..20 {
486 thread_pool
487 .execute(|| thread::sleep(Duration::from_millis(20)))
488 .unwrap();
489 }
490
491 let workers = thread_pool.share.core_workers.lock().unwrap();
492 assert!(workers.as_ref().unwrap().len() == 2);
493 for core_worker in workers.as_ref().unwrap() {
494 assert_eq!(Some("test"), core_worker.handle.thread().name());
495 }
496
497 let workers = thread_pool.share.workers.lock().unwrap();
498 assert!(workers.as_ref().unwrap().len() == 3);
499 for worker in workers.as_ref().unwrap() {
500 assert_eq!(Some("test"), worker.handle.thread().name());
501 }
502 }
503}