assemble_core/
work_queue.rs

1//! The work queue allows for submission and completion of work.
2//!
3//! There are two ways of interacting with the worker executor.
4//! 1. By directly interfacing with [`WorkerExecutor`](WorkerExecutor) instance.
5//! 2. Using a [`WorkerQueue`](WorkerQueue), which allows for easy handling of multiple requests.
6
7use crate::error::PayloadError;
8
9use crate::project::error::ProjectError;
10use crossbeam::channel::{bounded, unbounded, Receiver, SendError, Sender, TryRecvError};
11use crossbeam::deque::{Injector, Steal, Stealer, Worker};
12
13use std::any::Any;
14use std::collections::HashMap;
15
16use std::marker::PhantomData;
17
18use std::sync::Arc;
19use std::thread::JoinHandle;
20
21use std::{io, panic, thread};
22use uuid::Uuid;
23
24/// A Work Token is a single unit of work done within the Work Queue. Can be built using a [WorkTokenBuilder](WorkTokenBuilder)
25pub struct WorkToken {
26    pub on_start: Box<dyn Fn() + Send + 'static>,
27    pub on_complete: Box<dyn Fn() + Send + 'static>,
28    pub work: Box<dyn FnOnce() + Send + 'static>,
29}
30
31impl WorkToken {
32    fn new(
33        on_start: Box<dyn Fn() + Send + 'static>,
34        on_complete: Box<dyn Fn() + Send + 'static>,
35        work: Box<dyn FnOnce() + Send + 'static>,
36    ) -> Self {
37        Self {
38            on_start,
39            on_complete,
40            work,
41        }
42    }
43}
44
45pub trait ToWorkToken: Send + 'static {
46    fn on_start(&self) -> Box<dyn Fn() + Send + Sync> {
47        Box::new(|| {})
48    }
49    fn on_complete(&self) -> Box<dyn Fn() + Send + Sync> {
50        Box::new(|| {})
51    }
52    fn work(self);
53}
54
55impl<T: ToWorkToken> From<T> for WorkToken {
56    fn from(tok: T) -> Self {
57        let on_start = tok.on_start();
58        let on_complete = tok.on_complete();
59        WorkTokenBuilder::new(|| tok.work())
60            .on_start(on_start)
61            .on_complete(on_complete)
62            .build()
63    }
64}
65
66impl<F: FnOnce() + Send + 'static> ToWorkToken for F {
67    fn work(self) {
68        (self)()
69    }
70}
71
72fn empty() {}
73
74/// Builds [`WorkToken`s](WorkToken) for the work queue. Both on_start and on_complete are optional.
75///
76/// # Example
77/// ```rust
78/// # use assemble_core::work_queue::{WorkToken, WorkTokenBuilder};
79/// let token: WorkToken = WorkTokenBuilder::new(|| { }).build(); // valid
80/// let token: WorkToken = WorkTokenBuilder::new(|| { })
81///     .on_complete(|| { })
82///     .on_start(|| { })
83///     .build()
84///     ;
85/// ```
86pub struct WorkTokenBuilder<W, S, C>
87where
88    W: FnOnce(),
89{
90    on_start: S,
91    on_complete: C,
92    work: W,
93}
94
95impl<W, S, C> WorkTokenBuilder<W, S, C>
96where
97    W: FnOnce() + Send + 'static,
98    S: Fn() + Send + 'static,
99    C: Fn() + Send + 'static,
100{
101    pub fn build(self) -> WorkToken {
102        WorkToken::new(
103            Box::new(self.on_start),
104            Box::new(self.on_complete),
105            Box::new(self.work),
106        )
107    }
108}
109
110impl<W> WorkTokenBuilder<W, fn(), fn()>
111where
112    W: FnOnce() + Send + 'static,
113{
114    /// Create a new [`WorkTokenBuilder`](WorkTokenBuilder) to construct new [`WorkToken`s](WorkToken)
115    pub fn new(work: W) -> Self {
116        Self {
117            on_start: empty,
118            on_complete: empty,
119            work,
120        }
121    }
122}
123
124impl<W, S1, C> WorkTokenBuilder<W, S1, C>
125where
126    W: FnOnce(),
127{
128    pub fn on_start<S2: Fn() + Send + 'static>(self, on_start: S2) -> WorkTokenBuilder<W, S2, C> {
129        WorkTokenBuilder {
130            on_start,
131            on_complete: self.on_complete,
132            work: self.work,
133        }
134    }
135}
136
137impl<W, S, C1> WorkTokenBuilder<W, S, C1>
138where
139    W: FnOnce(),
140{
141    pub fn on_complete<C2: Fn() + Send + 'static>(
142        self,
143        on_complete: C2,
144    ) -> WorkTokenBuilder<W, S, C2> {
145        WorkTokenBuilder {
146            on_complete,
147            on_start: self.on_start,
148            work: self.work,
149        }
150    }
151}
152
153enum WorkerQueueRequest {
154    GetStatus,
155}
156
157enum WorkerQueueResponse {
158    Status(HashMap<Uuid, WorkerStatus>),
159}
160
161#[derive(Debug, Eq, PartialEq)]
162enum WorkerMessage {
163    Stop,
164}
165
166type WorkTokenId = u64;
167
168/// A worker queue allows for the submission of work to be done in parallel.
169pub struct WorkerExecutor {
170    max_jobs: usize,
171    injector: Arc<Injector<WorkerTuple>>,
172    connection: Option<Connection>,
173}
174
175struct Connection {
176    join_send: Sender<()>,
177    inner_handle: JoinHandle<()>,
178
179    request_sender: Sender<WorkerQueueRequest>,
180    response_receiver: Receiver<WorkerQueueResponse>,
181}
182
183impl Connection {
184    fn handle_request(&self, request: WorkerQueueRequest) -> WorkerQueueResponse {
185        self.request_sender.send(request).unwrap();
186        self.response_receiver.recv().unwrap()
187    }
188}
189
190impl Drop for WorkerExecutor {
191    fn drop(&mut self) {
192        self.join_inner();
193    }
194}
195
196impl WorkerExecutor {
197    pub fn new(pool_size: usize) -> io::Result<Self> {
198        let mut out = Self {
199            max_jobs: pool_size,
200            injector: Arc::new(Injector::new()),
201            connection: None,
202        };
203        out.start()?;
204        Ok(out)
205    }
206
207    /// Can be used to restart a joined worker queue
208    fn start(&mut self) -> io::Result<()> {
209        self.connection = Some(Inner::start(&self.injector, self.max_jobs)?);
210        Ok(())
211    }
212
213    /// Waits for all workers to finish. Unlike the drop implementation, Calls [`finish_jobs`](WorkerQueue::finish_jobs).
214    pub fn join(mut self) -> Result<(), PayloadError<ProjectError>> {
215        self.finish_jobs().map_err(PayloadError::new)?;
216        self.join_inner().map_err(PayloadError::new)?;
217        Ok(())
218    }
219
220    fn join_inner(&mut self) -> thread::Result<()> {
221        if let Some(connection) = std::mem::replace(&mut self.connection, None) {
222            let _ = connection.join_send.send(());
223            connection.inner_handle.join()?;
224        };
225        Ok(())
226    }
227
228    /// Forces all running tokens to end.
229    /// Submit some work to the Worker Queue.
230    ///
231    pub fn submit<I: Into<WorkToken>>(&self, token: I) -> io::Result<WorkHandle> {
232        let work_token = token.into();
233
234        let (handle, channel) = work_channel(self);
235        let id = rand::random();
236        let work_tuple = WorkerTuple(id, work_token, channel);
237        self.injector.push(work_tuple);
238        Ok(handle)
239    }
240
241    pub fn any_panicked(&self) -> bool {
242        let status = self
243            .connection
244            .as_ref()
245            .map(|s| s.handle_request(WorkerQueueRequest::GetStatus));
246        match status {
247            Some(WorkerQueueResponse::Status(status)) => {
248                status.values().any(|s| s == &WorkerStatus::Panic)
249            }
250            None => false,
251        }
252    }
253
254    /// Wait for all current jobs to finish.
255    pub fn finish_jobs(&mut self) -> io::Result<()> {
256        if self.connection.is_none() {
257            panic!("Shouldn't be possible")
258        }
259
260        loop {
261            if self.injector.is_empty() {
262                break;
263            }
264        }
265
266        while let Some(connection) = &self.connection {
267            // thread::sleep(Duration::from_millis(100));
268            let status = connection.handle_request(WorkerQueueRequest::GetStatus);
269            let finished = match status {
270                WorkerQueueResponse::Status(s) => s
271                    .values()
272                    .all(|status| status == &WorkerStatus::Idle || status == &WorkerStatus::Panic),
273            };
274            if finished {
275                break;
276            }
277        }
278        Ok(())
279    }
280
281    /// Create a worker queue instance.
282    pub fn queue(&self) -> WorkerQueue {
283        WorkerQueue::new(self)
284    }
285}
286
287struct Inner {
288    max_jobs: usize,
289    injector: Arc<Injector<WorkerTuple>>,
290    worker: Worker<WorkerTuple>,
291    message_sender: Sender<WorkerMessage>,
292    status_receiver: Receiver<WorkStatusUpdate>,
293    stop_receiver: Receiver<()>,
294    handles: Vec<JoinHandle<()>>,
295    id_to_status: HashMap<Uuid, WorkerStatus>,
296
297    request_recv: Receiver<WorkerQueueRequest>,
298    response_sndr: Sender<WorkerQueueResponse>,
299}
300
301#[derive(Clone)]
302pub struct WorkHandle<'exec> {
303    recv: Receiver<()>,
304    owner: &'exec WorkerExecutor,
305}
306
307/// Creates a work handle and it's corresponding sender
308fn work_channel(exec: &WorkerExecutor) -> (WorkHandle, Sender<()>) {
309    let (s, r) = bounded::<()>(1);
310    (
311        WorkHandle {
312            recv: r,
313            owner: exec,
314        },
315        s,
316    )
317}
318
319impl WorkHandle<'_> {
320    /// Joins the work handle
321    pub fn join(self) -> thread::Result<()> {
322        self.recv
323            .recv()
324            .map_err(|b| Box::new(b) as Box<dyn Any + Send>)
325    }
326}
327
328mod inner_impl {
329    use super::*;
330    impl Inner {
331        /// Create a Worker queue with a set maximum amount of workers.
332        fn new(
333            injector: &Arc<Injector<WorkerTuple>>,
334            pool_size: usize,
335            stop_recv: Receiver<()>,
336        ) -> io::Result<(
337            Self,
338            Sender<WorkerQueueRequest>,
339            Receiver<WorkerQueueResponse>,
340        )> {
341            let (s, r) = unbounded();
342            let (s2, r2) = unbounded();
343
344            let requests = unbounded();
345            let responses = unbounded();
346
347            let mut output = Self {
348                max_jobs: pool_size,
349                injector: injector.clone(),
350                worker: Worker::new_fifo(),
351                message_sender: s,
352                status_receiver: r2,
353                stop_receiver: stop_recv,
354                handles: vec![],
355                id_to_status: HashMap::new(),
356
357                request_recv: requests.1,
358                response_sndr: responses.0,
359            };
360            for _ in 0..pool_size {
361                let stealer = output.worker.stealer();
362                let (id, handle) = AssembleWorker::new(stealer, r.clone(), s2.clone()).start()?;
363                output.id_to_status.insert(id, WorkerStatus::Unknown);
364                output.handles.push(handle);
365            }
366
367            Ok((output, requests.0, responses.1))
368        }
369
370        pub fn start(
371            injector: &Arc<Injector<WorkerTuple>>,
372            pool_size: usize,
373        ) -> io::Result<Connection> {
374            let (stop_s, stop_r) = unbounded();
375            let (inner, sender, recv) = Self::new(injector, pool_size, stop_r)?;
376
377            let handle = thread::spawn(move || inner.run());
378
379            Ok(Connection {
380                join_send: stop_s,
381                inner_handle: handle,
382                request_sender: sender,
383                response_receiver: recv,
384            })
385        }
386
387        fn run(mut self) {
388            loop {
389                match self.stop_receiver.try_recv() {
390                    Ok(()) => break,
391                    Err(TryRecvError::Empty) => {}
392                    Err(_) => break,
393                }
394
395                let _ = self.injector.steal_batch(&self.worker);
396
397                self.update_worker_status();
398                self.handle_requests();
399            }
400            for _ in &self.handles {
401                self.message_sender.send(WorkerMessage::Stop);
402            }
403            for handle in self.handles {
404                handle.join();
405            }
406        }
407
408        fn update_worker_status(&mut self) {
409            while let Ok(status) = self.status_receiver.try_recv() {
410                self.id_to_status.insert(status.worker_id, status.status);
411            }
412        }
413
414        fn handle_requests(&mut self) {
415            while let Ok(req) = self.request_recv.try_recv() {
416                let response = self.on_request(req);
417                self.response_sndr
418                    .send(response)
419                    .expect("Inner still exists while Outer gone")
420            }
421        }
422
423        fn on_request(&mut self, request: WorkerQueueRequest) -> WorkerQueueResponse {
424            match request {
425                WorkerQueueRequest::GetStatus => {
426                    let map = self.id_to_status.clone();
427                    WorkerQueueResponse::Status(map)
428                }
429            }
430        }
431    }
432}
433
434#[derive(Debug, Clone, Eq, PartialEq)]
435enum WorkerStatus {
436    Unknown,
437    TaskRunning(WorkTokenId),
438    Idle,
439    Panic,
440}
441
442struct WorkStatusUpdate {
443    worker_id: Uuid,
444    status: WorkerStatus,
445}
446
447struct AssembleWorker {
448    id: Uuid,
449    stealer: Stealer<WorkerTuple>,
450    message_recv: Receiver<WorkerMessage>,
451    status_send: Sender<WorkStatusUpdate>,
452}
453
454impl Drop for AssembleWorker {
455    fn drop(&mut self) {
456        if thread::panicking() {
457            self.report_status(WorkerStatus::Panic).unwrap()
458        }
459    }
460}
461
462impl AssembleWorker {
463    pub fn new(
464        stealer: Stealer<WorkerTuple>,
465        message_recv: Receiver<WorkerMessage>,
466        status_send: Sender<WorkStatusUpdate>,
467    ) -> Self {
468        let id = Uuid::new_v4();
469        Self {
470            id,
471            stealer,
472            message_recv,
473            status_send,
474        }
475    }
476
477    fn start(mut self) -> io::Result<(Uuid, JoinHandle<()>)> {
478        let id = self.id;
479        self.report_status(WorkerStatus::Idle).unwrap();
480        let handle = thread::Builder::new()
481            .name(format!("Assemble Worker (id = {})", id))
482            .spawn(move || self.run())?;
483        Ok((id, handle))
484    }
485
486    fn run(&mut self) {
487        'outer: loop {
488            match self.message_recv.try_recv() {
489                Ok(msg) => match msg {
490                    WorkerMessage::Stop => break 'outer,
491                },
492                Err(TryRecvError::Empty) => {}
493                Err(_) => break 'outer,
494            }
495
496            if let Steal::Success(tuple) = self.stealer.steal() {
497                let WorkerTuple(id, work, vc) = tuple;
498                self.report_status(WorkerStatus::TaskRunning(id)).unwrap();
499
500                (work.on_start)();
501                (work.work)();
502                (work.on_complete)();
503
504                self.report_status(WorkerStatus::Idle).unwrap();
505
506                match vc.send(()) {
507                    Ok(()) => {}
508                    Err(_e) => {
509                        // occurs only if request handle went out of scope
510                    }
511                }
512            }
513        }
514    }
515
516    fn report_status(&mut self, status: WorkerStatus) -> Result<(), SendError<WorkStatusUpdate>> {
517        self.status_send.send(WorkStatusUpdate {
518            worker_id: self.id,
519            status,
520        })
521    }
522}
523
524struct WorkerTuple(WorkTokenId, WorkToken, Sender<()>);
525
526/// A worker queue is a way of submitting work to a [`WorkerExecutor`](WorkerExecutor).
527///
528/// A task submitted to the worker will get get put into the worker queue immediately.
529/// Dropping a WorkerQueue will force all work handles to be joined.
530pub struct WorkerQueue<'exec> {
531    executor: &'exec WorkerExecutor,
532    handles: Vec<WorkHandle<'exec>>,
533}
534
535impl<'exec> Drop for WorkerQueue<'exec> {
536    fn drop(&mut self) {
537        let handles = self.handles.drain(..);
538        for handle in handles {
539            let _ = handle.join();
540        }
541    }
542}
543
544impl<'exec> WorkerQueue<'exec> {
545    /// Create a new worker queue with a given executor.
546    pub fn new(executor: &'exec WorkerExecutor) -> Self {
547        Self {
548            executor,
549            handles: vec![],
550        }
551    }
552
553    /// Submit some work to do by the queue
554    pub fn submit<W: Into<WorkToken>>(&mut self, work: W) -> io::Result<WorkHandle> {
555        let handle = self.executor.submit(work)?;
556        self.handles.push(handle.clone());
557        Ok(handle)
558    }
559
560    /// Finishes the WorkerQueue by finishing all submitted tasks.
561    pub fn join(mut self) -> thread::Result<()> {
562        for handle in self.handles.drain(..) {
563            handle.join()?;
564        }
565        Ok(())
566    }
567
568    pub fn typed<W: Into<WorkToken>>(self) -> TypedWorkerQueue<'exec, W> {
569        TypedWorkerQueue {
570            _data: PhantomData,
571            queue: self,
572        }
573    }
574}
575
576/// Allows for only submissed of a certain type into the worker queue.
577pub struct TypedWorkerQueue<'exec, W: Into<WorkToken>> {
578    _data: PhantomData<W>,
579    queue: WorkerQueue<'exec>,
580}
581
582impl<'exec, W: Into<WorkToken>> TypedWorkerQueue<'exec, W> {
583    /// Create a new worker queue with a given executor.
584    pub fn new(executor: &'exec WorkerExecutor) -> Self {
585        Self {
586            _data: PhantomData,
587            queue: executor.queue(),
588        }
589    }
590
591    /// Submit some work to do by the queue
592    pub fn submit(&mut self, work: W) -> io::Result<WorkHandle> {
593        self.queue.submit(work)
594    }
595
596    /// Finishes the WorkerQueue by finishing all submitted tasks.
597    pub fn join(self) -> thread::Result<()> {
598        self.queue.join()
599    }
600}
601
602#[cfg(test)]
603mod tests {
604    use crate::work_queue::WorkerExecutor;
605
606    use std::sync::atomic::{AtomicUsize, Ordering};
607    use std::sync::{Arc, Barrier};
608    use std::thread;
609    use std::time::Duration;
610    const WORK_SIZE: usize = 6;
611    #[test]
612    #[ignore]
613    fn parallelism_works() {
614        let mut worker_queue = WorkerExecutor::new(WORK_SIZE).unwrap();
615
616        let _wait_group = Arc::new(Barrier::new(WORK_SIZE));
617        let add_all = Arc::new(AtomicUsize::new(0));
618
619        let mut current_worker = 0;
620
621        for _ in 0..(WORK_SIZE * 2) {
622            let add_all = add_all.clone();
623            let this_worker = current_worker;
624            current_worker += 1;
625            worker_queue
626                .submit(move || {
627                    debug!("running worker thread {}", this_worker);
628                    add_all.fetch_add(1, Ordering::SeqCst);
629                })
630                .unwrap();
631        }
632
633        worker_queue.finish_jobs().unwrap();
634        assert_eq!(add_all.load(Ordering::SeqCst), WORK_SIZE * 2);
635
636        for _ in 0..(WORK_SIZE * 2) {
637            let add_all = add_all.clone();
638            let this_worker = current_worker;
639            current_worker += 1;
640            worker_queue
641                .submit(move || {
642                    debug!("running worker thread {}", this_worker);
643                    add_all.fetch_add(1, Ordering::SeqCst);
644                })
645                .unwrap();
646        }
647
648        worker_queue.join().unwrap();
649
650        assert_eq!(add_all.load(Ordering::SeqCst), WORK_SIZE * 4);
651    }
652
653    #[test]
654    fn worker_queues_provide_protection() {
655        let exec = WorkerExecutor::new(WORK_SIZE).unwrap();
656
657        let accum = Arc::new(AtomicUsize::new(0));
658        {
659            let mut queue = exec.queue();
660            for _i in 0..64 {
661                let accum = accum.clone();
662                queue
663                    .submit(move || {
664                        accum.fetch_add(1, Ordering::Relaxed);
665                    })
666                    .unwrap();
667            }
668
669            // queue should drop here
670        }
671
672        assert_eq!(accum.load(Ordering::Acquire), 64);
673    }
674
675    fn test_executor_pool_size_ensured(pool_size: usize) {
676        let workers_running = Arc::new(AtomicUsize::new(0));
677        let max_workers_running = Arc::new(AtomicUsize::new(0));
678
679        let executor = WorkerExecutor::new(pool_size).unwrap();
680        {
681            let mut queue = executor.queue();
682            for _ in 0..4 * pool_size {
683                let workers_running = workers_running.clone();
684                let max_workers_running = max_workers_running.clone();
685                let _ = queue.submit(move || {
686                    workers_running.fetch_add(1, Ordering::SeqCst);
687                    thread::sleep(Duration::from_millis(100));
688                    let _ = workers_running.fetch_update(
689                        Ordering::SeqCst,
690                        Ordering::SeqCst,
691                        |running| {
692                            let _ = max_workers_running.fetch_update(
693                                Ordering::SeqCst,
694                                Ordering::SeqCst,
695                                |max| {
696                                    if running > max {
697                                        Some(running)
698                                    } else {
699                                        None
700                                    }
701                                },
702                            );
703                            None
704                        },
705                    );
706
707                    workers_running.fetch_sub(1, Ordering::SeqCst);
708                });
709            }
710
711            queue.join().expect("worker task failed :(");
712        }
713
714        let max_workers_running = max_workers_running.load(Ordering::Acquire);
715        println!("max running workers: {}", max_workers_running);
716        assert!(max_workers_running <= pool_size);
717    }
718
719    #[test]
720    fn only_correct_number_of_workers_run() {
721        test_executor_pool_size_ensured(1);
722        test_executor_pool_size_ensured(2);
723        test_executor_pool_size_ensured(4);
724        test_executor_pool_size_ensured(8);
725    }
726
727    #[test]
728    #[ignore]
729    fn can_stop_after_panic() {
730        let executor = WorkerExecutor::new(1).unwrap();
731        let job = executor.submit(|| panic!("WOOH I PANICKED")).unwrap();
732        job.join()
733            .expect_err("Should expect an error because a panic occurred");
734        println!("any panicked = {}", executor.any_panicked());
735        assert!(executor.any_panicked());
736    }
737}