1use crate::error::PayloadError;
8
9use crate::project::error::ProjectError;
10use crossbeam::channel::{bounded, unbounded, Receiver, SendError, Sender, TryRecvError};
11use crossbeam::deque::{Injector, Steal, Stealer, Worker};
12
13use std::any::Any;
14use std::collections::HashMap;
15
16use std::marker::PhantomData;
17
18use std::sync::Arc;
19use std::thread::JoinHandle;
20
21use std::{io, panic, thread};
22use uuid::Uuid;
23
24pub struct WorkToken {
26 pub on_start: Box<dyn Fn() + Send + 'static>,
27 pub on_complete: Box<dyn Fn() + Send + 'static>,
28 pub work: Box<dyn FnOnce() + Send + 'static>,
29}
30
31impl WorkToken {
32 fn new(
33 on_start: Box<dyn Fn() + Send + 'static>,
34 on_complete: Box<dyn Fn() + Send + 'static>,
35 work: Box<dyn FnOnce() + Send + 'static>,
36 ) -> Self {
37 Self {
38 on_start,
39 on_complete,
40 work,
41 }
42 }
43}
44
45pub trait ToWorkToken: Send + 'static {
46 fn on_start(&self) -> Box<dyn Fn() + Send + Sync> {
47 Box::new(|| {})
48 }
49 fn on_complete(&self) -> Box<dyn Fn() + Send + Sync> {
50 Box::new(|| {})
51 }
52 fn work(self);
53}
54
55impl<T: ToWorkToken> From<T> for WorkToken {
56 fn from(tok: T) -> Self {
57 let on_start = tok.on_start();
58 let on_complete = tok.on_complete();
59 WorkTokenBuilder::new(|| tok.work())
60 .on_start(on_start)
61 .on_complete(on_complete)
62 .build()
63 }
64}
65
66impl<F: FnOnce() + Send + 'static> ToWorkToken for F {
67 fn work(self) {
68 (self)()
69 }
70}
71
72fn empty() {}
73
74pub struct WorkTokenBuilder<W, S, C>
87where
88 W: FnOnce(),
89{
90 on_start: S,
91 on_complete: C,
92 work: W,
93}
94
95impl<W, S, C> WorkTokenBuilder<W, S, C>
96where
97 W: FnOnce() + Send + 'static,
98 S: Fn() + Send + 'static,
99 C: Fn() + Send + 'static,
100{
101 pub fn build(self) -> WorkToken {
102 WorkToken::new(
103 Box::new(self.on_start),
104 Box::new(self.on_complete),
105 Box::new(self.work),
106 )
107 }
108}
109
110impl<W> WorkTokenBuilder<W, fn(), fn()>
111where
112 W: FnOnce() + Send + 'static,
113{
114 pub fn new(work: W) -> Self {
116 Self {
117 on_start: empty,
118 on_complete: empty,
119 work,
120 }
121 }
122}
123
124impl<W, S1, C> WorkTokenBuilder<W, S1, C>
125where
126 W: FnOnce(),
127{
128 pub fn on_start<S2: Fn() + Send + 'static>(self, on_start: S2) -> WorkTokenBuilder<W, S2, C> {
129 WorkTokenBuilder {
130 on_start,
131 on_complete: self.on_complete,
132 work: self.work,
133 }
134 }
135}
136
137impl<W, S, C1> WorkTokenBuilder<W, S, C1>
138where
139 W: FnOnce(),
140{
141 pub fn on_complete<C2: Fn() + Send + 'static>(
142 self,
143 on_complete: C2,
144 ) -> WorkTokenBuilder<W, S, C2> {
145 WorkTokenBuilder {
146 on_complete,
147 on_start: self.on_start,
148 work: self.work,
149 }
150 }
151}
152
153enum WorkerQueueRequest {
154 GetStatus,
155}
156
157enum WorkerQueueResponse {
158 Status(HashMap<Uuid, WorkerStatus>),
159}
160
161#[derive(Debug, Eq, PartialEq)]
162enum WorkerMessage {
163 Stop,
164}
165
166type WorkTokenId = u64;
167
168pub struct WorkerExecutor {
170 max_jobs: usize,
171 injector: Arc<Injector<WorkerTuple>>,
172 connection: Option<Connection>,
173}
174
175struct Connection {
176 join_send: Sender<()>,
177 inner_handle: JoinHandle<()>,
178
179 request_sender: Sender<WorkerQueueRequest>,
180 response_receiver: Receiver<WorkerQueueResponse>,
181}
182
183impl Connection {
184 fn handle_request(&self, request: WorkerQueueRequest) -> WorkerQueueResponse {
185 self.request_sender.send(request).unwrap();
186 self.response_receiver.recv().unwrap()
187 }
188}
189
190impl Drop for WorkerExecutor {
191 fn drop(&mut self) {
192 self.join_inner();
193 }
194}
195
196impl WorkerExecutor {
197 pub fn new(pool_size: usize) -> io::Result<Self> {
198 let mut out = Self {
199 max_jobs: pool_size,
200 injector: Arc::new(Injector::new()),
201 connection: None,
202 };
203 out.start()?;
204 Ok(out)
205 }
206
207 fn start(&mut self) -> io::Result<()> {
209 self.connection = Some(Inner::start(&self.injector, self.max_jobs)?);
210 Ok(())
211 }
212
213 pub fn join(mut self) -> Result<(), PayloadError<ProjectError>> {
215 self.finish_jobs().map_err(PayloadError::new)?;
216 self.join_inner().map_err(PayloadError::new)?;
217 Ok(())
218 }
219
220 fn join_inner(&mut self) -> thread::Result<()> {
221 if let Some(connection) = std::mem::replace(&mut self.connection, None) {
222 let _ = connection.join_send.send(());
223 connection.inner_handle.join()?;
224 };
225 Ok(())
226 }
227
228 pub fn submit<I: Into<WorkToken>>(&self, token: I) -> io::Result<WorkHandle> {
232 let work_token = token.into();
233
234 let (handle, channel) = work_channel(self);
235 let id = rand::random();
236 let work_tuple = WorkerTuple(id, work_token, channel);
237 self.injector.push(work_tuple);
238 Ok(handle)
239 }
240
241 pub fn any_panicked(&self) -> bool {
242 let status = self
243 .connection
244 .as_ref()
245 .map(|s| s.handle_request(WorkerQueueRequest::GetStatus));
246 match status {
247 Some(WorkerQueueResponse::Status(status)) => {
248 status.values().any(|s| s == &WorkerStatus::Panic)
249 }
250 None => false,
251 }
252 }
253
254 pub fn finish_jobs(&mut self) -> io::Result<()> {
256 if self.connection.is_none() {
257 panic!("Shouldn't be possible")
258 }
259
260 loop {
261 if self.injector.is_empty() {
262 break;
263 }
264 }
265
266 while let Some(connection) = &self.connection {
267 let status = connection.handle_request(WorkerQueueRequest::GetStatus);
269 let finished = match status {
270 WorkerQueueResponse::Status(s) => s
271 .values()
272 .all(|status| status == &WorkerStatus::Idle || status == &WorkerStatus::Panic),
273 };
274 if finished {
275 break;
276 }
277 }
278 Ok(())
279 }
280
281 pub fn queue(&self) -> WorkerQueue {
283 WorkerQueue::new(self)
284 }
285}
286
287struct Inner {
288 max_jobs: usize,
289 injector: Arc<Injector<WorkerTuple>>,
290 worker: Worker<WorkerTuple>,
291 message_sender: Sender<WorkerMessage>,
292 status_receiver: Receiver<WorkStatusUpdate>,
293 stop_receiver: Receiver<()>,
294 handles: Vec<JoinHandle<()>>,
295 id_to_status: HashMap<Uuid, WorkerStatus>,
296
297 request_recv: Receiver<WorkerQueueRequest>,
298 response_sndr: Sender<WorkerQueueResponse>,
299}
300
301#[derive(Clone)]
302pub struct WorkHandle<'exec> {
303 recv: Receiver<()>,
304 owner: &'exec WorkerExecutor,
305}
306
307fn work_channel(exec: &WorkerExecutor) -> (WorkHandle, Sender<()>) {
309 let (s, r) = bounded::<()>(1);
310 (
311 WorkHandle {
312 recv: r,
313 owner: exec,
314 },
315 s,
316 )
317}
318
319impl WorkHandle<'_> {
320 pub fn join(self) -> thread::Result<()> {
322 self.recv
323 .recv()
324 .map_err(|b| Box::new(b) as Box<dyn Any + Send>)
325 }
326}
327
328mod inner_impl {
329 use super::*;
330 impl Inner {
331 fn new(
333 injector: &Arc<Injector<WorkerTuple>>,
334 pool_size: usize,
335 stop_recv: Receiver<()>,
336 ) -> io::Result<(
337 Self,
338 Sender<WorkerQueueRequest>,
339 Receiver<WorkerQueueResponse>,
340 )> {
341 let (s, r) = unbounded();
342 let (s2, r2) = unbounded();
343
344 let requests = unbounded();
345 let responses = unbounded();
346
347 let mut output = Self {
348 max_jobs: pool_size,
349 injector: injector.clone(),
350 worker: Worker::new_fifo(),
351 message_sender: s,
352 status_receiver: r2,
353 stop_receiver: stop_recv,
354 handles: vec![],
355 id_to_status: HashMap::new(),
356
357 request_recv: requests.1,
358 response_sndr: responses.0,
359 };
360 for _ in 0..pool_size {
361 let stealer = output.worker.stealer();
362 let (id, handle) = AssembleWorker::new(stealer, r.clone(), s2.clone()).start()?;
363 output.id_to_status.insert(id, WorkerStatus::Unknown);
364 output.handles.push(handle);
365 }
366
367 Ok((output, requests.0, responses.1))
368 }
369
370 pub fn start(
371 injector: &Arc<Injector<WorkerTuple>>,
372 pool_size: usize,
373 ) -> io::Result<Connection> {
374 let (stop_s, stop_r) = unbounded();
375 let (inner, sender, recv) = Self::new(injector, pool_size, stop_r)?;
376
377 let handle = thread::spawn(move || inner.run());
378
379 Ok(Connection {
380 join_send: stop_s,
381 inner_handle: handle,
382 request_sender: sender,
383 response_receiver: recv,
384 })
385 }
386
387 fn run(mut self) {
388 loop {
389 match self.stop_receiver.try_recv() {
390 Ok(()) => break,
391 Err(TryRecvError::Empty) => {}
392 Err(_) => break,
393 }
394
395 let _ = self.injector.steal_batch(&self.worker);
396
397 self.update_worker_status();
398 self.handle_requests();
399 }
400 for _ in &self.handles {
401 self.message_sender.send(WorkerMessage::Stop);
402 }
403 for handle in self.handles {
404 handle.join();
405 }
406 }
407
408 fn update_worker_status(&mut self) {
409 while let Ok(status) = self.status_receiver.try_recv() {
410 self.id_to_status.insert(status.worker_id, status.status);
411 }
412 }
413
414 fn handle_requests(&mut self) {
415 while let Ok(req) = self.request_recv.try_recv() {
416 let response = self.on_request(req);
417 self.response_sndr
418 .send(response)
419 .expect("Inner still exists while Outer gone")
420 }
421 }
422
423 fn on_request(&mut self, request: WorkerQueueRequest) -> WorkerQueueResponse {
424 match request {
425 WorkerQueueRequest::GetStatus => {
426 let map = self.id_to_status.clone();
427 WorkerQueueResponse::Status(map)
428 }
429 }
430 }
431 }
432}
433
434#[derive(Debug, Clone, Eq, PartialEq)]
435enum WorkerStatus {
436 Unknown,
437 TaskRunning(WorkTokenId),
438 Idle,
439 Panic,
440}
441
442struct WorkStatusUpdate {
443 worker_id: Uuid,
444 status: WorkerStatus,
445}
446
447struct AssembleWorker {
448 id: Uuid,
449 stealer: Stealer<WorkerTuple>,
450 message_recv: Receiver<WorkerMessage>,
451 status_send: Sender<WorkStatusUpdate>,
452}
453
454impl Drop for AssembleWorker {
455 fn drop(&mut self) {
456 if thread::panicking() {
457 self.report_status(WorkerStatus::Panic).unwrap()
458 }
459 }
460}
461
462impl AssembleWorker {
463 pub fn new(
464 stealer: Stealer<WorkerTuple>,
465 message_recv: Receiver<WorkerMessage>,
466 status_send: Sender<WorkStatusUpdate>,
467 ) -> Self {
468 let id = Uuid::new_v4();
469 Self {
470 id,
471 stealer,
472 message_recv,
473 status_send,
474 }
475 }
476
477 fn start(mut self) -> io::Result<(Uuid, JoinHandle<()>)> {
478 let id = self.id;
479 self.report_status(WorkerStatus::Idle).unwrap();
480 let handle = thread::Builder::new()
481 .name(format!("Assemble Worker (id = {})", id))
482 .spawn(move || self.run())?;
483 Ok((id, handle))
484 }
485
486 fn run(&mut self) {
487 'outer: loop {
488 match self.message_recv.try_recv() {
489 Ok(msg) => match msg {
490 WorkerMessage::Stop => break 'outer,
491 },
492 Err(TryRecvError::Empty) => {}
493 Err(_) => break 'outer,
494 }
495
496 if let Steal::Success(tuple) = self.stealer.steal() {
497 let WorkerTuple(id, work, vc) = tuple;
498 self.report_status(WorkerStatus::TaskRunning(id)).unwrap();
499
500 (work.on_start)();
501 (work.work)();
502 (work.on_complete)();
503
504 self.report_status(WorkerStatus::Idle).unwrap();
505
506 match vc.send(()) {
507 Ok(()) => {}
508 Err(_e) => {
509 }
511 }
512 }
513 }
514 }
515
516 fn report_status(&mut self, status: WorkerStatus) -> Result<(), SendError<WorkStatusUpdate>> {
517 self.status_send.send(WorkStatusUpdate {
518 worker_id: self.id,
519 status,
520 })
521 }
522}
523
524struct WorkerTuple(WorkTokenId, WorkToken, Sender<()>);
525
526pub struct WorkerQueue<'exec> {
531 executor: &'exec WorkerExecutor,
532 handles: Vec<WorkHandle<'exec>>,
533}
534
535impl<'exec> Drop for WorkerQueue<'exec> {
536 fn drop(&mut self) {
537 let handles = self.handles.drain(..);
538 for handle in handles {
539 let _ = handle.join();
540 }
541 }
542}
543
544impl<'exec> WorkerQueue<'exec> {
545 pub fn new(executor: &'exec WorkerExecutor) -> Self {
547 Self {
548 executor,
549 handles: vec![],
550 }
551 }
552
553 pub fn submit<W: Into<WorkToken>>(&mut self, work: W) -> io::Result<WorkHandle> {
555 let handle = self.executor.submit(work)?;
556 self.handles.push(handle.clone());
557 Ok(handle)
558 }
559
560 pub fn join(mut self) -> thread::Result<()> {
562 for handle in self.handles.drain(..) {
563 handle.join()?;
564 }
565 Ok(())
566 }
567
568 pub fn typed<W: Into<WorkToken>>(self) -> TypedWorkerQueue<'exec, W> {
569 TypedWorkerQueue {
570 _data: PhantomData,
571 queue: self,
572 }
573 }
574}
575
576pub struct TypedWorkerQueue<'exec, W: Into<WorkToken>> {
578 _data: PhantomData<W>,
579 queue: WorkerQueue<'exec>,
580}
581
582impl<'exec, W: Into<WorkToken>> TypedWorkerQueue<'exec, W> {
583 pub fn new(executor: &'exec WorkerExecutor) -> Self {
585 Self {
586 _data: PhantomData,
587 queue: executor.queue(),
588 }
589 }
590
591 pub fn submit(&mut self, work: W) -> io::Result<WorkHandle> {
593 self.queue.submit(work)
594 }
595
596 pub fn join(self) -> thread::Result<()> {
598 self.queue.join()
599 }
600}
601
602#[cfg(test)]
603mod tests {
604 use crate::work_queue::WorkerExecutor;
605
606 use std::sync::atomic::{AtomicUsize, Ordering};
607 use std::sync::{Arc, Barrier};
608 use std::thread;
609 use std::time::Duration;
610 const WORK_SIZE: usize = 6;
611 #[test]
612 #[ignore]
613 fn parallelism_works() {
614 let mut worker_queue = WorkerExecutor::new(WORK_SIZE).unwrap();
615
616 let _wait_group = Arc::new(Barrier::new(WORK_SIZE));
617 let add_all = Arc::new(AtomicUsize::new(0));
618
619 let mut current_worker = 0;
620
621 for _ in 0..(WORK_SIZE * 2) {
622 let add_all = add_all.clone();
623 let this_worker = current_worker;
624 current_worker += 1;
625 worker_queue
626 .submit(move || {
627 debug!("running worker thread {}", this_worker);
628 add_all.fetch_add(1, Ordering::SeqCst);
629 })
630 .unwrap();
631 }
632
633 worker_queue.finish_jobs().unwrap();
634 assert_eq!(add_all.load(Ordering::SeqCst), WORK_SIZE * 2);
635
636 for _ in 0..(WORK_SIZE * 2) {
637 let add_all = add_all.clone();
638 let this_worker = current_worker;
639 current_worker += 1;
640 worker_queue
641 .submit(move || {
642 debug!("running worker thread {}", this_worker);
643 add_all.fetch_add(1, Ordering::SeqCst);
644 })
645 .unwrap();
646 }
647
648 worker_queue.join().unwrap();
649
650 assert_eq!(add_all.load(Ordering::SeqCst), WORK_SIZE * 4);
651 }
652
653 #[test]
654 fn worker_queues_provide_protection() {
655 let exec = WorkerExecutor::new(WORK_SIZE).unwrap();
656
657 let accum = Arc::new(AtomicUsize::new(0));
658 {
659 let mut queue = exec.queue();
660 for _i in 0..64 {
661 let accum = accum.clone();
662 queue
663 .submit(move || {
664 accum.fetch_add(1, Ordering::Relaxed);
665 })
666 .unwrap();
667 }
668
669 }
671
672 assert_eq!(accum.load(Ordering::Acquire), 64);
673 }
674
675 fn test_executor_pool_size_ensured(pool_size: usize) {
676 let workers_running = Arc::new(AtomicUsize::new(0));
677 let max_workers_running = Arc::new(AtomicUsize::new(0));
678
679 let executor = WorkerExecutor::new(pool_size).unwrap();
680 {
681 let mut queue = executor.queue();
682 for _ in 0..4 * pool_size {
683 let workers_running = workers_running.clone();
684 let max_workers_running = max_workers_running.clone();
685 let _ = queue.submit(move || {
686 workers_running.fetch_add(1, Ordering::SeqCst);
687 thread::sleep(Duration::from_millis(100));
688 let _ = workers_running.fetch_update(
689 Ordering::SeqCst,
690 Ordering::SeqCst,
691 |running| {
692 let _ = max_workers_running.fetch_update(
693 Ordering::SeqCst,
694 Ordering::SeqCst,
695 |max| {
696 if running > max {
697 Some(running)
698 } else {
699 None
700 }
701 },
702 );
703 None
704 },
705 );
706
707 workers_running.fetch_sub(1, Ordering::SeqCst);
708 });
709 }
710
711 queue.join().expect("worker task failed :(");
712 }
713
714 let max_workers_running = max_workers_running.load(Ordering::Acquire);
715 println!("max running workers: {}", max_workers_running);
716 assert!(max_workers_running <= pool_size);
717 }
718
719 #[test]
720 fn only_correct_number_of_workers_run() {
721 test_executor_pool_size_ensured(1);
722 test_executor_pool_size_ensured(2);
723 test_executor_pool_size_ensured(4);
724 test_executor_pool_size_ensured(8);
725 }
726
727 #[test]
728 #[ignore]
729 fn can_stop_after_panic() {
730 let executor = WorkerExecutor::new(1).unwrap();
731 let job = executor.submit(|| panic!("WOOH I PANICKED")).unwrap();
732 job.join()
733 .expect_err("Should expect an error because a panic occurred");
734 println!("any panicked = {}", executor.any_panicked());
735 assert!(executor.any_panicked());
736 }
737}