1use std::{
26 mem,
27 ops::Deref,
28 sync::{
29 atomic::{AtomicBool, AtomicUsize, Ordering},
30 Arc,
31 },
32 thread::JoinHandle,
33 time::Duration,
34};
35
36use crossbeam_channel::{bounded, unbounded, RecvTimeoutError, SendError, Sender};
37use thiserror::Error;
38
39pub use global::*;
40
41pub(crate) mod global;
42
43enum Blocked {
45 Shutdown,
47 Continue,
49}
50
51enum Command {
53 Task(Box<dyn FnOnce() + Send>),
55
56 Swap(Sender<()>),
58
59 Shutdown,
61}
62
63#[derive(Error, Debug, PartialEq, Eq)]
65pub enum DispatchError {
66 #[error("The worker panicked while running a task")]
68 WorkerPanic,
69
70 #[error("Maximum queue size reached")]
72 QueueFull,
73
74 #[error("Pre-init buffer was already flushed")]
76 AlreadyFlushed,
77
78 #[error("Failed to send command to worker thread")]
80 SendError,
81
82 #[error("Failed to receive from channel")]
84 RecvError(#[from] crossbeam_channel::RecvError),
85}
86
87impl<T> From<SendError<T>> for DispatchError {
88 fn from(_: SendError<T>) -> Self {
89 DispatchError::SendError
90 }
91}
92
93#[derive(Clone)]
95struct DispatchGuard {
96 inner: Arc<DispatchGuardInner>,
97}
98
99impl Deref for DispatchGuard {
100 type Target = DispatchGuardInner;
101
102 fn deref(&self) -> &Self::Target {
103 &self.inner
104 }
105}
106
107struct DispatchGuardInner {
108 queue_preinit: AtomicBool,
110
111 overflow_count: AtomicUsize,
113
114 max_queue_size: usize,
116
117 block_sender: Sender<Blocked>,
119
120 preinit_sender: Sender<Command>,
122
123 sender: Sender<Command>,
125}
126
127impl DispatchGuard {
128 pub fn launch(&self, task: impl FnOnce() + Send + 'static) -> Result<(), DispatchError> {
129 let task = Command::Task(Box::new(task));
130 self.send(task)
131 }
132
133 pub fn shutdown(&mut self) -> Result<(), DispatchError> {
134 self.flush_init().ok();
137 self.send(Command::Shutdown)
138 }
139
140 fn send(&self, task: Command) -> Result<(), DispatchError> {
141 if self.queue_preinit.load(Ordering::SeqCst) {
142 if self.preinit_sender.len() < self.max_queue_size {
143 self.preinit_sender.send(task)?;
144 Ok(())
145 } else {
146 self.overflow_count.fetch_add(1, Ordering::SeqCst);
147 Err(DispatchError::QueueFull)
151 }
152 } else {
153 self.sender.send(task)?;
154 Ok(())
155 }
156 }
157
158 fn block_on_queue(&self) {
159 let (tx, rx) = crossbeam_channel::bounded(0);
160
161 let task = Command::Task(Box::new(move || {
167 tx.send(())
168 .expect("(worker) Can't send message on single-use channel");
169 }));
170 self.sender
171 .send(task)
172 .expect("Failed to launch the blocking task");
173
174 rx.recv()
175 .expect("Failed to receive message on single-use channel");
176 }
177
178 fn block_on_queue_timeout(&self, timeout: Duration) -> Result<(), RecvTimeoutError> {
180 let (tx, rx) = crossbeam_channel::bounded(0);
181
182 let task = Command::Task(Box::new(move || {
188 _ = tx.send(());
192 }));
193 self.sender
194 .send(task)
195 .expect("Failed to launch the blocking task");
196
197 rx.recv_timeout(timeout)
198 }
199
200 fn kill(&mut self) -> Result<(), DispatchError> {
201 let old_val = self.queue_preinit.swap(false, Ordering::SeqCst);
203 if !old_val {
204 return Err(DispatchError::AlreadyFlushed);
205 }
206
207 self.block_sender.send(Blocked::Shutdown)?;
209 Ok(())
210 }
211
212 fn flush_init(&mut self) -> Result<usize, DispatchError> {
219 let old_val = self.queue_preinit.swap(false, Ordering::SeqCst);
221 if !old_val {
222 return Err(DispatchError::AlreadyFlushed);
223 }
224
225 self.block_sender.send(Blocked::Continue)?;
227
228 let (swap_sender, swap_receiver) = bounded(0);
230
231 self.preinit_sender
233 .send(Command::Swap(swap_sender))
234 .map_err(|_| DispatchError::SendError)?;
235
236 swap_receiver.recv()?;
239
240 global::QUEUE_TASKS.store(false, Ordering::SeqCst);
242
243 let overflow_count = self.overflow_count.load(Ordering::SeqCst);
244 if overflow_count > 0 {
245 Ok(overflow_count)
246 } else {
247 Ok(0)
248 }
249 }
250}
251
252pub struct Dispatcher {
260 guard: DispatchGuard,
262
263 worker: Option<JoinHandle<()>>,
265}
266
267impl Dispatcher {
268 pub fn new(max_queue_size: usize) -> Self {
274 let (block_sender, block_receiver) = bounded(1);
275 let (preinit_sender, preinit_receiver) = unbounded();
276 let (sender, mut unbounded_receiver) = unbounded();
277
278 let queue_preinit = AtomicBool::new(true);
279 let overflow_count = AtomicUsize::new(0);
280
281 let worker = crate::thread::spawn("glean.dispatcher", move || {
282 match block_receiver.recv() {
283 Err(_) => {
284 log::error!("The task producer was disconnected. Worker thread will exit.");
287 return;
288 }
289 Ok(Blocked::Shutdown) => {
290 return;
292 }
293 Ok(Blocked::Continue) => {
294 }
296 }
297
298 let mut receiver = preinit_receiver;
299 loop {
300 use Command::*;
301
302 match receiver.recv() {
303 Ok(Shutdown) => {
304 break;
305 }
306
307 Ok(Task(f)) => {
308 (f)();
309 }
310
311 Ok(Swap(swap_done)) => {
312 mem::swap(&mut receiver, &mut unbounded_receiver);
319
320 swap_done
324 .send(())
325 .expect("The caller of `flush_init` has gone missing");
326 }
327
328 Err(_) => {
330 log::error!("The task producer was disconnected. Worker thread will exit.");
331 return;
332 }
333 }
334 }
335 })
336 .expect("Failed to spawn Glean's dispatcher thread");
337
338 let inner = Arc::new(DispatchGuardInner {
339 queue_preinit,
340 overflow_count,
341 max_queue_size,
342 block_sender,
343 preinit_sender,
344 sender,
345 });
346 let guard = DispatchGuard { inner };
347
348 Dispatcher {
349 guard,
350 worker: Some(worker),
351 }
352 }
353
354 fn guard(&self) -> DispatchGuard {
355 self.guard.clone()
356 }
357
358 #[cfg(test)]
362 fn join(mut self) -> Result<(), DispatchError> {
363 if let Some(worker) = self.worker.take() {
364 worker.join().map_err(|_| DispatchError::WorkerPanic)?;
365 }
366 Ok(())
367 }
368}
369
370#[cfg(test)]
371mod test {
372 use super::*;
373 use std::sync::atomic::AtomicU8;
374 use std::sync::Mutex;
375 use std::thread;
376
377 fn enable_test_logging() {
378 let _ = env_logger::builder().is_test(true).try_init();
381 }
382
383 #[test]
384 fn tasks_run_off_the_main_thread() {
385 enable_test_logging();
386
387 let main_thread_id = thread::current().id();
388 let thread_canary = Arc::new(AtomicBool::new(false));
389
390 let dispatcher = Dispatcher::new(100);
391
392 dispatcher
394 .guard()
395 .flush_init()
396 .expect("Failed to get out of preinit queue mode");
397
398 let canary_clone = thread_canary.clone();
399 dispatcher
400 .guard()
401 .launch(move || {
402 assert!(thread::current().id() != main_thread_id);
403 assert!(!canary_clone.load(Ordering::SeqCst));
406 canary_clone.store(true, Ordering::SeqCst);
407 })
408 .expect("Failed to dispatch the test task");
409
410 dispatcher.guard().block_on_queue();
411 assert!(thread_canary.load(Ordering::SeqCst));
412 assert_eq!(main_thread_id, thread::current().id());
413 }
414
415 #[test]
416 fn launch_correctly_adds_tasks_to_preinit_queue() {
417 enable_test_logging();
418
419 let main_thread_id = thread::current().id();
420 let thread_canary = Arc::new(AtomicU8::new(0));
421
422 let dispatcher = Dispatcher::new(100);
423
424 for _ in 0..3 {
427 let canary_clone = thread_canary.clone();
428 dispatcher
429 .guard()
430 .launch(move || {
431 assert!(thread::current().id() != main_thread_id);
433 canary_clone.fetch_add(1, Ordering::SeqCst);
434 })
435 .expect("Failed to dispatch the test task");
436 }
437
438 assert_eq!(0, thread_canary.load(Ordering::SeqCst));
440
441 dispatcher
443 .guard()
444 .flush_init()
445 .expect("Failed to get out of preinit queue mode");
446 assert_eq!(3, thread_canary.load(Ordering::SeqCst));
448 }
449
450 #[test]
451 fn preinit_tasks_are_processed_after_flush() {
452 enable_test_logging();
453
454 let dispatcher = Dispatcher::new(10);
455
456 let result = Arc::new(Mutex::new(vec![]));
457 for i in 1..=5 {
458 let result = Arc::clone(&result);
459 dispatcher
460 .guard()
461 .launch(move || {
462 result.lock().unwrap().push(i);
463 })
464 .unwrap();
465 }
466
467 result.lock().unwrap().push(0);
468 dispatcher.guard().flush_init().unwrap();
469 for i in 6..=10 {
470 let result = Arc::clone(&result);
471 dispatcher
472 .guard()
473 .launch(move || {
474 result.lock().unwrap().push(i);
475 })
476 .unwrap();
477 }
478
479 dispatcher.guard().block_on_queue();
480
481 assert_eq!(
483 &*result.lock().unwrap(),
484 &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
485 );
486 }
487
488 #[test]
489 fn tasks_after_shutdown_are_not_processed() {
490 enable_test_logging();
491
492 let dispatcher = Dispatcher::new(10);
493
494 let result = Arc::new(Mutex::new(vec![]));
495
496 dispatcher.guard().flush_init().unwrap();
497
498 dispatcher.guard().shutdown().unwrap();
499 {
500 let result = Arc::clone(&result);
501 let _ = dispatcher.guard().launch(move || {
505 result.lock().unwrap().push(0);
506 });
507 }
508
509 dispatcher.join().unwrap();
510
511 assert!(result.lock().unwrap().is_empty());
512 }
513
514 #[test]
515 fn preinit_buffer_fills_up() {
516 enable_test_logging();
517
518 let dispatcher = Dispatcher::new(5);
519
520 let result = Arc::new(Mutex::new(vec![]));
521
522 for i in 1..=5 {
523 let result = Arc::clone(&result);
524 dispatcher
525 .guard()
526 .launch(move || {
527 result.lock().unwrap().push(i);
528 })
529 .unwrap();
530 }
531
532 {
533 let result = Arc::clone(&result);
534 let err = dispatcher.guard().launch(move || {
535 result.lock().unwrap().push(10);
536 });
537 assert_eq!(Err(DispatchError::QueueFull), err);
538 }
539
540 dispatcher.guard().flush_init().unwrap();
541
542 {
543 let result = Arc::clone(&result);
544 dispatcher
545 .guard()
546 .launch(move || {
547 result.lock().unwrap().push(20);
548 })
549 .unwrap();
550 }
551
552 dispatcher.guard().block_on_queue();
553
554 assert_eq!(&*result.lock().unwrap(), &[1, 2, 3, 4, 5, 20]);
555 }
556
557 #[test]
558 fn normal_queue_is_unbounded() {
559 enable_test_logging();
560
561 let dispatcher = Dispatcher::new(5);
566
567 let result = Arc::new(Mutex::new(vec![]));
568
569 for i in 1..=5 {
570 let result = Arc::clone(&result);
571 dispatcher
572 .guard()
573 .launch(move || {
574 result.lock().unwrap().push(i);
575 })
576 .unwrap();
577 }
578
579 dispatcher.guard().flush_init().unwrap();
580
581 for i in 6..=20 {
585 let result = Arc::clone(&result);
586 dispatcher
587 .guard()
588 .launch(move || {
589 thread::sleep(Duration::from_millis(50));
590 result.lock().unwrap().push(i);
591 })
592 .unwrap();
593 }
594
595 dispatcher.guard().shutdown().unwrap();
596 dispatcher.join().unwrap();
597
598 let expected = (1..=20).collect::<Vec<_>>();
599 assert_eq!(&*result.lock().unwrap(), &expected);
600 }
601}