nio_threadpool/
lib.rs

1use mpmc_channel::MPMC;
2use std::{collections::VecDeque, io, num::NonZero, sync::Arc, thread, time::Duration};
3
4type Channel<Task> = Arc<MPMC<Queue<Task>>>;
5
6struct Queue<Task> {
7    tasks: VecDeque<Task>,
8    count: usize,
9}
10
11pub struct ThreadPool<Task: Runnable> {
12    channel: Channel<Task>,
13
14    // ----- config -----
15    max_threads_limit: u16,
16    timeout: Option<Duration>,
17    stack_size: Option<NonZero<usize>>,
18    load_factor: usize,
19    name: Box<dyn Fn(usize) -> String + Send + Sync>,
20}
21
22impl<T: Runnable> Default for ThreadPool<T> {
23    fn default() -> Self {
24        Self {
25            channel: Arc::new(MPMC::new(Queue {
26                tasks: VecDeque::with_capacity(64),
27                count: 0,
28            })),
29
30            timeout: Some(Duration::from_secs(7)),
31            max_threads_limit: 32,
32            load_factor: 2,
33            stack_size: None,
34            name: Box::new(|id| format!("Worker: {id}")),
35        }
36    }
37}
38
39pub trait Runnable: Send + 'static {
40    fn run(self);
41}
42
43impl<Task: Runnable> ThreadPool<Task> {
44    pub fn new() -> Self {
45        Self::default()
46    }
47
48    pub fn timeout(mut self, timeout: Option<Duration>) -> Self {
49        self.timeout = timeout;
50        self
51    }
52
53    pub fn max_threads_limit(mut self, limit: u16) -> Self {
54        assert!(limit > 0, "max threads limit must be greater than 0");
55        self.max_threads_limit = limit;
56        self
57    }
58
59    pub fn stack_size(mut self, stack_size: usize) -> Self {
60        self.stack_size = NonZero::new(stack_size);
61        self
62    }
63
64    pub fn load_factor(mut self, factor: usize) -> Self {
65        assert!(factor != 0, "threadpool load factor must be > 0");
66        self.load_factor = factor;
67        self
68    }
69
70    pub fn name<F>(mut self, f: F) -> Self
71    where
72        F: Fn(usize) -> String + 'static + Send + Sync,
73    {
74        self.name = Box::new(f);
75        self
76    }
77
78    // ---------------  Getter  ------------------
79
80    pub fn get_timeout(&self) -> Option<Duration> {
81        self.timeout
82    }
83
84    pub fn get_stack_size(&self) -> Option<usize> {
85        self.stack_size.map(|size| size.get())
86    }
87
88    pub fn get_max_threads_limit(&self) -> u16 {
89        self.max_threads_limit
90    }
91
92    // -------------------------------------------
93
94    pub fn thread_count(&self) -> usize {
95        Arc::strong_count(&self.channel) - 1
96    }
97
98    pub fn is_thread_limit_reached(&self) -> bool {
99        self.thread_count() > self.get_max_threads_limit().into()
100    }
101
102    pub fn add_task_to_queue(&self, task: Task) -> usize {
103        let mut tx = self.channel.produce();
104        tx.tasks.push_back(task);
105        tx.count += 1;
106        let task_count = tx.count;
107        tx.notify_one();
108        task_count
109    }
110
111    pub fn execute(&self, task: Task) {
112        let task_count = self.add_task_to_queue(task);
113
114        let thread_count = self.thread_count();
115        if thread_count > self.get_max_threads_limit().into() {
116            return;
117        }
118
119        let threshold = thread_count * self.load_factor;
120        if task_count <= threshold {
121            return;
122        }
123
124        let b = self.thread_builder();
125        self.spawn(b)
126            .expect("failed to spawn a thread in thread pool");
127    }
128
129    pub fn spawn(&self, thread_builder: thread::Builder) -> io::Result<thread::JoinHandle<()>> {
130        let timeout = self.timeout;
131        let channel = self.channel.clone();
132
133        let worker = move || {
134            let mut rx = channel.consume();
135            loop {
136                rx = match rx.tasks.pop_front() {
137                    Some(task) => {
138                        drop(rx);
139                        task.run();
140                        
141                        let mut rx = channel.consume();
142                        rx.count -= 1;
143                        rx
144                    }
145                    None => match timeout {
146                        None => rx.wait(),
147                        Some(dur) => match rx.wait_timeout(dur) {
148                            Ok(rx) => rx,
149                            Err(_) => break,
150                        },
151                    },
152                }
153            }
154        };
155        thread_builder.spawn(worker)
156    }
157
158    pub fn thread_builder(&self) -> thread::Builder {
159        let mut thread = thread::Builder::new();
160        if let Some(size) = self.get_stack_size() {
161            thread = thread.stack_size(size);
162        }
163        let name = (self.name)(self.thread_count());
164        if !name.is_empty() {
165            thread = thread.name(name);
166        }
167        thread
168    }
169}