lambda_channel/
thread.rs

1use std::fmt;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::Arc;
4use std::thread::{spawn, JoinHandle};
5use std::time::Duration;
6
7use crossbeam_channel::{bounded, select, select_biased, tick};
8use genzero;
9use quanta::Clock;
10
11use super::channel::*;
12use super::err::ThreadPoolError;
13
14const UPDATE_SIZE: u8 = 0;
15const STOP_THREAD: u8 = 1;
16
17/// Struct containing thread pool metrics, that is updated every approx. every 10 seconds.
18#[derive(Default, Clone, Copy)]
19pub struct Metrics {
20    /// Current number of running threads, which may temporarily differ from [`ThreadPool::get_pool_size`].
21    pub active_threads: usize,
22    pub input_channel_len: usize,
23    pub input_channel_capacity: Option<usize>,
24    pub output_channel_len: usize,
25    pub output_channel_capacity: Option<usize>,
26
27    /// Total executions of the provided function since the last metrics update.
28    /// Inputs that have been executed on, but not sent are not counted until they have sent;
29    pub execution_count: usize,
30    /// Average nano-second execution duration of the provided function since the last metrics update.
31    /// Does not include any time the input/output spends in the channels.
32    pub average_execution_duration_ns: usize,
33    // pub minimum_execution_duration_ns: usize,
34    // pub maximum_execution_duration_ns: usize,
35}
36
37struct ExecutionMetrics {
38    clock: Clock,
39    execution_counter: AtomicUsize,
40    total_execution_time_ns: AtomicUsize,
41    // min_execution_time_ns: AtomicUsize,
42    // max_execution_time_ns: AtomicUsize,
43}
44
45impl ExecutionMetrics {
46    fn new() -> Self {
47        ExecutionMetrics {
48            clock: Clock::new(),
49            execution_counter: AtomicUsize::new(0),
50            total_execution_time_ns: AtomicUsize::new(0),
51        }
52    }
53
54    fn update(&self, execution_time: usize) {
55        self.execution_counter.fetch_add(1, Ordering::Relaxed);
56        self.total_execution_time_ns
57            .fetch_add(execution_time, Ordering::Relaxed);
58    }
59
60    fn get_and_reset_execution_count(&self) -> usize {
61        self.execution_counter.fetch_and(0, Ordering::Relaxed)
62    }
63
64    fn get_and_reset_total_execution_time_ns(&self) -> usize {
65        self.total_execution_time_ns.fetch_and(0, Ordering::Relaxed)
66    }
67}
68
69/// The thread pool of a lambda-channel that spawns threads running an infinite loop.
70/// The thread loop waits for a message from the input channel, executes a provided function,
71/// and sends the result to the output channel if the function has an output.
72///
73/// The pool starts with one control thread, all additional threads are normal worker threads.
74/// The control thread is identical to a normal worker thread, but also handles metrics collection, pool resizing, and termination propagation.
75///
76/// If the pool is dropped, the threads will automatically terminate.
77/// If the input channel or output channel to the thread pool disconnects, the threads will automatically terminate.
78/// If the thread is executing on an input or waiting to send an output when termination is triggered, it will finish execution and send the
79/// output value before terminating.
80#[derive(Clone)]
81pub struct ThreadPool {
82    desired_threads: Arc<AtomicUsize>,
83    control_tx: crossbeam_channel::Sender<u8>,
84    metrics_rx: genzero::Receiver<Metrics>,
85}
86
87impl ThreadPool {
88    pub(super) fn new_lambda_pool<
89        T: Send + 'static,
90        U: Send + 'static,
91        V: Clone + Send + 'static,
92    >(
93        input_channel: Receiver<T>,
94        output_channel: Sender<U>,
95        shared_resource: V,
96        function: fn(&V, T) -> U,
97    ) -> Self {
98        let desired_threads = Arc::new(AtomicUsize::new(1));
99        let (control_tx, metrics_rx) = spawn_primary_lambda_thread(
100            input_channel,
101            output_channel,
102            shared_resource,
103            function,
104            desired_threads.clone(),
105        );
106
107        Self {
108            desired_threads,
109            control_tx,
110            metrics_rx,
111        }
112    }
113
114    pub(super) fn new_sink_pool<T: Send + 'static, V: Clone + Send + 'static>(
115        input_channel: Receiver<T>,
116        shared_resource: V,
117        function: fn(&V, T),
118    ) -> Self {
119        let desired_threads = Arc::new(AtomicUsize::new(1));
120        let (control_tx, metrics_rx) = spawn_primary_sink_thread(
121            input_channel,
122            shared_resource,
123            function,
124            desired_threads.clone(),
125        );
126
127        Self {
128            desired_threads,
129            control_tx,
130            metrics_rx,
131        }
132    }
133
134    /// Returns the target number of threads in the pool. The actual number of threads may temporarily differ
135    /// when resizing the pool.
136    pub fn get_pool_size(&self) -> usize {
137        self.desired_threads.load(Ordering::Acquire)
138    }
139
140    /// Sets the target number of threads in the pool. The actual number of threads may temporarily differ
141    /// when resizing the pool.
142    ///
143    /// This function will return the desired thread size if successful, or a [`ThreadPoolError`] if not.
144    pub fn set_pool_size(&self, n: usize) -> Result<usize, ThreadPoolError> {
145        if n < 1 {
146            return Err(ThreadPoolError::ValueError);
147        }
148        self.desired_threads.store(n, Ordering::Relaxed);
149
150        match self.control_tx.send(UPDATE_SIZE) {
151            Ok(_) => Ok(n),
152            Err(_) => Err(ThreadPoolError::ThreadsLost),
153        }
154    }
155
156    /// Returns the latest metric values, which is updated approx. every 10 seconds.
157    /// Calling this function between updates, will return the same values.
158    pub fn get_metrics(&self) -> Metrics {
159        self.metrics_rx.recv().unwrap()
160    }
161}
162
163impl fmt::Display for ThreadPool {
164    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165        "Lambda Channel Thread Pool".fmt(f)
166    }
167}
168
169fn spawn_primary_lambda_thread<T: Send + 'static, U: Send + 'static, V: Clone + Send + 'static>(
170    input_channel: Receiver<T>,
171    output_channel: Sender<U>,
172    shared_resource: V,
173    function: fn(&V, T) -> U,
174    desired_threads: Arc<AtomicUsize>,
175) -> (crossbeam_channel::Sender<u8>, genzero::Receiver<Metrics>) {
176    let (control_tx, control_rx) = bounded(0);
177    let (mut metrics_tx, metrics_rx) = genzero::new(Metrics::default());
178
179    spawn(move || {
180        let mut threads = Vec::new();
181        let ticker = tick(Duration::from_secs(10));
182        let execution_metrics = Arc::new(ExecutionMetrics::new());
183        let input_channel_capacity = input_channel.capacity();
184        let output_channel_capacity = output_channel.capacity();
185
186        'main: loop {
187            select_biased! {
188                recv(output_channel.liveness_check) -> _ => {
189                    break 'main;
190                },
191                recv(control_rx) -> c => {
192                    let command = match c {
193                        Ok(v) => v,
194                        Err(_) => {
195                            break 'main;
196                        }
197                    };
198
199                    match command {
200                        UPDATE_SIZE => {
201                            let target = desired_threads.load(Ordering::Relaxed);
202                            let current = threads.len() + 1;
203
204                            if current < target {
205                                for _ in 0..target-current {
206                                    threads.push(spawn_worker_lambda_thread(
207                                        input_channel.clone(),
208                                        output_channel.clone(),
209                                        shared_resource.clone(),
210                                        function,
211                                        execution_metrics.clone(),
212                                    ));
213                                }
214                            } else {
215                                for _ in 0..current-target {
216                                    let (control_tx, _) = threads.pop().unwrap();
217                                    let _ = control_tx.send(STOP_THREAD);
218                                }
219                            }
220                        },
221                        STOP_THREAD => {
222                            break 'main;
223                        },
224                        _ => {}
225                    }
226                },
227                recv(ticker) -> _ => {
228                    let thread_count = threads.len();
229                    threads.retain(|thread| {
230                        let delete = thread.1.is_finished();
231                        !delete
232                    });
233                    let failed_threads = thread_count - threads.len();
234
235                    for _ in 0..failed_threads {
236                        threads.push(spawn_worker_lambda_thread(
237                            input_channel.clone(),
238                            output_channel.clone(),
239                            shared_resource.clone(),
240                            function,
241                            execution_metrics.clone(),
242                        ));
243                    }
244
245                    let execution_count = execution_metrics.get_and_reset_execution_count();
246                    let average_execution_duration_ns = match execution_count {
247                        0 => 0,
248                        _ => execution_metrics.get_and_reset_total_execution_time_ns() / execution_count,
249                    };
250
251                    metrics_tx.send(Metrics{
252                        active_threads: threads.len() + 1,
253                        input_channel_len: input_channel.len(),
254                        input_channel_capacity,
255                        output_channel_len: output_channel.len(),
256                        output_channel_capacity,
257
258                        execution_count,
259                        average_execution_duration_ns,
260                        // minimum_execution_duration_ns: 0,
261                        // maximum_execution_duration_ns: 0,
262                    });
263                },
264                recv(input_channel.receiver) -> msg => {
265                    let input = match msg {
266                        Ok(v) => v,
267                        Err(_) => {
268                            break 'main;
269                        }
270                    };
271
272                    let start_time = execution_metrics.clock.now();
273                    let output = function(&shared_resource, input);
274                    let execution_time = start_time.elapsed().as_nanos() as usize;
275
276                    'inner: loop {
277                        select! {
278                            recv(control_rx) -> c => {
279                                let command = match c {
280                                    Ok(v) => v,
281                                    Err(_) => {
282                                        break 'main;
283                                    }
284                                };
285
286                                match command {
287                                    UPDATE_SIZE => {
288                                        let target = desired_threads.load(Ordering::Relaxed);
289                                        let current = threads.len() + 1;
290
291                                        if current < target {
292                                            for _ in 0..target-current {
293                                                threads.push(spawn_worker_lambda_thread(
294                                                    input_channel.clone(),
295                                                    output_channel.clone(),
296                                                    shared_resource.clone(),
297                                                    function,
298                                                    execution_metrics.clone(),
299                                                ));
300                                            }
301                                        } else {
302                                            for _ in 0..current-target {
303                                                let (control_tx, _) = threads.pop().unwrap();
304                                                let _ = control_tx.send(STOP_THREAD);
305                                            }
306                                        }
307                                    },
308                                    STOP_THREAD => {
309                                        break 'main;
310                                    },
311                                    _ => {}
312                                }
313                            },
314                            recv(ticker) -> _ => {
315                                let thread_count = threads.len();
316                                threads.retain(|thread| {
317                                    let delete = thread.1.is_finished();
318                                    !delete
319                                });
320                                let failed_threads = thread_count - threads.len();
321
322                                for _ in 0..failed_threads {
323                                    threads.push(spawn_worker_lambda_thread(
324                                        input_channel.clone(),
325                                        output_channel.clone(),
326                                        shared_resource.clone(),
327                                        function,
328                                        execution_metrics.clone(),
329                                    ));
330                                }
331
332                                let execution_count = execution_metrics.get_and_reset_execution_count();
333                                let average_execution_duration_ns = match execution_count {
334                                    0 => 0,
335                                    _ => execution_metrics.get_and_reset_total_execution_time_ns() / execution_count,
336                                };
337                                metrics_tx.send(Metrics{
338                                    active_threads: threads.len() + 1,
339                                    input_channel_len: input_channel.len(),
340                                    input_channel_capacity,
341                                    output_channel_len: output_channel.len(),
342                                    output_channel_capacity,
343
344                                    execution_count,
345                                    average_execution_duration_ns,
346                                    // minimum_execution_duration_ns: 0,
347                                    // maximum_execution_duration_ns: 0,
348                                });
349                            },
350                            send(output_channel.sender, output) -> result => {
351                                match result {
352                                    Ok(_) => {
353                                        execution_metrics.update(execution_time);
354                                        break 'inner;
355                                    }
356                                    Err(_) => {
357                                        break 'main;
358                                    }
359                                }
360                            }
361                        }
362                    }
363                }
364            }
365        }
366    });
367
368    (control_tx, metrics_rx)
369}
370
371fn spawn_worker_lambda_thread<T: Send + 'static, U: Send + 'static, V: Clone + Send + 'static>(
372    input_channel: Receiver<T>,
373    output_channel: Sender<U>,
374    shared_resource: V,
375    function: fn(&V, T) -> U,
376    execution_metrics: Arc<ExecutionMetrics>,
377) -> (crossbeam_channel::Sender<u8>, JoinHandle<()>) {
378    let (control_tx, control_rx) = bounded(0);
379
380    let handle = spawn(move || 'main: loop {
381        select_biased! {
382            recv(output_channel.liveness_check) -> _ => {
383                break 'main;
384            },
385            recv(control_rx) -> c => {
386                let command = match c {
387                    Ok(v) => v,
388                    Err(_) => {
389                        break 'main;
390                    }
391                };
392
393                if command == STOP_THREAD {
394                    break 'main;
395                }
396            },
397            recv(input_channel.receiver) -> msg => {
398                let input = match msg {
399                    Ok(v) => v,
400                    Err(_) => {
401                        break 'main;
402                    }
403                };
404
405                let start_time = execution_metrics.clock.now();
406                let output = function(&shared_resource, input);
407                let execution_time = start_time.elapsed().as_nanos() as usize;
408
409                'inner: loop {
410                    select! {
411                        recv(control_rx) -> c => {
412                            let command = match c {
413                                Ok(v) => v,
414                                Err(_) => {
415                                    drop(input_channel);
416                                    let _ = output_channel.send(output);
417                                    break 'main;
418                                }
419                            };
420
421                            if command == STOP_THREAD {
422                                drop(input_channel);
423                                let _ = output_channel.send(output);
424                                break 'main;
425                            }
426                        },
427                        send(output_channel.sender, output) -> result => {
428                            match result {
429                                Ok(_) => {
430                                    execution_metrics.update(execution_time);
431                                    break 'inner;
432                                }
433                                Err(_) => {
434                                    break 'main;
435                                }
436                            }
437                        }
438                    }
439                }
440            }
441        }
442    });
443    (control_tx, handle)
444}
445
446fn spawn_primary_sink_thread<T: Send + 'static, V: Clone + Send + 'static>(
447    input_channel: Receiver<T>,
448    shared_resource: V,
449    function: fn(&V, T),
450    desired_threads: Arc<AtomicUsize>,
451) -> (crossbeam_channel::Sender<u8>, genzero::Receiver<Metrics>) {
452    let (control_tx, control_rx) = bounded(0);
453    let (mut metrics_tx, metrics_rx) = genzero::new(Metrics::default());
454
455    spawn(move || {
456        let mut threads = Vec::new();
457        let ticker = tick(Duration::from_secs(10));
458        let execution_metrics = Arc::new(ExecutionMetrics::new());
459        let input_channel_capacity = input_channel.capacity();
460        let output_channel_capacity = None;
461
462        'main: loop {
463            select_biased! {
464                recv(control_rx) -> c => {
465                    let command = match c {
466                        Ok(v) => v,
467                        Err(_) => {
468                            break 'main;
469                        }
470                    };
471
472                    match command {
473                        UPDATE_SIZE => {
474                            let target = desired_threads.load(Ordering::Relaxed);
475                            let current = threads.len() + 1;
476
477                            if current < target {
478                                for _ in 0..target-current {
479                                    threads.push(spawn_worker_sink_thread(
480                                        input_channel.clone(),
481                                        shared_resource.clone(),
482                                        function,
483                                        execution_metrics.clone(),
484                                    ));
485                                }
486                            } else {
487                                for _ in 0..current-target {
488                                    let (control_tx, _) = threads.pop().unwrap();
489                                    let _ = control_tx.send(STOP_THREAD);
490                                }
491                            }
492                        },
493                        STOP_THREAD => {
494                            break 'main;
495                        },
496                        _ => {}
497                    }
498                },
499                recv(ticker) -> _ => {
500                    let thread_count = threads.len();
501                    threads.retain(|thread| {
502                        let delete = thread.1.is_finished();
503                        !delete
504                    });
505                    let failed_threads = thread_count - threads.len();
506
507                    for _ in 0..failed_threads {
508                        threads.push(spawn_worker_sink_thread(
509                            input_channel.clone(),
510                            shared_resource.clone(),
511                            function,
512                            execution_metrics.clone(),
513                        ));
514                    }
515
516                    let execution_count = execution_metrics.get_and_reset_execution_count();
517                    let average_execution_duration_ns = match execution_count {
518                        0 => 0,
519                        _ => execution_metrics.get_and_reset_total_execution_time_ns() / execution_count,
520                    };
521                    metrics_tx.send(Metrics{
522                        active_threads: threads.len() + 1,
523                        input_channel_len: input_channel.len(),
524                        input_channel_capacity,
525                        output_channel_len: 0,
526                        output_channel_capacity,
527
528                        execution_count,
529                        average_execution_duration_ns,
530                        // minimum_execution_duration_ns: 0,
531                        // maximum_execution_duration_ns: 0,
532                    });
533                },
534                recv(input_channel.receiver) -> msg => {
535                    let input = match msg {
536                        Ok(v) => v,
537                        Err(_) => {
538                            break 'main;
539                        }
540                    };
541
542                    let start_time = execution_metrics.clock.now();
543                    function(&shared_resource, input);
544                    let execution_time = start_time.elapsed().as_nanos() as usize;
545                    execution_metrics.update(execution_time);
546                }
547            }
548        }
549    });
550
551    (control_tx, metrics_rx)
552}
553
554fn spawn_worker_sink_thread<T: Send + 'static, V: Clone + Send + 'static>(
555    input_channel: Receiver<T>,
556    shared_resource: V,
557    function: fn(&V, T),
558    execution_metrics: Arc<ExecutionMetrics>,
559) -> (crossbeam_channel::Sender<u8>, JoinHandle<()>) {
560    let (control_tx, control_rx) = bounded(0);
561
562    let handle = spawn(move || 'main: loop {
563        select_biased! {
564            recv(control_rx) -> c => {
565                let command = match c {
566                    Ok(v) => v,
567                    Err(_) => {
568                        break 'main;
569                    }
570                };
571
572                if command == STOP_THREAD {
573                    break 'main;
574                }
575            },
576            recv(input_channel.receiver) -> msg => {
577                let input = match msg {
578                    Ok(v) => v,
579                    Err(_) => {
580                        break 'main;
581                    }
582                };
583
584                let start_time = execution_metrics.clock.now();
585                function(&shared_resource, input);
586                let execution_time = start_time.elapsed().as_nanos() as usize;
587                execution_metrics.update(execution_time);
588            }
589        }
590    });
591    (control_tx, handle)
592}
593
594#[cfg(test)]
595mod tests {
596    use crate::new_lambda_channel;
597
598    use super::*;
599    use std::thread::sleep;
600
601    fn simple_task(_: &Option<()>, x: u32) -> f32 {
602        x as f32
603    }
604
605    fn io_task(_: &Option<()>, x: u32) -> f32 {
606        sleep(Duration::from_millis(10));
607        (x as f32) / 3.0
608    }
609
610    #[test]
611    fn single_worker() {
612        let tasks = 100usize;
613        let capacity = 10;
614        let (tx, rx, _pool) = new_lambda_channel(Some(capacity), Some(capacity), None, simple_task);
615
616        spawn(move || {
617            for i in 0..tasks {
618                tx.send(i as u32).unwrap();
619            }
620        });
621
622        let mut c = 0usize;
623        while rx.recv().is_ok() {
624            c += 1;
625        }
626
627        assert_eq!(c, tasks);
628    }
629
630    #[test]
631    fn many_workers() {
632        let tasks = 100usize;
633        let capacity = 10;
634        let (tx, rx, pool) = new_lambda_channel(Some(capacity), Some(capacity), None, io_task);
635        assert_eq!(pool.set_pool_size(4), Ok(4));
636
637        let clock = Clock::new();
638        let start = clock.now();
639        spawn(move || {
640            for i in 0..tasks {
641                tx.send(i as u32).unwrap();
642            }
643        });
644
645        let mut c = 0usize;
646        while rx.recv().is_ok() {
647            c += 1;
648        }
649
650        assert!(start.elapsed() < Duration::from_millis(4 * (tasks as u64)));
651        assert_eq!(c, tasks);
652    }
653
654    #[test]
655    fn drop_input_tx() {
656        let capacity = 10;
657        let (tx, rx, pool) = new_lambda_channel(Some(capacity), Some(capacity), None, simple_task);
658        assert_eq!(pool.set_pool_size(4), Ok(4));
659
660        for i in 0..(2 * capacity) {
661            tx.send(i as u32).unwrap();
662        }
663
664        sleep(Duration::from_millis(1));
665
666        // The 4 members are holding 4 values.
667        assert_eq!(tx.len(), 6);
668        assert!(rx.is_full());
669
670        // Test recruit while blocked.
671        assert_eq!(pool.set_pool_size(6), Ok(6));
672
673        sleep(Duration::from_millis(1));
674
675        // The 6 members are holding 6 values.
676        assert_eq!(tx.len(), 4);
677        assert!(rx.is_full());
678
679        drop(tx);
680
681        let mut c = 0usize;
682        while rx.recv().is_ok() {
683            c += 1;
684        }
685
686        assert_eq!(c, 2 * capacity);
687    }
688
689    #[test]
690    fn drop_output_rx() {
691        let capacity = 10;
692        let (tx, rx, pool) = new_lambda_channel(Some(capacity), Some(capacity), None, simple_task);
693        assert_eq!(pool.set_pool_size(4), Ok(4));
694        drop(rx);
695
696        let mut c = 0;
697        while tx.send(0).is_ok() {
698            c += 1;
699        }
700
701        assert_eq!(c, 0);
702        assert_eq!(tx.len(), c);
703    }
704
705    #[test]
706    fn thrash_pool_size() {
707        let tasks = 100usize;
708        let capacity = 10;
709        let (tx, rx, pool) = new_lambda_channel(Some(capacity), Some(capacity), None, simple_task);
710        assert_eq!(pool.set_pool_size(4), Ok(4));
711
712        spawn(move || {
713            for i in 0..tasks {
714                tx.send(i as u32).unwrap();
715            }
716        });
717
718        let mut c = 0;
719        while rx.recv().is_ok() {
720            c += 1;
721            if c >= 10 {
722                break;
723            }
724        }
725
726        assert_eq!(pool.set_pool_size(6), Ok(6));
727        while rx.recv().is_ok() {
728            c += 1;
729            if c >= 20 {
730                break;
731            }
732        }
733        assert_eq!(pool.set_pool_size(3), Ok(3));
734        while rx.recv().is_ok() {
735            c += 1;
736            if c >= 30 {
737                break;
738            }
739        }
740        assert_eq!(pool.set_pool_size(5), Ok(5));
741        sleep(Duration::from_millis(10));
742        assert_eq!(pool.set_pool_size(2), Ok(2));
743        while rx.recv().is_ok() {
744            c += 1;
745            if c >= 50 {
746                break;
747            }
748        }
749        assert_eq!(pool.set_pool_size(1), Ok(1));
750
751        while rx.recv().is_ok() {
752            c += 1;
753        }
754
755        assert_eq!(c, tasks);
756    }
757}