1use std::fmt;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::Arc;
4use std::thread::{spawn, JoinHandle};
5use std::time::Duration;
6
7use crossbeam_channel::{bounded, select, select_biased, tick};
8use genzero;
9use quanta::Clock;
10
11use super::channel::*;
12use super::err::ThreadPoolError;
13
14const UPDATE_SIZE: u8 = 0;
15const STOP_THREAD: u8 = 1;
16
17#[derive(Default, Clone, Copy)]
19pub struct Metrics {
20 pub active_threads: usize,
22 pub input_channel_len: usize,
23 pub input_channel_capacity: Option<usize>,
24 pub output_channel_len: usize,
25 pub output_channel_capacity: Option<usize>,
26
27 pub execution_count: usize,
30 pub average_execution_duration_ns: usize,
33 }
36
37struct ExecutionMetrics {
38 clock: Clock,
39 execution_counter: AtomicUsize,
40 total_execution_time_ns: AtomicUsize,
41 }
44
45impl ExecutionMetrics {
46 fn new() -> Self {
47 ExecutionMetrics {
48 clock: Clock::new(),
49 execution_counter: AtomicUsize::new(0),
50 total_execution_time_ns: AtomicUsize::new(0),
51 }
52 }
53
54 fn update(&self, execution_time: usize) {
55 self.execution_counter.fetch_add(1, Ordering::Relaxed);
56 self.total_execution_time_ns
57 .fetch_add(execution_time, Ordering::Relaxed);
58 }
59
60 fn get_and_reset_execution_count(&self) -> usize {
61 self.execution_counter.fetch_and(0, Ordering::Relaxed)
62 }
63
64 fn get_and_reset_total_execution_time_ns(&self) -> usize {
65 self.total_execution_time_ns.fetch_and(0, Ordering::Relaxed)
66 }
67}
68
69#[derive(Clone)]
81pub struct ThreadPool {
82 desired_threads: Arc<AtomicUsize>,
83 control_tx: crossbeam_channel::Sender<u8>,
84 metrics_rx: genzero::Receiver<Metrics>,
85}
86
87impl ThreadPool {
88 pub(super) fn new_lambda_pool<
89 T: Send + 'static,
90 U: Send + 'static,
91 V: Clone + Send + 'static,
92 >(
93 input_channel: Receiver<T>,
94 output_channel: Sender<U>,
95 shared_resource: V,
96 function: fn(&V, T) -> U,
97 ) -> Self {
98 let desired_threads = Arc::new(AtomicUsize::new(1));
99 let (control_tx, metrics_rx) = spawn_primary_lambda_thread(
100 input_channel,
101 output_channel,
102 shared_resource,
103 function,
104 desired_threads.clone(),
105 );
106
107 Self {
108 desired_threads,
109 control_tx,
110 metrics_rx,
111 }
112 }
113
114 pub(super) fn new_sink_pool<T: Send + 'static, V: Clone + Send + 'static>(
115 input_channel: Receiver<T>,
116 shared_resource: V,
117 function: fn(&V, T),
118 ) -> Self {
119 let desired_threads = Arc::new(AtomicUsize::new(1));
120 let (control_tx, metrics_rx) = spawn_primary_sink_thread(
121 input_channel,
122 shared_resource,
123 function,
124 desired_threads.clone(),
125 );
126
127 Self {
128 desired_threads,
129 control_tx,
130 metrics_rx,
131 }
132 }
133
134 pub fn get_pool_size(&self) -> usize {
137 self.desired_threads.load(Ordering::Acquire)
138 }
139
140 pub fn set_pool_size(&self, n: usize) -> Result<usize, ThreadPoolError> {
145 if n < 1 {
146 return Err(ThreadPoolError::ValueError);
147 }
148 self.desired_threads.store(n, Ordering::Relaxed);
149
150 match self.control_tx.send(UPDATE_SIZE) {
151 Ok(_) => Ok(n),
152 Err(_) => Err(ThreadPoolError::ThreadsLost),
153 }
154 }
155
156 pub fn get_metrics(&self) -> Metrics {
159 self.metrics_rx.recv().unwrap()
160 }
161}
162
163impl fmt::Display for ThreadPool {
164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 "Lambda Channel Thread Pool".fmt(f)
166 }
167}
168
169fn spawn_primary_lambda_thread<T: Send + 'static, U: Send + 'static, V: Clone + Send + 'static>(
170 input_channel: Receiver<T>,
171 output_channel: Sender<U>,
172 shared_resource: V,
173 function: fn(&V, T) -> U,
174 desired_threads: Arc<AtomicUsize>,
175) -> (crossbeam_channel::Sender<u8>, genzero::Receiver<Metrics>) {
176 let (control_tx, control_rx) = bounded(0);
177 let (mut metrics_tx, metrics_rx) = genzero::new(Metrics::default());
178
179 spawn(move || {
180 let mut threads = Vec::new();
181 let ticker = tick(Duration::from_secs(10));
182 let execution_metrics = Arc::new(ExecutionMetrics::new());
183 let input_channel_capacity = input_channel.capacity();
184 let output_channel_capacity = output_channel.capacity();
185
186 'main: loop {
187 select_biased! {
188 recv(output_channel.liveness_check) -> _ => {
189 break 'main;
190 },
191 recv(control_rx) -> c => {
192 let command = match c {
193 Ok(v) => v,
194 Err(_) => {
195 break 'main;
196 }
197 };
198
199 match command {
200 UPDATE_SIZE => {
201 let target = desired_threads.load(Ordering::Relaxed);
202 let current = threads.len() + 1;
203
204 if current < target {
205 for _ in 0..target-current {
206 threads.push(spawn_worker_lambda_thread(
207 input_channel.clone(),
208 output_channel.clone(),
209 shared_resource.clone(),
210 function,
211 execution_metrics.clone(),
212 ));
213 }
214 } else {
215 for _ in 0..current-target {
216 let (control_tx, _) = threads.pop().unwrap();
217 let _ = control_tx.send(STOP_THREAD);
218 }
219 }
220 },
221 STOP_THREAD => {
222 break 'main;
223 },
224 _ => {}
225 }
226 },
227 recv(ticker) -> _ => {
228 let thread_count = threads.len();
229 threads.retain(|thread| {
230 let delete = thread.1.is_finished();
231 !delete
232 });
233 let failed_threads = thread_count - threads.len();
234
235 for _ in 0..failed_threads {
236 threads.push(spawn_worker_lambda_thread(
237 input_channel.clone(),
238 output_channel.clone(),
239 shared_resource.clone(),
240 function,
241 execution_metrics.clone(),
242 ));
243 }
244
245 let execution_count = execution_metrics.get_and_reset_execution_count();
246 let average_execution_duration_ns = execution_metrics.get_and_reset_total_execution_time_ns() / execution_count;
247 metrics_tx.send(Metrics{
248 active_threads: threads.len(),
249 input_channel_len: input_channel.len(),
250 input_channel_capacity,
251 output_channel_len: output_channel.len(),
252 output_channel_capacity,
253
254 execution_count,
255 average_execution_duration_ns,
256 });
259 },
260 recv(input_channel.receiver) -> msg => {
261 let input = match msg {
262 Ok(v) => v,
263 Err(_) => {
264 break 'main;
265 }
266 };
267
268 let start_time = execution_metrics.clock.now();
269 let output = function(&shared_resource, input);
270 let execution_time = start_time.elapsed().as_nanos() as usize;
271
272 'inner: loop {
273 select! {
274 recv(control_rx) -> c => {
275 let command = match c {
276 Ok(v) => v,
277 Err(_) => {
278 break 'main;
279 }
280 };
281
282 match command {
283 UPDATE_SIZE => {
284 let target = desired_threads.load(Ordering::Relaxed);
285 let current = threads.len() + 1;
286
287 if current < target {
288 for _ in 0..target-current {
289 threads.push(spawn_worker_lambda_thread(
290 input_channel.clone(),
291 output_channel.clone(),
292 shared_resource.clone(),
293 function,
294 execution_metrics.clone(),
295 ));
296 }
297 } else {
298 for _ in 0..current-target {
299 let (control_tx, _) = threads.pop().unwrap();
300 let _ = control_tx.send(STOP_THREAD);
301 }
302 }
303 },
304 STOP_THREAD => {
305 break 'main;
306 },
307 _ => {}
308 }
309 },
310 recv(ticker) -> _ => {
311 let thread_count = threads.len();
312 threads.retain(|thread| {
313 let delete = thread.1.is_finished();
314 !delete
315 });
316 let failed_threads = thread_count - threads.len();
317
318 for _ in 0..failed_threads {
319 threads.push(spawn_worker_lambda_thread(
320 input_channel.clone(),
321 output_channel.clone(),
322 shared_resource.clone(),
323 function,
324 execution_metrics.clone(),
325 ));
326 }
327
328 let execution_count = execution_metrics.get_and_reset_execution_count();
329 let average_execution_duration_ns = execution_metrics.get_and_reset_total_execution_time_ns() / execution_count;
330 metrics_tx.send(Metrics{
331 active_threads: threads.len(),
332 input_channel_len: input_channel.len(),
333 input_channel_capacity,
334 output_channel_len: output_channel.len(),
335 output_channel_capacity,
336
337 execution_count,
338 average_execution_duration_ns,
339 });
342 },
343 send(output_channel.sender, output) -> result => {
344 match result {
345 Ok(_) => {
346 execution_metrics.update(execution_time);
347 break 'inner;
348 }
349 Err(_) => {
350 break 'main;
351 }
352 }
353 }
354 }
355 }
356 }
357 }
358 }
359 });
360
361 (control_tx, metrics_rx)
362}
363
364fn spawn_worker_lambda_thread<T: Send + 'static, U: Send + 'static, V: Clone + Send + 'static>(
365 input_channel: Receiver<T>,
366 output_channel: Sender<U>,
367 shared_resource: V,
368 function: fn(&V, T) -> U,
369 execution_metrics: Arc<ExecutionMetrics>,
370) -> (crossbeam_channel::Sender<u8>, JoinHandle<()>) {
371 let (control_tx, control_rx) = bounded(0);
372
373 let handle = spawn(move || 'main: loop {
374 select_biased! {
375 recv(output_channel.liveness_check) -> _ => {
376 break 'main;
377 },
378 recv(control_rx) -> c => {
379 let command = match c {
380 Ok(v) => v,
381 Err(_) => {
382 break 'main;
383 }
384 };
385
386 if command == STOP_THREAD {
387 break 'main;
388 }
389 },
390 recv(input_channel.receiver) -> msg => {
391 let input = match msg {
392 Ok(v) => v,
393 Err(_) => {
394 break 'main;
395 }
396 };
397
398 let start_time = execution_metrics.clock.now();
399 let output = function(&shared_resource, input);
400 let execution_time = start_time.elapsed().as_nanos() as usize;
401
402 'inner: loop {
403 select! {
404 recv(control_rx) -> c => {
405 let command = match c {
406 Ok(v) => v,
407 Err(_) => {
408 drop(input_channel);
409 let _ = output_channel.send(output);
410 break 'main;
411 }
412 };
413
414 if command == STOP_THREAD {
415 drop(input_channel);
416 let _ = output_channel.send(output);
417 break 'main;
418 }
419 },
420 send(output_channel.sender, output) -> result => {
421 match result {
422 Ok(_) => {
423 execution_metrics.update(execution_time);
424 break 'inner;
425 }
426 Err(_) => {
427 break 'main;
428 }
429 }
430 }
431 }
432 }
433 }
434 }
435 });
436 (control_tx, handle)
437}
438
439fn spawn_primary_sink_thread<T: Send + 'static, V: Clone + Send + 'static>(
440 input_channel: Receiver<T>,
441 shared_resource: V,
442 function: fn(&V, T),
443 desired_threads: Arc<AtomicUsize>,
444) -> (crossbeam_channel::Sender<u8>, genzero::Receiver<Metrics>) {
445 let (control_tx, control_rx) = bounded(0);
446 let (mut metrics_tx, metrics_rx) = genzero::new(Metrics::default());
447
448 spawn(move || {
449 let mut threads = Vec::new();
450 let ticker = tick(Duration::from_secs(10));
451 let execution_metrics = Arc::new(ExecutionMetrics::new());
452 let input_channel_capacity = input_channel.capacity();
453 let output_channel_capacity = None;
454
455 'main: loop {
456 select_biased! {
457 recv(control_rx) -> c => {
458 let command = match c {
459 Ok(v) => v,
460 Err(_) => {
461 break 'main;
462 }
463 };
464
465 match command {
466 UPDATE_SIZE => {
467 let target = desired_threads.load(Ordering::Relaxed);
468 let current = threads.len() + 1;
469
470 if current < target {
471 for _ in 0..target-current {
472 threads.push(spawn_worker_sink_thread(
473 input_channel.clone(),
474 shared_resource.clone(),
475 function,
476 execution_metrics.clone(),
477 ));
478 }
479 } else {
480 for _ in 0..current-target {
481 let (control_tx, _) = threads.pop().unwrap();
482 let _ = control_tx.send(STOP_THREAD);
483 }
484 }
485 },
486 STOP_THREAD => {
487 break 'main;
488 },
489 _ => {}
490 }
491 },
492 recv(ticker) -> _ => {
493 let thread_count = threads.len();
494 threads.retain(|thread| {
495 let delete = thread.1.is_finished();
496 !delete
497 });
498 let failed_threads = thread_count - threads.len();
499
500 for _ in 0..failed_threads {
501 threads.push(spawn_worker_sink_thread(
502 input_channel.clone(),
503 shared_resource.clone(),
504 function,
505 execution_metrics.clone(),
506 ));
507 }
508
509 let execution_count = execution_metrics.get_and_reset_execution_count();
510 let average_execution_duration_ns = execution_metrics.get_and_reset_total_execution_time_ns() / execution_count;
511 metrics_tx.send(Metrics{
512 active_threads: threads.len(),
513 input_channel_len: input_channel.len(),
514 input_channel_capacity,
515 output_channel_len: 0,
516 output_channel_capacity,
517
518 execution_count,
519 average_execution_duration_ns,
520 });
523 },
524 recv(input_channel.receiver) -> msg => {
525 let input = match msg {
526 Ok(v) => v,
527 Err(_) => {
528 break 'main;
529 }
530 };
531
532 let start_time = execution_metrics.clock.now();
533 function(&shared_resource, input);
534 let execution_time = start_time.elapsed().as_nanos() as usize;
535 execution_metrics.update(execution_time);
536 }
537 }
538 }
539 });
540
541 (control_tx, metrics_rx)
542}
543
544fn spawn_worker_sink_thread<T: Send + 'static, V: Clone + Send + 'static>(
545 input_channel: Receiver<T>,
546 shared_resource: V,
547 function: fn(&V, T),
548 execution_metrics: Arc<ExecutionMetrics>,
549) -> (crossbeam_channel::Sender<u8>, JoinHandle<()>) {
550 let (control_tx, control_rx) = bounded(0);
551
552 let handle = spawn(move || 'main: loop {
553 select_biased! {
554 recv(control_rx) -> c => {
555 let command = match c {
556 Ok(v) => v,
557 Err(_) => {
558 break 'main;
559 }
560 };
561
562 if command == STOP_THREAD {
563 break 'main;
564 }
565 },
566 recv(input_channel.receiver) -> msg => {
567 let input = match msg {
568 Ok(v) => v,
569 Err(_) => {
570 break 'main;
571 }
572 };
573
574 let start_time = execution_metrics.clock.now();
575 function(&shared_resource, input);
576 let execution_time = start_time.elapsed().as_nanos() as usize;
577 execution_metrics.update(execution_time);
578 }
579 }
580 });
581 (control_tx, handle)
582}
583
584#[cfg(test)]
585mod tests {
586 use crate::new_lambda_channel;
587
588 use super::*;
589 use std::thread::sleep;
590
591 fn simple_task(_: &Option<()>, x: u32) -> f32 {
592 x as f32
593 }
594
595 fn io_task(_: &Option<()>, x: u32) -> f32 {
596 sleep(Duration::from_millis(10));
597 (x as f32) / 3.0
598 }
599
600 #[test]
601 fn single_worker() {
602 let tasks = 100usize;
603 let capacity = 10;
604 let (tx, rx, _pool) = new_lambda_channel(Some(capacity), Some(capacity), None, simple_task);
605
606 spawn(move || {
607 for i in 0..tasks {
608 tx.send(i as u32).unwrap();
609 }
610 });
611
612 let mut c = 0usize;
613 while rx.recv().is_ok() {
614 c += 1;
615 }
616
617 assert_eq!(c, tasks);
618 }
619
620 #[test]
621 fn many_workers() {
622 let tasks = 100usize;
623 let capacity = 10;
624 let (tx, rx, pool) = new_lambda_channel(Some(capacity), Some(capacity), None, io_task);
625 assert_eq!(pool.set_pool_size(4), Ok(4));
626
627 let clock = Clock::new();
628 let start = clock.now();
629 spawn(move || {
630 for i in 0..tasks {
631 tx.send(i as u32).unwrap();
632 }
633 });
634
635 let mut c = 0usize;
636 while rx.recv().is_ok() {
637 c += 1;
638 }
639
640 assert!(start.elapsed() < Duration::from_millis(4 * (tasks as u64)));
641 assert_eq!(c, tasks);
642 }
643
644 #[test]
645 fn drop_input_tx() {
646 let capacity = 10;
647 let (tx, rx, pool) = new_lambda_channel(Some(capacity), Some(capacity), None, simple_task);
648 assert_eq!(pool.set_pool_size(4), Ok(4));
649
650 for i in 0..(2 * capacity) {
651 tx.send(i as u32).unwrap();
652 }
653
654 sleep(Duration::from_millis(1));
655
656 assert_eq!(tx.len(), 6);
658 assert!(rx.is_full());
659
660 assert_eq!(pool.set_pool_size(6), Ok(6));
662
663 sleep(Duration::from_millis(1));
664
665 assert_eq!(tx.len(), 4);
667 assert!(rx.is_full());
668
669 drop(tx);
670
671 let mut c = 0usize;
672 while rx.recv().is_ok() {
673 c += 1;
674 }
675
676 assert_eq!(c, 2 * capacity);
677 }
678
679 #[test]
680 fn drop_output_rx() {
681 let capacity = 10;
682 let (tx, rx, pool) = new_lambda_channel(Some(capacity), Some(capacity), None, simple_task);
683 assert_eq!(pool.set_pool_size(4), Ok(4));
684 drop(rx);
685
686 let mut c = 0;
687 while tx.send(0).is_ok() {
688 c += 1;
689 }
690
691 assert_eq!(c, 0);
692 assert_eq!(tx.len(), c);
693 }
694
695 #[test]
696 fn thrash_pool_size() {
697 let tasks = 100usize;
698 let capacity = 10;
699 let (tx, rx, pool) = new_lambda_channel(Some(capacity), Some(capacity), None, simple_task);
700 assert_eq!(pool.set_pool_size(4), Ok(4));
701
702 spawn(move || {
703 for i in 0..tasks {
704 tx.send(i as u32).unwrap();
705 }
706 });
707
708 let mut c = 0;
709 while rx.recv().is_ok() {
710 c += 1;
711 if c >= 10 {
712 break;
713 }
714 }
715
716 assert_eq!(pool.set_pool_size(6), Ok(6));
717 while rx.recv().is_ok() {
718 c += 1;
719 if c >= 20 {
720 break;
721 }
722 }
723 assert_eq!(pool.set_pool_size(3), Ok(3));
724 while rx.recv().is_ok() {
725 c += 1;
726 if c >= 30 {
727 break;
728 }
729 }
730 assert_eq!(pool.set_pool_size(5), Ok(5));
731 sleep(Duration::from_millis(10));
732 assert_eq!(pool.set_pool_size(2), Ok(2));
733 while rx.recv().is_ok() {
734 c += 1;
735 if c >= 50 {
736 break;
737 }
738 }
739 assert_eq!(pool.set_pool_size(1), Ok(1));
740
741 while rx.recv().is_ok() {
742 c += 1;
743 }
744
745 assert_eq!(c, tasks);
746 }
747}