1use std::fmt;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::Arc;
4use std::thread::{spawn, JoinHandle};
5use std::time::Duration;
6
7use crossbeam_channel::{bounded, select, select_biased, tick};
8use genzero;
9use quanta::Clock;
10
11use super::channel::*;
12use super::err::ThreadPoolError;
13
14const UPDATE_SIZE: u8 = 0;
15const STOP_THREAD: u8 = 1;
16
17#[derive(Default, Clone, Copy)]
19pub struct Metrics {
20 pub active_threads: usize,
22 pub input_channel_len: usize,
23 pub input_channel_capacity: Option<usize>,
24 pub output_channel_len: usize,
25 pub output_channel_capacity: Option<usize>,
26
27 pub execution_count: usize,
30 pub average_execution_duration_ns: usize,
33 }
36
37struct ExecutionMetrics {
38 clock: Clock,
39 execution_counter: AtomicUsize,
40 total_execution_time_ns: AtomicUsize,
41 }
44
45impl ExecutionMetrics {
46 fn new() -> Self {
47 ExecutionMetrics {
48 clock: Clock::new(),
49 execution_counter: AtomicUsize::new(0),
50 total_execution_time_ns: AtomicUsize::new(0),
51 }
52 }
53
54 fn update(&self, execution_time: usize) {
55 self.execution_counter.fetch_add(1, Ordering::Relaxed);
56 self.total_execution_time_ns
57 .fetch_add(execution_time, Ordering::Relaxed);
58 }
59
60 fn get_and_reset_execution_count(&self) -> usize {
61 self.execution_counter.fetch_and(0, Ordering::Relaxed)
62 }
63
64 fn get_and_reset_total_execution_time_ns(&self) -> usize {
65 self.total_execution_time_ns.fetch_and(0, Ordering::Relaxed)
66 }
67}
68
69#[derive(Clone)]
81pub struct ThreadPool {
82 desired_threads: Arc<AtomicUsize>,
83 control_tx: crossbeam_channel::Sender<u8>,
84 metrics_rx: genzero::Receiver<Metrics>,
85}
86
87impl ThreadPool {
88 pub(super) fn new_lambda_pool<
89 T: Send + 'static,
90 U: Send + 'static,
91 V: Clone + Send + 'static,
92 >(
93 input_channel: Receiver<T>,
94 output_channel: Sender<U>,
95 shared_resource: V,
96 function: fn(&V, T) -> U,
97 ) -> Self {
98 let desired_threads = Arc::new(AtomicUsize::new(1));
99 let (control_tx, metrics_rx) = spawn_primary_lambda_thread(
100 input_channel,
101 output_channel,
102 shared_resource,
103 function,
104 desired_threads.clone(),
105 );
106
107 Self {
108 desired_threads,
109 control_tx,
110 metrics_rx,
111 }
112 }
113
114 pub(super) fn new_sink_pool<T: Send + 'static, V: Clone + Send + 'static>(
115 input_channel: Receiver<T>,
116 shared_resource: V,
117 function: fn(&V, T),
118 ) -> Self {
119 let desired_threads = Arc::new(AtomicUsize::new(1));
120 let (control_tx, metrics_rx) = spawn_primary_sink_thread(
121 input_channel,
122 shared_resource,
123 function,
124 desired_threads.clone(),
125 );
126
127 Self {
128 desired_threads,
129 control_tx,
130 metrics_rx,
131 }
132 }
133
134 pub fn get_pool_size(&self) -> usize {
137 self.desired_threads.load(Ordering::Acquire)
138 }
139
140 pub fn set_pool_size(&self, n: usize) -> Result<usize, ThreadPoolError> {
145 if n < 1 {
146 return Err(ThreadPoolError::ValueError);
147 }
148 self.desired_threads.store(n, Ordering::Relaxed);
149
150 match self.control_tx.send(UPDATE_SIZE) {
151 Ok(_) => Ok(n),
152 Err(_) => Err(ThreadPoolError::ThreadsLost),
153 }
154 }
155
156 pub fn get_metrics(&self) -> Metrics {
159 self.metrics_rx.recv().unwrap()
160 }
161}
162
163impl fmt::Display for ThreadPool {
164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 "Lambda Channel Thread Pool".fmt(f)
166 }
167}
168
169fn spawn_primary_lambda_thread<T: Send + 'static, U: Send + 'static, V: Clone + Send + 'static>(
170 input_channel: Receiver<T>,
171 output_channel: Sender<U>,
172 shared_resource: V,
173 function: fn(&V, T) -> U,
174 desired_threads: Arc<AtomicUsize>,
175) -> (crossbeam_channel::Sender<u8>, genzero::Receiver<Metrics>) {
176 let (control_tx, control_rx) = bounded(0);
177 let (mut metrics_tx, metrics_rx) = genzero::new(Metrics::default());
178
179 spawn(move || {
180 let mut threads = Vec::new();
181 let ticker = tick(Duration::from_secs(10));
182 let execution_metrics = Arc::new(ExecutionMetrics::new());
183 let input_channel_capacity = input_channel.capacity();
184 let output_channel_capacity = output_channel.capacity();
185
186 'main: loop {
187 select_biased! {
188 recv(output_channel.liveness_check) -> _ => {
189 break 'main;
190 },
191 recv(control_rx) -> c => {
192 let command = match c {
193 Ok(v) => v,
194 Err(_) => {
195 break 'main;
196 }
197 };
198
199 match command {
200 UPDATE_SIZE => {
201 let target = desired_threads.load(Ordering::Relaxed);
202 let current = threads.len() + 1;
203
204 if current < target {
205 for _ in 0..target-current {
206 threads.push(spawn_worker_lambda_thread(
207 input_channel.clone(),
208 output_channel.clone(),
209 shared_resource.clone(),
210 function,
211 execution_metrics.clone(),
212 ));
213 }
214 } else {
215 for _ in 0..current-target {
216 let (control_tx, _) = threads.pop().unwrap();
217 let _ = control_tx.send(STOP_THREAD);
218 }
219 }
220 },
221 STOP_THREAD => {
222 break 'main;
223 },
224 _ => {}
225 }
226 },
227 recv(ticker) -> _ => {
228 let thread_count = threads.len();
229 threads.retain(|thread| {
230 let delete = thread.1.is_finished();
231 !delete
232 });
233 let failed_threads = thread_count - threads.len();
234
235 for _ in 0..failed_threads {
236 threads.push(spawn_worker_lambda_thread(
237 input_channel.clone(),
238 output_channel.clone(),
239 shared_resource.clone(),
240 function,
241 execution_metrics.clone(),
242 ));
243 }
244
245 let execution_count = execution_metrics.get_and_reset_execution_count();
246 let average_execution_duration_ns = match execution_count {
247 0 => 0,
248 _ => execution_metrics.get_and_reset_total_execution_time_ns() / execution_count,
249 };
250
251 metrics_tx.send(Metrics{
252 active_threads: threads.len() + 1,
253 input_channel_len: input_channel.len(),
254 input_channel_capacity,
255 output_channel_len: output_channel.len(),
256 output_channel_capacity,
257
258 execution_count,
259 average_execution_duration_ns,
260 });
263 },
264 recv(input_channel.receiver) -> msg => {
265 let input = match msg {
266 Ok(v) => v,
267 Err(_) => {
268 break 'main;
269 }
270 };
271
272 let start_time = execution_metrics.clock.now();
273 let output = function(&shared_resource, input);
274 let execution_time = start_time.elapsed().as_nanos() as usize;
275
276 'inner: loop {
277 select! {
278 recv(control_rx) -> c => {
279 let command = match c {
280 Ok(v) => v,
281 Err(_) => {
282 break 'main;
283 }
284 };
285
286 match command {
287 UPDATE_SIZE => {
288 let target = desired_threads.load(Ordering::Relaxed);
289 let current = threads.len() + 1;
290
291 if current < target {
292 for _ in 0..target-current {
293 threads.push(spawn_worker_lambda_thread(
294 input_channel.clone(),
295 output_channel.clone(),
296 shared_resource.clone(),
297 function,
298 execution_metrics.clone(),
299 ));
300 }
301 } else {
302 for _ in 0..current-target {
303 let (control_tx, _) = threads.pop().unwrap();
304 let _ = control_tx.send(STOP_THREAD);
305 }
306 }
307 },
308 STOP_THREAD => {
309 break 'main;
310 },
311 _ => {}
312 }
313 },
314 recv(ticker) -> _ => {
315 let thread_count = threads.len();
316 threads.retain(|thread| {
317 let delete = thread.1.is_finished();
318 !delete
319 });
320 let failed_threads = thread_count - threads.len();
321
322 for _ in 0..failed_threads {
323 threads.push(spawn_worker_lambda_thread(
324 input_channel.clone(),
325 output_channel.clone(),
326 shared_resource.clone(),
327 function,
328 execution_metrics.clone(),
329 ));
330 }
331
332 let execution_count = execution_metrics.get_and_reset_execution_count();
333 let average_execution_duration_ns = match execution_count {
334 0 => 0,
335 _ => execution_metrics.get_and_reset_total_execution_time_ns() / execution_count,
336 };
337 metrics_tx.send(Metrics{
338 active_threads: threads.len() + 1,
339 input_channel_len: input_channel.len(),
340 input_channel_capacity,
341 output_channel_len: output_channel.len(),
342 output_channel_capacity,
343
344 execution_count,
345 average_execution_duration_ns,
346 });
349 },
350 send(output_channel.sender, output) -> result => {
351 match result {
352 Ok(_) => {
353 execution_metrics.update(execution_time);
354 break 'inner;
355 }
356 Err(_) => {
357 break 'main;
358 }
359 }
360 }
361 }
362 }
363 }
364 }
365 }
366 });
367
368 (control_tx, metrics_rx)
369}
370
371fn spawn_worker_lambda_thread<T: Send + 'static, U: Send + 'static, V: Clone + Send + 'static>(
372 input_channel: Receiver<T>,
373 output_channel: Sender<U>,
374 shared_resource: V,
375 function: fn(&V, T) -> U,
376 execution_metrics: Arc<ExecutionMetrics>,
377) -> (crossbeam_channel::Sender<u8>, JoinHandle<()>) {
378 let (control_tx, control_rx) = bounded(0);
379
380 let handle = spawn(move || 'main: loop {
381 select_biased! {
382 recv(output_channel.liveness_check) -> _ => {
383 break 'main;
384 },
385 recv(control_rx) -> c => {
386 let command = match c {
387 Ok(v) => v,
388 Err(_) => {
389 break 'main;
390 }
391 };
392
393 if command == STOP_THREAD {
394 break 'main;
395 }
396 },
397 recv(input_channel.receiver) -> msg => {
398 let input = match msg {
399 Ok(v) => v,
400 Err(_) => {
401 break 'main;
402 }
403 };
404
405 let start_time = execution_metrics.clock.now();
406 let output = function(&shared_resource, input);
407 let execution_time = start_time.elapsed().as_nanos() as usize;
408
409 'inner: loop {
410 select! {
411 recv(control_rx) -> c => {
412 let command = match c {
413 Ok(v) => v,
414 Err(_) => {
415 drop(input_channel);
416 let _ = output_channel.send(output);
417 break 'main;
418 }
419 };
420
421 if command == STOP_THREAD {
422 drop(input_channel);
423 let _ = output_channel.send(output);
424 break 'main;
425 }
426 },
427 send(output_channel.sender, output) -> result => {
428 match result {
429 Ok(_) => {
430 execution_metrics.update(execution_time);
431 break 'inner;
432 }
433 Err(_) => {
434 break 'main;
435 }
436 }
437 }
438 }
439 }
440 }
441 }
442 });
443 (control_tx, handle)
444}
445
446fn spawn_primary_sink_thread<T: Send + 'static, V: Clone + Send + 'static>(
447 input_channel: Receiver<T>,
448 shared_resource: V,
449 function: fn(&V, T),
450 desired_threads: Arc<AtomicUsize>,
451) -> (crossbeam_channel::Sender<u8>, genzero::Receiver<Metrics>) {
452 let (control_tx, control_rx) = bounded(0);
453 let (mut metrics_tx, metrics_rx) = genzero::new(Metrics::default());
454
455 spawn(move || {
456 let mut threads = Vec::new();
457 let ticker = tick(Duration::from_secs(10));
458 let execution_metrics = Arc::new(ExecutionMetrics::new());
459 let input_channel_capacity = input_channel.capacity();
460 let output_channel_capacity = None;
461
462 'main: loop {
463 select_biased! {
464 recv(control_rx) -> c => {
465 let command = match c {
466 Ok(v) => v,
467 Err(_) => {
468 break 'main;
469 }
470 };
471
472 match command {
473 UPDATE_SIZE => {
474 let target = desired_threads.load(Ordering::Relaxed);
475 let current = threads.len() + 1;
476
477 if current < target {
478 for _ in 0..target-current {
479 threads.push(spawn_worker_sink_thread(
480 input_channel.clone(),
481 shared_resource.clone(),
482 function,
483 execution_metrics.clone(),
484 ));
485 }
486 } else {
487 for _ in 0..current-target {
488 let (control_tx, _) = threads.pop().unwrap();
489 let _ = control_tx.send(STOP_THREAD);
490 }
491 }
492 },
493 STOP_THREAD => {
494 break 'main;
495 },
496 _ => {}
497 }
498 },
499 recv(ticker) -> _ => {
500 let thread_count = threads.len();
501 threads.retain(|thread| {
502 let delete = thread.1.is_finished();
503 !delete
504 });
505 let failed_threads = thread_count - threads.len();
506
507 for _ in 0..failed_threads {
508 threads.push(spawn_worker_sink_thread(
509 input_channel.clone(),
510 shared_resource.clone(),
511 function,
512 execution_metrics.clone(),
513 ));
514 }
515
516 let execution_count = execution_metrics.get_and_reset_execution_count();
517 let average_execution_duration_ns = match execution_count {
518 0 => 0,
519 _ => execution_metrics.get_and_reset_total_execution_time_ns() / execution_count,
520 };
521 metrics_tx.send(Metrics{
522 active_threads: threads.len() + 1,
523 input_channel_len: input_channel.len(),
524 input_channel_capacity,
525 output_channel_len: 0,
526 output_channel_capacity,
527
528 execution_count,
529 average_execution_duration_ns,
530 });
533 },
534 recv(input_channel.receiver) -> msg => {
535 let input = match msg {
536 Ok(v) => v,
537 Err(_) => {
538 break 'main;
539 }
540 };
541
542 let start_time = execution_metrics.clock.now();
543 function(&shared_resource, input);
544 let execution_time = start_time.elapsed().as_nanos() as usize;
545 execution_metrics.update(execution_time);
546 }
547 }
548 }
549 });
550
551 (control_tx, metrics_rx)
552}
553
554fn spawn_worker_sink_thread<T: Send + 'static, V: Clone + Send + 'static>(
555 input_channel: Receiver<T>,
556 shared_resource: V,
557 function: fn(&V, T),
558 execution_metrics: Arc<ExecutionMetrics>,
559) -> (crossbeam_channel::Sender<u8>, JoinHandle<()>) {
560 let (control_tx, control_rx) = bounded(0);
561
562 let handle = spawn(move || 'main: loop {
563 select_biased! {
564 recv(control_rx) -> c => {
565 let command = match c {
566 Ok(v) => v,
567 Err(_) => {
568 break 'main;
569 }
570 };
571
572 if command == STOP_THREAD {
573 break 'main;
574 }
575 },
576 recv(input_channel.receiver) -> msg => {
577 let input = match msg {
578 Ok(v) => v,
579 Err(_) => {
580 break 'main;
581 }
582 };
583
584 let start_time = execution_metrics.clock.now();
585 function(&shared_resource, input);
586 let execution_time = start_time.elapsed().as_nanos() as usize;
587 execution_metrics.update(execution_time);
588 }
589 }
590 });
591 (control_tx, handle)
592}
593
594#[cfg(test)]
595mod tests {
596 use crate::new_lambda_channel;
597
598 use super::*;
599 use std::thread::sleep;
600
601 fn simple_task(_: &Option<()>, x: u32) -> f32 {
602 x as f32
603 }
604
605 fn io_task(_: &Option<()>, x: u32) -> f32 {
606 sleep(Duration::from_millis(10));
607 (x as f32) / 3.0
608 }
609
610 #[test]
611 fn single_worker() {
612 let tasks = 100usize;
613 let capacity = 10;
614 let (tx, rx, _pool) = new_lambda_channel(Some(capacity), Some(capacity), None, simple_task);
615
616 spawn(move || {
617 for i in 0..tasks {
618 tx.send(i as u32).unwrap();
619 }
620 });
621
622 let mut c = 0usize;
623 while rx.recv().is_ok() {
624 c += 1;
625 }
626
627 assert_eq!(c, tasks);
628 }
629
630 #[test]
631 fn many_workers() {
632 let tasks = 100usize;
633 let capacity = 10;
634 let (tx, rx, pool) = new_lambda_channel(Some(capacity), Some(capacity), None, io_task);
635 assert_eq!(pool.set_pool_size(4), Ok(4));
636
637 let clock = Clock::new();
638 let start = clock.now();
639 spawn(move || {
640 for i in 0..tasks {
641 tx.send(i as u32).unwrap();
642 }
643 });
644
645 let mut c = 0usize;
646 while rx.recv().is_ok() {
647 c += 1;
648 }
649
650 assert!(start.elapsed() < Duration::from_millis(4 * (tasks as u64)));
651 assert_eq!(c, tasks);
652 }
653
654 #[test]
655 fn drop_input_tx() {
656 let capacity = 10;
657 let (tx, rx, pool) = new_lambda_channel(Some(capacity), Some(capacity), None, simple_task);
658 assert_eq!(pool.set_pool_size(4), Ok(4));
659
660 for i in 0..(2 * capacity) {
661 tx.send(i as u32).unwrap();
662 }
663
664 sleep(Duration::from_millis(1));
665
666 assert_eq!(tx.len(), 6);
668 assert!(rx.is_full());
669
670 assert_eq!(pool.set_pool_size(6), Ok(6));
672
673 sleep(Duration::from_millis(1));
674
675 assert_eq!(tx.len(), 4);
677 assert!(rx.is_full());
678
679 drop(tx);
680
681 let mut c = 0usize;
682 while rx.recv().is_ok() {
683 c += 1;
684 }
685
686 assert_eq!(c, 2 * capacity);
687 }
688
689 #[test]
690 fn drop_output_rx() {
691 let capacity = 10;
692 let (tx, rx, pool) = new_lambda_channel(Some(capacity), Some(capacity), None, simple_task);
693 assert_eq!(pool.set_pool_size(4), Ok(4));
694 drop(rx);
695
696 let mut c = 0;
697 while tx.send(0).is_ok() {
698 c += 1;
699 }
700
701 assert_eq!(c, 0);
702 assert_eq!(tx.len(), c);
703 }
704
705 #[test]
706 fn thrash_pool_size() {
707 let tasks = 100usize;
708 let capacity = 10;
709 let (tx, rx, pool) = new_lambda_channel(Some(capacity), Some(capacity), None, simple_task);
710 assert_eq!(pool.set_pool_size(4), Ok(4));
711
712 spawn(move || {
713 for i in 0..tasks {
714 tx.send(i as u32).unwrap();
715 }
716 });
717
718 let mut c = 0;
719 while rx.recv().is_ok() {
720 c += 1;
721 if c >= 10 {
722 break;
723 }
724 }
725
726 assert_eq!(pool.set_pool_size(6), Ok(6));
727 while rx.recv().is_ok() {
728 c += 1;
729 if c >= 20 {
730 break;
731 }
732 }
733 assert_eq!(pool.set_pool_size(3), Ok(3));
734 while rx.recv().is_ok() {
735 c += 1;
736 if c >= 30 {
737 break;
738 }
739 }
740 assert_eq!(pool.set_pool_size(5), Ok(5));
741 sleep(Duration::from_millis(10));
742 assert_eq!(pool.set_pool_size(2), Ok(2));
743 while rx.recv().is_ok() {
744 c += 1;
745 if c >= 50 {
746 break;
747 }
748 }
749 assert_eq!(pool.set_pool_size(1), Ok(1));
750
751 while rx.recv().is_ok() {
752 c += 1;
753 }
754
755 assert_eq!(c, tasks);
756 }
757}