1#![cfg_attr(test, recursion_limit = "256")]
36#![deny(
37 clippy::unwrap_used,
38 missing_docs,
39 missing_debug_implementations,
40 missing_copy_implementations,
41 trivial_casts,
42 trivial_numeric_casts,
43 unstable_features,
44 unused_import_braces,
45 unused_qualifications
46)]
47
48mod config;
49mod handler;
50mod state;
51mod worker;
52
53pub use crate::config::ThreadPoolConfig;
54#[cfg(feature = "async")]
55pub use crate::handler::ThreadPoolAsyncHandler;
56pub use crate::handler::{JoinHandle, ThreadPoolSyncHandler};
57
58use crate::state::State;
59use crate::worker::{MsgForWorker, Worker};
60use flume::{Receiver as FlumeReceiver, RecvTimeoutError, Sender as FlumeSender};
61use std::{
62 num::NonZeroU16,
63 sync::{
64 atomic::{AtomicU32, Ordering},
65 Arc,
66 },
67 time::Duration,
68};
69
70#[derive(Clone, Copy, Debug)]
71pub struct ThreadPoolDisconnected;
73
74impl std::fmt::Display for ThreadPoolDisconnected {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 write!(f, "Thread pool disconnected")
77 }
78}
79impl std::error::Error for ThreadPoolDisconnected {}
80
81#[derive(Debug)]
82pub struct ThreadPool<Shared: 'static + Clone + Send> {
85 sender: FlumeSender<MsgForWorker<Shared>>,
86}
87
88impl<Shared: 'static + Clone + Send> ThreadPool<Shared> {
89 pub fn start(config: ThreadPoolConfig, shared: Shared) -> Self {
92 let state = State::new(config);
93 let (sender, receiver) = if let Some(queue_size) = config.queue_size {
94 flume::bounded(queue_size)
95 } else {
96 flume::unbounded()
97 };
98
99 for _ in 0..config.min_workers.get() {
100 let worker = Worker::new(
101 config.keep_alive,
102 receiver.clone(),
103 sender.clone(),
104 shared.clone(),
105 state.clone(),
106 );
107 std::thread::spawn(move || worker.run());
108 }
109
110 ThreadPool { sender }
111 }
112 #[cfg(feature = "async")]
113 pub fn async_handler(&self) -> ThreadPoolAsyncHandler<Shared> {
115 ThreadPoolAsyncHandler::new(self.sender.clone())
116 }
117 pub fn sync_handler(&self) -> ThreadPoolSyncHandler<Shared> {
119 ThreadPoolSyncHandler::new(self.sender.clone())
120 }
121 #[cfg(feature = "async")]
122 pub fn into_async_handler(self) -> ThreadPoolAsyncHandler<Shared> {
124 ThreadPoolAsyncHandler::new(self.sender)
125 }
126 pub fn into_sync_handler(self) -> ThreadPoolSyncHandler<Shared> {
128 ThreadPoolSyncHandler::new(self.sender)
129 }
130}
131
132type Job<Shared> = Box<dyn FnOnce(&Shared) + Send + 'static>;
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use std::time::Instant;
138
139 #[cfg(feature = "async")]
140 const FOUR: NonZeroU16 = unsafe { NonZeroU16::new_unchecked(4) };
141
142 #[test]
143 fn test_sync() -> Result<(), ThreadPoolDisconnected> {
144 let tp = ThreadPool::start(ThreadPoolConfig::default(), ());
145
146 let tp_handler = tp.into_sync_handler();
147
148 assert_eq!(4, tp_handler.execute(|_| { 2 + 2 })?);
149
150 let start = Instant::now();
151
152 let r1 = tp_handler.launch(|_| std::thread::sleep(Duration::from_secs(1)))?;
153 let r2 = tp_handler.launch(|_| std::thread::sleep(Duration::from_secs(1)))?;
154 let r3 = tp_handler.launch(|_| std::thread::sleep(Duration::from_secs(1)))?;
155 let r4 = tp_handler.launch(|_| std::thread::sleep(Duration::from_secs(1)))?;
156
157 r1.join().expect("ThreadPool disconnected");
158 r2.join().expect("ThreadPool disconnected");
159 r3.join().expect("ThreadPool disconnected");
160 r4.join().expect("ThreadPool disconnected");
161
162 let elapsed = start.elapsed();
163
164 assert!(elapsed.as_secs() < 2);
165
166 Ok(())
167 }
168
169 #[cfg(feature = "async")]
170 #[test]
171 fn test_async() -> Result<(), ThreadPoolDisconnected> {
172 futures::executor::block_on(async {
173 let shared: i32 = 42;
174 let conf = ThreadPoolConfig::default()
175 .min_workers(FOUR)
176 .max_available_workers(FOUR);
177 println!("conf={:?}", conf);
178 let tp = ThreadPool::start(conf, shared);
179
180 let tp_handler = tp.into_async_handler();
181
182 assert_eq!(4u32, tp_handler.execute(|_| { 2 + 2 }).await?);
183
184 let start = Instant::now();
185
186 use futures::join;
187 let (res1, res2, res3, res4, res5, res6, res7, res8, res9, res10) = join!(
188 tp_handler.execute(|_| { std::thread::sleep(Duration::from_secs(1)) }),
189 tp_handler.execute(|_| { std::thread::sleep(Duration::from_secs(1)) }),
190 tp_handler.execute(|_| { std::thread::sleep(Duration::from_secs(1)) }),
191 tp_handler.execute(|_| { std::thread::sleep(Duration::from_secs(1)) }),
192 tp_handler.execute(|_| { std::thread::sleep(Duration::from_secs(1)) }),
193 tp_handler.execute(|_| { std::thread::sleep(Duration::from_secs(1)) }),
194 tp_handler.execute(|_| { std::thread::sleep(Duration::from_secs(1)) }),
195 tp_handler.execute(|_| { std::thread::sleep(Duration::from_secs(1)) }),
196 tp_handler.execute(|_| { std::thread::sleep(Duration::from_secs(1)) }),
197 tp_handler.execute(|_| { std::thread::sleep(Duration::from_secs(1)) }),
198 );
199
200 res1?;
201 res2?;
202 res3?;
203 res4?;
204 res5?;
205 res6?;
206 res7?;
207 res8?;
208 res9?;
209 res10?;
210
211 let elapsed = start.elapsed();
212
213 assert!(elapsed.as_secs() < 3);
214
215 Ok(())
216 })
217 }
218}