1use mpmc_channel::MPMC;
2use std::{collections::VecDeque, io, num::NonZero, sync::Arc, thread, time::Duration};
3
4type Channel<Task> = Arc<MPMC<Queue<Task>>>;
5
6struct Queue<Task> {
7 tasks: VecDeque<Task>,
8 count: usize,
9}
10
11pub struct ThreadPool<Task: Runnable> {
12 channel: Channel<Task>,
13
14 max_threads_limit: u16,
16 timeout: Option<Duration>,
17 stack_size: Option<NonZero<usize>>,
18 load_factor: usize,
19 name: Box<dyn Fn(usize) -> String + Send + Sync>,
20}
21
22impl<T: Runnable> Default for ThreadPool<T> {
23 fn default() -> Self {
24 Self {
25 channel: Arc::new(MPMC::new(Queue {
26 tasks: VecDeque::with_capacity(64),
27 count: 0,
28 })),
29
30 timeout: Some(Duration::from_secs(7)),
31 max_threads_limit: 32,
32 load_factor: 2,
33 stack_size: None,
34 name: Box::new(|id| format!("Worker: {id}")),
35 }
36 }
37}
38
39pub trait Runnable: Send + 'static {
40 fn run(self);
41}
42
43impl<Task: Runnable> ThreadPool<Task> {
44 pub fn new() -> Self {
45 Self::default()
46 }
47
48 pub fn timeout(mut self, timeout: Option<Duration>) -> Self {
49 self.timeout = timeout;
50 self
51 }
52
53 pub fn max_threads_limit(mut self, limit: u16) -> Self {
54 assert!(limit > 0, "max threads limit must be greater than 0");
55 self.max_threads_limit = limit;
56 self
57 }
58
59 pub fn stack_size(mut self, stack_size: usize) -> Self {
60 self.stack_size = NonZero::new(stack_size);
61 self
62 }
63
64 pub fn load_factor(mut self, factor: usize) -> Self {
65 assert!(factor != 0, "threadpool load factor must be > 0");
66 self.load_factor = factor;
67 self
68 }
69
70 pub fn name<F>(mut self, f: F) -> Self
71 where
72 F: Fn(usize) -> String + 'static + Send + Sync,
73 {
74 self.name = Box::new(f);
75 self
76 }
77
78 pub fn get_timeout(&self) -> Option<Duration> {
81 self.timeout
82 }
83
84 pub fn get_stack_size(&self) -> Option<usize> {
85 self.stack_size.map(|size| size.get())
86 }
87
88 pub fn get_max_threads_limit(&self) -> u16 {
89 self.max_threads_limit
90 }
91
92 pub fn thread_count(&self) -> usize {
95 Arc::strong_count(&self.channel) - 1
96 }
97
98 pub fn is_thread_limit_reached(&self) -> bool {
99 self.thread_count() > self.get_max_threads_limit().into()
100 }
101
102 pub fn add_task_to_queue(&self, task: Task) -> usize {
103 let mut tx = self.channel.produce();
104 tx.tasks.push_back(task);
105 tx.count += 1;
106 let task_count = tx.count;
107 tx.notify_one();
108 task_count
109 }
110
111 pub fn execute(&self, task: Task) {
112 let task_count = self.add_task_to_queue(task);
113
114 let thread_count = self.thread_count();
115 if thread_count > self.get_max_threads_limit().into() {
116 return;
117 }
118
119 let threshold = thread_count * self.load_factor;
120 if task_count <= threshold {
121 return;
122 }
123
124 let b = self.thread_builder();
125 self.spawn(b)
126 .expect("failed to spawn a thread in thread pool");
127 }
128
129 pub fn spawn(&self, thread_builder: thread::Builder) -> io::Result<thread::JoinHandle<()>> {
130 let timeout = self.timeout;
131 let channel = self.channel.clone();
132
133 let worker = move || {
134 let mut rx = channel.consume();
135 loop {
136 rx = match rx.tasks.pop_front() {
137 Some(task) => {
138 drop(rx);
139 task.run();
140
141 let mut rx = channel.consume();
142 rx.count -= 1;
143 rx
144 }
145 None => match timeout {
146 None => rx.wait(),
147 Some(dur) => match rx.wait_timeout(dur) {
148 Ok(rx) => rx,
149 Err(_) => break,
150 },
151 },
152 }
153 }
154 };
155 thread_builder.spawn(worker)
156 }
157
158 pub fn thread_builder(&self) -> thread::Builder {
159 let mut thread = thread::Builder::new();
160 if let Some(size) = self.get_stack_size() {
161 thread = thread.stack_size(size);
162 }
163 let name = (self.name)(self.thread_count());
164 if !name.is_empty() {
165 thread = thread.name(name);
166 }
167 thread
168 }
169}