lambda_channel/
thread.rs

1use std::fmt;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::Arc;
4use std::thread::{spawn, JoinHandle};
5use std::time::Duration;
6
7use crossbeam_channel::{bounded, select, select_biased, tick};
8use genzero;
9use quanta::Clock;
10
11use super::channel::*;
12use super::err::ThreadPoolError;
13
14const UPDATE_SIZE: u8 = 0;
15const STOP_THREAD: u8 = 1;
16
17/// Struct containing thread pool metrics, that is updated every approx. every 10 seconds.
18#[derive(Default, Clone, Copy)]
19pub struct Metrics {
20    /// Current number of running threads, which may temporarily differ from [`ThreadPool::get_pool_size`].
21    pub active_threads: usize,
22    pub input_channel_len: usize,
23    pub input_channel_capacity: Option<usize>,
24    pub output_channel_len: usize,
25    pub output_channel_capacity: Option<usize>,
26
27    /// Total executions of the provided function since the last metrics update.
28    /// Inputs that have been executed on, but not sent are not counted until they have sent;
29    pub execution_count: usize,
30    /// Average nano-second execution duration of the provided function since the last metrics update.
31    /// Does not include any time the input/output spends in the channels.
32    pub average_execution_duration_ns: usize,
33    // pub minimum_execution_duration_ns: usize,
34    // pub maximum_execution_duration_ns: usize,
35}
36
37struct ExecutionMetrics {
38    clock: Clock,
39    execution_counter: AtomicUsize,
40    total_execution_time_ns: AtomicUsize,
41    // min_execution_time_ns: AtomicUsize,
42    // max_execution_time_ns: AtomicUsize,
43}
44
45impl ExecutionMetrics {
46    fn new() -> Self {
47        ExecutionMetrics {
48            clock: Clock::new(),
49            execution_counter: AtomicUsize::new(0),
50            total_execution_time_ns: AtomicUsize::new(0),
51        }
52    }
53
54    fn update(&self, execution_time: usize) {
55        self.execution_counter.fetch_add(1, Ordering::Relaxed);
56        self.total_execution_time_ns
57            .fetch_add(execution_time, Ordering::Relaxed);
58    }
59
60    fn get_and_reset_execution_count(&self) -> usize {
61        self.execution_counter.fetch_and(0, Ordering::Relaxed)
62    }
63
64    fn get_and_reset_total_execution_time_ns(&self) -> usize {
65        self.total_execution_time_ns.fetch_and(0, Ordering::Relaxed)
66    }
67}
68
69/// The thread pool of a lambda-channel that spawns threads running an infinite loop.
70/// The thread loop waits for a message from the input channel, executes a provided function,
71/// and sends the result to the output channel if the function has an output.
72///
73/// The pool starts with one control thread, all additional threads are normal worker threads.
74/// The control thread is identical to a normal worker thread, but also handles metrics collection, pool resizing, and termination propagation.
75///
76/// If the pool is dropped, the threads will automatically terminate.
77/// If the input channel or output channel to the thread pool disconnects, the threads will automatically terminate.
78/// If the thread is executing on an input or waiting to send an output when termination is triggered, it will finish execution and send the
79/// output value before terminating.
80#[derive(Clone)]
81pub struct ThreadPool {
82    desired_threads: Arc<AtomicUsize>,
83    control_tx: crossbeam_channel::Sender<u8>,
84    metrics_rx: genzero::Receiver<Metrics>,
85}
86
87impl ThreadPool {
88    pub(super) fn new_lambda_pool<
89        T: Send + 'static,
90        U: Send + 'static,
91        V: Clone + Send + 'static,
92    >(
93        input_channel: Receiver<T>,
94        output_channel: Sender<U>,
95        shared_resource: V,
96        function: fn(&V, T) -> U,
97    ) -> Self {
98        let desired_threads = Arc::new(AtomicUsize::new(1));
99        let (control_tx, metrics_rx) = spawn_primary_lambda_thread(
100            input_channel,
101            output_channel,
102            shared_resource,
103            function,
104            desired_threads.clone(),
105        );
106
107        Self {
108            desired_threads,
109            control_tx,
110            metrics_rx,
111        }
112    }
113
114    pub(super) fn new_sink_pool<T: Send + 'static, V: Clone + Send + 'static>(
115        input_channel: Receiver<T>,
116        shared_resource: V,
117        function: fn(&V, T),
118    ) -> Self {
119        let desired_threads = Arc::new(AtomicUsize::new(1));
120        let (control_tx, metrics_rx) = spawn_primary_sink_thread(
121            input_channel,
122            shared_resource,
123            function,
124            desired_threads.clone(),
125        );
126
127        Self {
128            desired_threads,
129            control_tx,
130            metrics_rx,
131        }
132    }
133
134    /// Returns the target number of threads in the pool. The actual number of threads may temporarily differ
135    /// when resizing the pool.
136    pub fn get_pool_size(&self) -> usize {
137        self.desired_threads.load(Ordering::Acquire)
138    }
139
140    /// Sets the target number of threads in the pool. The actual number of threads may temporarily differ
141    /// when resizing the pool.
142    ///
143    /// This function will return the desired thread size if successful, or a [`ThreadPoolError`] if not.
144    pub fn set_pool_size(&self, n: usize) -> Result<usize, ThreadPoolError> {
145        if n < 1 {
146            return Err(ThreadPoolError::ValueError);
147        }
148        self.desired_threads.store(n, Ordering::Relaxed);
149
150        match self.control_tx.send(UPDATE_SIZE) {
151            Ok(_) => Ok(n),
152            Err(_) => Err(ThreadPoolError::ThreadsLost),
153        }
154    }
155
156    /// Returns the latest metric values, which is updated approx. every 10 seconds.
157    /// Calling this function between updates, will return the same values.
158    pub fn get_metrics(&self) -> Metrics {
159        self.metrics_rx.recv().unwrap()
160    }
161}
162
163impl fmt::Display for ThreadPool {
164    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165        "Lambda Channel Thread Pool".fmt(f)
166    }
167}
168
169fn spawn_primary_lambda_thread<T: Send + 'static, U: Send + 'static, V: Clone + Send + 'static>(
170    input_channel: Receiver<T>,
171    output_channel: Sender<U>,
172    shared_resource: V,
173    function: fn(&V, T) -> U,
174    desired_threads: Arc<AtomicUsize>,
175) -> (crossbeam_channel::Sender<u8>, genzero::Receiver<Metrics>) {
176    let (control_tx, control_rx) = bounded(0);
177    let (mut metrics_tx, metrics_rx) = genzero::new(Metrics::default());
178
179    spawn(move || {
180        let mut threads = Vec::new();
181        let ticker = tick(Duration::from_secs(10));
182        let execution_metrics = Arc::new(ExecutionMetrics::new());
183        let input_channel_capacity = input_channel.capacity();
184        let output_channel_capacity = output_channel.capacity();
185
186        'main: loop {
187            select_biased! {
188                recv(output_channel.liveness_check) -> _ => {
189                    break 'main;
190                },
191                recv(control_rx) -> c => {
192                    let command = match c {
193                        Ok(v) => v,
194                        Err(_) => {
195                            break 'main;
196                        }
197                    };
198
199                    match command {
200                        UPDATE_SIZE => {
201                            let target = desired_threads.load(Ordering::Relaxed);
202                            let current = threads.len() + 1;
203
204                            if current < target {
205                                for _ in 0..target-current {
206                                    threads.push(spawn_worker_lambda_thread(
207                                        input_channel.clone(),
208                                        output_channel.clone(),
209                                        shared_resource.clone(),
210                                        function,
211                                        execution_metrics.clone(),
212                                    ));
213                                }
214                            } else {
215                                for _ in 0..current-target {
216                                    let (control_tx, _) = threads.pop().unwrap();
217                                    let _ = control_tx.send(STOP_THREAD);
218                                }
219                            }
220                        },
221                        STOP_THREAD => {
222                            break 'main;
223                        },
224                        _ => {}
225                    }
226                },
227                recv(ticker) -> _ => {
228                    let thread_count = threads.len();
229                    threads.retain(|thread| {
230                        let delete = thread.1.is_finished();
231                        !delete
232                    });
233                    let failed_threads = thread_count - threads.len();
234
235                    for _ in 0..failed_threads {
236                        threads.push(spawn_worker_lambda_thread(
237                            input_channel.clone(),
238                            output_channel.clone(),
239                            shared_resource.clone(),
240                            function,
241                            execution_metrics.clone(),
242                        ));
243                    }
244
245                    let execution_count = execution_metrics.get_and_reset_execution_count();
246                    let average_execution_duration_ns = execution_metrics.get_and_reset_total_execution_time_ns() / execution_count;
247                    metrics_tx.send(Metrics{
248                        active_threads: threads.len(),
249                        input_channel_len: input_channel.len(),
250                        input_channel_capacity,
251                        output_channel_len: output_channel.len(),
252                        output_channel_capacity,
253
254                        execution_count,
255                        average_execution_duration_ns,
256                        // minimum_execution_duration_ns: 0,
257                        // maximum_execution_duration_ns: 0,
258                    });
259                },
260                recv(input_channel.receiver) -> msg => {
261                    let input = match msg {
262                        Ok(v) => v,
263                        Err(_) => {
264                            break 'main;
265                        }
266                    };
267
268                    let start_time = execution_metrics.clock.now();
269                    let output = function(&shared_resource, input);
270                    let execution_time = start_time.elapsed().as_nanos() as usize;
271
272                    'inner: loop {
273                        select! {
274                            recv(control_rx) -> c => {
275                                let command = match c {
276                                    Ok(v) => v,
277                                    Err(_) => {
278                                        break 'main;
279                                    }
280                                };
281
282                                match command {
283                                    UPDATE_SIZE => {
284                                        let target = desired_threads.load(Ordering::Relaxed);
285                                        let current = threads.len() + 1;
286
287                                        if current < target {
288                                            for _ in 0..target-current {
289                                                threads.push(spawn_worker_lambda_thread(
290                                                    input_channel.clone(),
291                                                    output_channel.clone(),
292                                                    shared_resource.clone(),
293                                                    function,
294                                                    execution_metrics.clone(),
295                                                ));
296                                            }
297                                        } else {
298                                            for _ in 0..current-target {
299                                                let (control_tx, _) = threads.pop().unwrap();
300                                                let _ = control_tx.send(STOP_THREAD);
301                                            }
302                                        }
303                                    },
304                                    STOP_THREAD => {
305                                        break 'main;
306                                    },
307                                    _ => {}
308                                }
309                            },
310                            recv(ticker) -> _ => {
311                                let thread_count = threads.len();
312                                threads.retain(|thread| {
313                                    let delete = thread.1.is_finished();
314                                    !delete
315                                });
316                                let failed_threads = thread_count - threads.len();
317
318                                for _ in 0..failed_threads {
319                                    threads.push(spawn_worker_lambda_thread(
320                                        input_channel.clone(),
321                                        output_channel.clone(),
322                                        shared_resource.clone(),
323                                        function,
324                                        execution_metrics.clone(),
325                                    ));
326                                }
327
328                                let execution_count = execution_metrics.get_and_reset_execution_count();
329                                let average_execution_duration_ns = execution_metrics.get_and_reset_total_execution_time_ns() / execution_count;
330                                metrics_tx.send(Metrics{
331                                    active_threads: threads.len(),
332                                    input_channel_len: input_channel.len(),
333                                    input_channel_capacity,
334                                    output_channel_len: output_channel.len(),
335                                    output_channel_capacity,
336
337                                    execution_count,
338                                    average_execution_duration_ns,
339                                    // minimum_execution_duration_ns: 0,
340                                    // maximum_execution_duration_ns: 0,
341                                });
342                            },
343                            send(output_channel.sender, output) -> result => {
344                                match result {
345                                    Ok(_) => {
346                                        execution_metrics.update(execution_time);
347                                        break 'inner;
348                                    }
349                                    Err(_) => {
350                                        break 'main;
351                                    }
352                                }
353                            }
354                        }
355                    }
356                }
357            }
358        }
359    });
360
361    (control_tx, metrics_rx)
362}
363
364fn spawn_worker_lambda_thread<T: Send + 'static, U: Send + 'static, V: Clone + Send + 'static>(
365    input_channel: Receiver<T>,
366    output_channel: Sender<U>,
367    shared_resource: V,
368    function: fn(&V, T) -> U,
369    execution_metrics: Arc<ExecutionMetrics>,
370) -> (crossbeam_channel::Sender<u8>, JoinHandle<()>) {
371    let (control_tx, control_rx) = bounded(0);
372
373    let handle = spawn(move || 'main: loop {
374        select_biased! {
375            recv(output_channel.liveness_check) -> _ => {
376                break 'main;
377            },
378            recv(control_rx) -> c => {
379                let command = match c {
380                    Ok(v) => v,
381                    Err(_) => {
382                        break 'main;
383                    }
384                };
385
386                if command == STOP_THREAD {
387                    break 'main;
388                }
389            },
390            recv(input_channel.receiver) -> msg => {
391                let input = match msg {
392                    Ok(v) => v,
393                    Err(_) => {
394                        break 'main;
395                    }
396                };
397
398                let start_time = execution_metrics.clock.now();
399                let output = function(&shared_resource, input);
400                let execution_time = start_time.elapsed().as_nanos() as usize;
401
402                'inner: loop {
403                    select! {
404                        recv(control_rx) -> c => {
405                            let command = match c {
406                                Ok(v) => v,
407                                Err(_) => {
408                                    drop(input_channel);
409                                    let _ = output_channel.send(output);
410                                    break 'main;
411                                }
412                            };
413
414                            if command == STOP_THREAD {
415                                drop(input_channel);
416                                let _ = output_channel.send(output);
417                                break 'main;
418                            }
419                        },
420                        send(output_channel.sender, output) -> result => {
421                            match result {
422                                Ok(_) => {
423                                    execution_metrics.update(execution_time);
424                                    break 'inner;
425                                }
426                                Err(_) => {
427                                    break 'main;
428                                }
429                            }
430                        }
431                    }
432                }
433            }
434        }
435    });
436    (control_tx, handle)
437}
438
439fn spawn_primary_sink_thread<T: Send + 'static, V: Clone + Send + 'static>(
440    input_channel: Receiver<T>,
441    shared_resource: V,
442    function: fn(&V, T),
443    desired_threads: Arc<AtomicUsize>,
444) -> (crossbeam_channel::Sender<u8>, genzero::Receiver<Metrics>) {
445    let (control_tx, control_rx) = bounded(0);
446    let (mut metrics_tx, metrics_rx) = genzero::new(Metrics::default());
447
448    spawn(move || {
449        let mut threads = Vec::new();
450        let ticker = tick(Duration::from_secs(10));
451        let execution_metrics = Arc::new(ExecutionMetrics::new());
452        let input_channel_capacity = input_channel.capacity();
453        let output_channel_capacity = None;
454
455        'main: loop {
456            select_biased! {
457                recv(control_rx) -> c => {
458                    let command = match c {
459                        Ok(v) => v,
460                        Err(_) => {
461                            break 'main;
462                        }
463                    };
464
465                    match command {
466                        UPDATE_SIZE => {
467                            let target = desired_threads.load(Ordering::Relaxed);
468                            let current = threads.len() + 1;
469
470                            if current < target {
471                                for _ in 0..target-current {
472                                    threads.push(spawn_worker_sink_thread(
473                                        input_channel.clone(),
474                                        shared_resource.clone(),
475                                        function,
476                                        execution_metrics.clone(),
477                                    ));
478                                }
479                            } else {
480                                for _ in 0..current-target {
481                                    let (control_tx, _) = threads.pop().unwrap();
482                                    let _ = control_tx.send(STOP_THREAD);
483                                }
484                            }
485                        },
486                        STOP_THREAD => {
487                            break 'main;
488                        },
489                        _ => {}
490                    }
491                },
492                recv(ticker) -> _ => {
493                    let thread_count = threads.len();
494                    threads.retain(|thread| {
495                        let delete = thread.1.is_finished();
496                        !delete
497                    });
498                    let failed_threads = thread_count - threads.len();
499
500                    for _ in 0..failed_threads {
501                        threads.push(spawn_worker_sink_thread(
502                            input_channel.clone(),
503                            shared_resource.clone(),
504                            function,
505                            execution_metrics.clone(),
506                        ));
507                    }
508
509                    let execution_count = execution_metrics.get_and_reset_execution_count();
510                    let average_execution_duration_ns = execution_metrics.get_and_reset_total_execution_time_ns() / execution_count;
511                    metrics_tx.send(Metrics{
512                        active_threads: threads.len(),
513                        input_channel_len: input_channel.len(),
514                        input_channel_capacity,
515                        output_channel_len: 0,
516                        output_channel_capacity,
517
518                        execution_count,
519                        average_execution_duration_ns,
520                        // minimum_execution_duration_ns: 0,
521                        // maximum_execution_duration_ns: 0,
522                    });
523                },
524                recv(input_channel.receiver) -> msg => {
525                    let input = match msg {
526                        Ok(v) => v,
527                        Err(_) => {
528                            break 'main;
529                        }
530                    };
531
532                    let start_time = execution_metrics.clock.now();
533                    function(&shared_resource, input);
534                    let execution_time = start_time.elapsed().as_nanos() as usize;
535                    execution_metrics.update(execution_time);
536                }
537            }
538        }
539    });
540
541    (control_tx, metrics_rx)
542}
543
544fn spawn_worker_sink_thread<T: Send + 'static, V: Clone + Send + 'static>(
545    input_channel: Receiver<T>,
546    shared_resource: V,
547    function: fn(&V, T),
548    execution_metrics: Arc<ExecutionMetrics>,
549) -> (crossbeam_channel::Sender<u8>, JoinHandle<()>) {
550    let (control_tx, control_rx) = bounded(0);
551
552    let handle = spawn(move || 'main: loop {
553        select_biased! {
554            recv(control_rx) -> c => {
555                let command = match c {
556                    Ok(v) => v,
557                    Err(_) => {
558                        break 'main;
559                    }
560                };
561
562                if command == STOP_THREAD {
563                    break 'main;
564                }
565            },
566            recv(input_channel.receiver) -> msg => {
567                let input = match msg {
568                    Ok(v) => v,
569                    Err(_) => {
570                        break 'main;
571                    }
572                };
573
574                let start_time = execution_metrics.clock.now();
575                function(&shared_resource, input);
576                let execution_time = start_time.elapsed().as_nanos() as usize;
577                execution_metrics.update(execution_time);
578            }
579        }
580    });
581    (control_tx, handle)
582}
583
584#[cfg(test)]
585mod tests {
586    use crate::new_lambda_channel;
587
588    use super::*;
589    use std::thread::sleep;
590
591    fn simple_task(_: &Option<()>, x: u32) -> f32 {
592        x as f32
593    }
594
595    fn io_task(_: &Option<()>, x: u32) -> f32 {
596        sleep(Duration::from_millis(10));
597        (x as f32) / 3.0
598    }
599
600    #[test]
601    fn single_worker() {
602        let tasks = 100usize;
603        let capacity = 10;
604        let (tx, rx, _pool) = new_lambda_channel(Some(capacity), Some(capacity), None, simple_task);
605
606        spawn(move || {
607            for i in 0..tasks {
608                tx.send(i as u32).unwrap();
609            }
610        });
611
612        let mut c = 0usize;
613        while rx.recv().is_ok() {
614            c += 1;
615        }
616
617        assert_eq!(c, tasks);
618    }
619
620    #[test]
621    fn many_workers() {
622        let tasks = 100usize;
623        let capacity = 10;
624        let (tx, rx, pool) = new_lambda_channel(Some(capacity), Some(capacity), None, io_task);
625        assert_eq!(pool.set_pool_size(4), Ok(4));
626
627        let clock = Clock::new();
628        let start = clock.now();
629        spawn(move || {
630            for i in 0..tasks {
631                tx.send(i as u32).unwrap();
632            }
633        });
634
635        let mut c = 0usize;
636        while rx.recv().is_ok() {
637            c += 1;
638        }
639
640        assert!(start.elapsed() < Duration::from_millis(4 * (tasks as u64)));
641        assert_eq!(c, tasks);
642    }
643
644    #[test]
645    fn drop_input_tx() {
646        let capacity = 10;
647        let (tx, rx, pool) = new_lambda_channel(Some(capacity), Some(capacity), None, simple_task);
648        assert_eq!(pool.set_pool_size(4), Ok(4));
649
650        for i in 0..(2 * capacity) {
651            tx.send(i as u32).unwrap();
652        }
653
654        sleep(Duration::from_millis(1));
655
656        // The 4 members are holding 4 values.
657        assert_eq!(tx.len(), 6);
658        assert!(rx.is_full());
659
660        // Test recruit while blocked.
661        assert_eq!(pool.set_pool_size(6), Ok(6));
662
663        sleep(Duration::from_millis(1));
664
665        // The 6 members are holding 6 values.
666        assert_eq!(tx.len(), 4);
667        assert!(rx.is_full());
668
669        drop(tx);
670
671        let mut c = 0usize;
672        while rx.recv().is_ok() {
673            c += 1;
674        }
675
676        assert_eq!(c, 2 * capacity);
677    }
678
679    #[test]
680    fn drop_output_rx() {
681        let capacity = 10;
682        let (tx, rx, pool) = new_lambda_channel(Some(capacity), Some(capacity), None, simple_task);
683        assert_eq!(pool.set_pool_size(4), Ok(4));
684        drop(rx);
685
686        let mut c = 0;
687        while tx.send(0).is_ok() {
688            c += 1;
689        }
690
691        assert_eq!(c, 0);
692        assert_eq!(tx.len(), c);
693    }
694
695    #[test]
696    fn thrash_pool_size() {
697        let tasks = 100usize;
698        let capacity = 10;
699        let (tx, rx, pool) = new_lambda_channel(Some(capacity), Some(capacity), None, simple_task);
700        assert_eq!(pool.set_pool_size(4), Ok(4));
701
702        spawn(move || {
703            for i in 0..tasks {
704                tx.send(i as u32).unwrap();
705            }
706        });
707
708        let mut c = 0;
709        while rx.recv().is_ok() {
710            c += 1;
711            if c >= 10 {
712                break;
713            }
714        }
715
716        assert_eq!(pool.set_pool_size(6), Ok(6));
717        while rx.recv().is_ok() {
718            c += 1;
719            if c >= 20 {
720                break;
721            }
722        }
723        assert_eq!(pool.set_pool_size(3), Ok(3));
724        while rx.recv().is_ok() {
725            c += 1;
726            if c >= 30 {
727                break;
728            }
729        }
730        assert_eq!(pool.set_pool_size(5), Ok(5));
731        sleep(Duration::from_millis(10));
732        assert_eq!(pool.set_pool_size(2), Ok(2));
733        while rx.recv().is_ok() {
734            c += 1;
735            if c >= 50 {
736                break;
737            }
738        }
739        assert_eq!(pool.set_pool_size(1), Ok(1));
740
741        while rx.recv().is_ok() {
742            c += 1;
743        }
744
745        assert_eq!(c, tasks);
746    }
747}