1use super::{
2 comm::{
3 self, CommandReceiver, CommandSender, GlobalSignal, ParkingReceiver, ParkingSender,
4 SubComm, WORKER_ID_GEN,
5 },
6 par::FnContext,
7 task::{AsyncTask, ParTask, SysTask, Task},
8};
9use crate::{
10 ds::{
11 Array, ListPos, ManagedConstPtr, ManagedMutPtr, NonNullExt, SetValueList, UnsafeFuture,
12 WakeSend,
13 },
14 ecs::{
15 cache::{CacheItem, RefreshCacheStorage},
16 cmd::CommandObject,
17 sys::system::{RawSystemCycleIter, SystemCycleIter, SystemData, SystemGroup, SystemId},
18 wait::WaitQueues,
19 worker::{Message, PanicMessage, Work, WorkerId},
20 },
21 util::prelude::*,
22 MAX_GROUP,
23};
24use crossbeam_deque as cb;
25use std::{
26 any::Any,
27 cell::{Cell, UnsafeCell},
28 collections::HashMap,
29 hash::BuildHasher,
30 marker::PhantomPinned,
31 mem,
32 ops::{Deref, IndexMut},
33 pin::Pin,
34 ptr::NonNull,
35 rc::Rc,
36 sync::{
37 atomic::{AtomicU32, Ordering},
38 Arc,
39 },
40 thread::{self, Thread, ThreadId},
41 time::Duration,
42};
43
44#[derive(Debug)]
45pub(crate) struct Scheduler<W: Work + 'static, S> {
46 wgroups: Array<WorkGroup<W>, MAX_GROUP>,
47
48 waits: WaitQueues<S>,
49
50 record: ScheduleRecord<S>,
52
53 nor_pendings: Array<Pending<S>, MAX_GROUP>,
55
56 dedi_pendings: Array<Pending<S>, MAX_GROUP>,
58
59 tx_dedi: ParkingSender<Task>,
61 rx_dedi: ParkingReceiver<Task>,
62
63 tx_msg: ParkingSender<Message>,
65
66 rx_msg: ParkingReceiver<Message>,
68
69 tx_cmd: CommandSender,
71
72 rx_cmd: Rc<CommandReceiver>,
74
75 wid: WorkerId,
77
78 waker: MainWaker,
80
81 fut_cnt: Arc<AtomicU32>,
82}
83
84impl<W, S> Scheduler<W, S>
85where
86 W: Work + 'static,
87{
88 pub(crate) fn num_groups(&self) -> usize {
89 self.wgroups.len()
90 }
91
92 pub(crate) fn num_workers(&self) -> usize {
93 self.wgroups.iter().map(WorkGroup::len).sum()
94 }
95
96 pub(crate) fn is_work_groups_exhausted(&self) -> bool {
98 self.wgroups.iter().all(|wg| wg.is_exhausted())
99 }
100
101 pub(crate) fn has_command(&self) -> bool {
103 !self.rx_cmd.is_empty()
104 }
105
106 pub(crate) fn has_dedicated_future(&self) -> bool {
107 self.fut_cnt.load(Ordering::Relaxed) > 0
108 }
109
110 pub(crate) fn wait_exhausted(&self) {
111 for wg in self.wgroups.iter() {
112 wg.wait_exhausted();
113 }
114 }
115
116 pub(crate) fn wait_receiving_dedicated_task(&self) {
117 self.rx_dedi.wait_timeout(Duration::MAX);
118 }
119
120 pub(crate) fn take_workers(mut self) -> Vec<W> {
124 self.wgroups.iter_mut().fold(Vec::new(), |mut acc, wgroup| {
125 acc.append(&mut wgroup.take_workers());
126 acc
127 })
128 }
129
130 pub(crate) fn get_wait_queues_mut(&mut self) -> &mut WaitQueues<S> {
131 &mut self.waits
132 }
133
134 pub(crate) fn get_tx_dedi_queue(&self) -> &ParkingSender<Task> {
135 &self.tx_dedi
136 }
137
138 pub(crate) fn get_send_message_queue(&self) -> &ParkingSender<Message> {
139 &self.tx_msg
140 }
141
142 pub(crate) fn get_future_count(&self) -> &Arc<AtomicU32> {
143 &self.fut_cnt
144 }
145
146 fn work_one(&mut self) {
147 if let Ok(task) = self.rx_dedi.try_recv() {
148 match task {
151 Task::System(task) => self.work_for_system_task(task),
152 Task::Parallel(task) => self.work_for_parallel_task(task),
153 Task::Async(task) => self.work_for_async_task(task),
154 }
155 }
156 }
157
158 fn work_for_system_task(&self, task: SysTask) {
159 let sid = task.sid();
160
161 let resp = match task.execute(self.wid) {
162 Ok(_) => Message::Fin(self.wid, sid),
163 Err(payload) => Message::Panic(PanicMessage {
164 wid: self.wid,
165 sid,
166 payload,
167 unrecoverable: false,
168 }),
169 };
170
171 self.tx_msg.send(resp).unwrap();
174 }
175
176 fn work_for_parallel_task(&self, task: ParTask) {
177 task.execute(self.wid, FnContext::NOT_MIGRATED);
178 }
179
180 fn work_for_async_task(&self, task: AsyncTask) {
181 unsafe {
183 if !task.will_wake(&self.waker) {
184 task.set_waker(self.waker.clone());
185 }
186 }
187
188 let on_ready = |ready| {
190 self.fut_cnt.fetch_sub(1, Ordering::Relaxed);
192
193 let cmd = CommandObject::Future(ready);
195 self.tx_cmd.send_or_cancel(cmd);
196 };
197 task.execute(self.wid, on_ready);
198 }
199}
200
201impl<W, S> Scheduler<W, S>
202where
203 W: Work + 'static,
204 S: BuildHasher + Default + 'static,
205{
206 pub(crate) fn new(
207 mut workers: Vec<W>,
208 groups: &[usize],
209 tx_cmd: CommandSender,
210 rx_cmd: Rc<CommandReceiver>,
211 ) -> Self {
212 assert_eq!(workers.len(), groups.iter().sum::<usize>());
213
214 let num_groups = groups.len();
215 let pending_limit: usize = workers.len();
216
217 let (nor_pendings, dedi_pendings) =
218 (0..num_groups).fold((Array::new(), Array::new()), |(mut nor, mut dedi), _| {
219 nor.push(Pending::new(pending_limit));
220 dedi.push(Pending::new(pending_limit));
221 (nor, dedi)
222 });
223
224 let (tx_msg, rx_msg) = comm::parking_channel(thread::current());
225
226 let wgroups = (0..num_groups).fold(Array::new(), |mut acc, i| {
227 let mut left = workers.split_off(groups[i]); mem::swap(&mut workers, &mut left); let mut group = WorkGroup::new(i as u16, left, &tx_msg, &tx_cmd);
233 group.initialize(&rx_msg);
234 acc.push(group);
235
236 acc
237 });
238
239 let (tx_dedi, rx_dedi) = comm::parking_channel(thread::current());
240 let waker = MainWaker::new(tx_dedi.clone());
241
242 let id = WORKER_ID_GEN.get();
243 WORKER_ID_GEN.set(id + 1);
244 let wid = WorkerId::new(
245 id,
246 WorkerId::dummy().group_index(),
247 WorkerId::dummy().worker_index(),
248 );
249
250 Self {
255 wgroups,
256 waits: WaitQueues::new(),
257 record: ScheduleRecord::new(),
258 nor_pendings,
259 dedi_pendings,
260 tx_dedi,
261 rx_dedi,
262 tx_msg,
263 rx_msg,
264 tx_cmd,
265 rx_cmd,
266 wid,
267 waker,
268 fut_cnt: Arc::new(AtomicU32::new(0)),
269 }
270 }
271
272 pub(crate) fn execute_all<T>(&mut self, sgroups: &mut T, cache: &mut RefreshCacheStorage<S>)
273 where
274 T: IndexMut<usize, Output = SystemGroup<S>>,
275 {
276 let num_groups = self.wgroups.len();
284 let mut lives = [false; MAX_GROUP];
285 let mut units: Array<ScheduleUnit<'_, W, S>, MAX_GROUP> = Array::new();
286 for i in 0..num_groups {
287 lives[i] = sgroups[i].len_active() > 0;
288 let cycle = sgroups[i].get_active_mut().iter_begin().into_raw();
289 let unit = ScheduleUnit::new(i, self, cycle, cache);
290 units.push(unit);
291 }
292 let tickables = lives;
293 let mut panicked = Vec::new();
294
295 for (i, _) in lives.iter().enumerate().filter(|(_, live)| **live) {
297 self.wgroups[i].open();
298 }
299
300 loop {
306 for (i, live) in lives.iter_mut().enumerate().filter(|(_, live)| **live) {
309 let pull_end = units[i].pull_many() == PullRes::Empty;
310 let no_pending = !self.has_pending(i);
311 if pull_end && no_pending {
312 self.wgroups[i].close();
313 *live = false;
314 }
315 }
316
317 self.work_one();
319 if !self.rx_dedi.is_empty() {
320 continue;
321 }
322
323 if lives.iter().any(|&live| live) {
325 self.wait(&mut units, cache, &mut panicked);
326 } else {
327 self.consume_messages(cache, &mut panicked);
328 break;
329 }
330 }
331
332 drop(units);
337 while let Some((sid, payload)) = panicked.pop() {
338 sgroups[sid.group_index() as usize]
339 .poison(&sid, payload)
340 .unwrap();
341 }
342 for i in 0..num_groups {
343 if tickables[i] {
344 sgroups[i].tick();
345 }
346 }
347 self.record.clear();
348
349 #[cfg(debug_assertions)]
350 self.validate_clean();
351 }
352
353 fn wait<'s, T>(
354 &mut self,
355 units: &mut T,
356 cache: &mut RefreshCacheStorage<'_, S>,
357 panicked: &mut Vec<(SystemId, Box<dyn Any + Send>)>,
358 ) where
359 T: IndexMut<usize, Output = ScheduleUnit<'s, W, S>>,
360 {
361 if let Ok(msg) = self.rx_msg.recv_timeout(Duration::MAX) {
363 self.handle_message(msg, cache, panicked);
364 }
365 while let Ok(msg) = self.rx_msg.try_recv() {
366 self.handle_message(msg, cache, panicked);
367 }
368
369 self.pending_to_ready(units, cache);
370 }
371
372 fn consume_messages(
373 &mut self,
374 cache: &mut RefreshCacheStorage<'_, S>,
375 panicked: &mut Vec<(SystemId, Box<dyn Any + Send>)>,
376 ) {
377 while self.record.num_injected() > self.record.num_completed() {
378 if let Ok(msg) = self.rx_msg.recv_timeout(Duration::MAX) {
379 self.handle_message(msg, cache, panicked);
380 }
381 }
382 }
383
384 fn handle_message(
385 &mut self,
386 msg: Message,
387 cache: &mut RefreshCacheStorage<'_, S>,
388 panicked: &mut Vec<(SystemId, Box<dyn Any + Send>)>,
389 ) {
390 match msg {
391 Message::Handle(..) => unreachable!(),
392 Message::Fin(_wid, sid) => {
393 self.record.insert(sid, RunResult::Finished);
394 let cache = cache.get(&sid).unwrap();
395 self.waits.dequeue(&cache.get_wait_indices());
396 }
397 Message::Aborted(_wid, sid) => {
398 self.record.insert(sid, RunResult::Aborted);
399 let cache = cache.get(&sid).unwrap();
400 self.waits.dequeue(&cache.get_wait_indices());
401 }
402 Message::Panic(msg) => {
403 self.record.insert(msg.sid, RunResult::Panicked);
404 self.panic_helper(cache, panicked, msg);
405 }
406 };
407 }
408
409 fn pending_to_ready<'s, T>(&mut self, units: &mut T, cache: &mut RefreshCacheStorage<'_, S>)
410 where
411 T: IndexMut<usize, Output = ScheduleUnit<'s, W, S>>,
412 {
413 #[allow(clippy::needless_range_loop)] let num_groups = self.wgroups.len();
415 for i in 0..num_groups {
416 unsafe {
418 let target = NonNull::new_unchecked(&mut self.wgroups[i] as *mut _);
420 Helper::pending_to_ready::<W, S>(
421 Or::A(target),
422 &mut self.nor_pendings[i],
423 &mut self.waits,
424 &mut units[i].cycle(),
425 cache,
426 );
427
428 let target = &mut self.tx_dedi as *mut _;
430 let target = NonNull::new_unchecked(target);
431 Helper::pending_to_ready::<W, S>(
432 Or::B(target),
433 &mut self.dedi_pendings[i],
434 &mut self.waits,
435 &mut units[i].cycle(),
436 cache,
437 );
438 }
439 }
440 }
441
442 fn panic_helper(
443 &mut self,
444 cache: &mut RefreshCacheStorage<S>,
445 panicked: &mut Vec<(SystemId, Box<dyn Any + Send>)>,
446 msg: PanicMessage,
447 ) {
448 if msg.unrecoverable {
449 panic!("unrecoverable");
450 }
451
452 let cache = {
453 #[cfg(not(target_arch = "wasm32"))]
454 {
455 cache.get(&msg.sid).unwrap()
456 }
457
458 #[cfg(target_arch = "wasm32")]
462 {
463 let mut cache = cache.get_mut(&msg.sid).unwrap();
464 let buf = cache.get_request_buffer_mut();
465 buf.clear();
466 cache
467 }
468 };
469
470 self.waits.dequeue(&cache.get_wait_indices());
471 panicked.push((msg.sid, msg.payload));
472
473 #[cfg(target_arch = "wasm32")]
474 {
475 debug_assert_eq!(msg.sid.group_index(), msg.wid.group_index());
479 let gi = msg.wid.group_index() as usize;
480 let wi = msg.wid.worker_index() as usize;
481
482 self.wgroups[gi].insert_search(wi);
483 }
484 }
485
486 fn has_pending(&self, index: usize) -> bool {
487 !self.nor_pendings[index].is_empty() || !self.dedi_pendings[index].is_empty()
488 }
489
490 #[cfg(debug_assertions)]
491 fn validate_clean(&self) {
492 assert!(self.waits.is_all_queue_empty());
494
495 assert!(self.record.is_empty());
497
498 let num_groups = self.wgroups.len();
500 for i in 0..num_groups {
501 assert!(!self.has_pending(i));
502 }
503
504 for task in self.rx_dedi.buffer().iter() {
509 if matches!(task, Task::System(_) | Task::Parallel(_)) {
510 panic!("expected empty dedicated queue, but found: {task:?}");
511 }
512 }
513
514 match self.rx_msg.try_recv() {
516 Err(std::sync::mpsc::TryRecvError::Empty) => {}
517 Ok(msg) => panic!("unexpected remaining msg in channel: {msg:?}"),
518 Err(err) => panic!("unexpected error from channel: {err:?}"),
519 }
520 }
521}
522
523#[derive(Debug)]
524struct ScheduleUnit<'s, W: Work + 'static, S> {
525 cycle: RawSystemCycleIter<S>,
527
528 wgroup: NonNull<WorkGroup<W>>,
530 waits: NonNull<WaitQueues<S>>,
531 record: NonNull<ScheduleRecord<S>>,
532 nor_pendings: NonNull<[Pending<S>]>,
533 dedi_pendings: NonNull<[Pending<S>]>,
534 tx_dedi: NonNull<ParkingSender<Task>>,
535
536 cache: NonNull<RefreshCacheStorage<'s, S>>,
538}
539
540impl<'s, W, S> ScheduleUnit<'s, W, S>
541where
542 W: Work + 'static,
543 S: BuildHasher + Default + 'static,
544{
545 fn new(
546 index: usize,
547 sched: &mut Scheduler<W, S>,
548 cycle: RawSystemCycleIter<S>,
549 cache: &mut RefreshCacheStorage<'s, S>,
550 ) -> Self {
551 unsafe {
554 let ptr = sched.wgroups.get_mut(index).unwrap_unchecked() as *mut _;
556 let wgroup = NonNull::new_unchecked(ptr);
557 let ptr = &mut sched.waits as *mut _;
558 let waits = NonNull::new_unchecked(ptr);
559 let ptr = &mut sched.record as *mut _;
560 let record = NonNull::new_unchecked(ptr);
561 let ptr = sched.nor_pendings.as_mut_slice() as *mut _;
562 let nor_pendings = NonNull::new_unchecked(ptr);
563 let ptr = sched.dedi_pendings.as_mut_slice() as *mut _;
564 let dedi_pendings = NonNull::new_unchecked(ptr);
565 let ptr = &mut sched.tx_dedi as *mut _;
566 let tx_dedi = NonNull::new_unchecked(ptr);
567 let ptr = cache as *mut _;
569 let cache = NonNull::new_unchecked(ptr);
570
571 Self {
572 cycle,
573 wgroup,
574 waits,
575 record,
576 nor_pendings,
577 dedi_pendings,
578 tx_dedi,
579 cache,
580 }
581 }
582 }
583
584 fn pull_many(&mut self) -> PullRes {
585 loop {
586 match self.pull_one() {
587 PullRes::Empty => return PullRes::Empty,
588 PullRes::Success => {}
589 PullRes::PendingFull => return PullRes::PendingFull,
590 }
591 }
592 }
593
594 fn pull_one(&mut self) -> PullRes {
595 let mut cycle = self.cycle();
603 if cycle.position().is_end() {
604 return PullRes::Empty;
605 }
606
607 let sdata = cycle.get().unwrap();
609 let sid = sdata.id();
610 let gi = sid.group_index() as usize;
611 let (pending, target) = if sdata.flags().is_dedi() {
612 (&mut self.dedi_pendings()[gi], Or::B(self.tx_dedi))
613 } else {
614 (&mut self.nor_pendings()[gi], Or::A(self.wgroup))
615 };
616
617 if let Some(cache) = Helper::update_task(self.waits(), sdata, self.cache()) {
619 self.record().insert(sid, RunResult::Injected);
620 Helper::move_ready_system(target, sdata, cache);
621 unsafe { self.cycle.next() };
622 PullRes::Success
623 }
624 else if pending.push(cycle.position()) {
626 self.record().insert(sid, RunResult::Injected);
627 unsafe { self.cycle.next() };
628 PullRes::Success
629 }
630 else {
632 PullRes::PendingFull
633 }
634 }
635
636 fn cycle<'o>(&mut self) -> SystemCycleIter<'o, S> {
639 unsafe { SystemCycleIter::from_raw(self.cycle) }
642 }
643
644 fn waits<'o>(&mut self) -> &'o mut WaitQueues<S> {
645 unsafe { self.waits.as_mut() }
648 }
649
650 fn record<'o>(&mut self) -> &'o mut ScheduleRecord<S> {
651 unsafe { self.record.as_mut() }
654 }
655
656 fn nor_pendings<'o>(&mut self) -> &'o mut [Pending<S>] {
657 unsafe { self.nor_pendings.as_mut() }
660 }
661
662 fn dedi_pendings<'o>(&mut self) -> &'o mut [Pending<S>] {
663 unsafe { self.dedi_pendings.as_mut() }
666 }
667
668 fn cache<'o>(&mut self) -> &'o mut RefreshCacheStorage<'s, S> {
669 unsafe { self.cache.as_mut() }
672 }
673}
674
675impl<W, S> Drop for ScheduleUnit<'_, W, S>
676where
677 W: Work + 'static,
678{
679 fn drop(&mut self) {
680 debug_assert!(self.cycle.position().is_end());
682 }
683}
684
685struct Helper;
686
687impl Helper {
688 fn update_task<'a, S>(
692 waits: &mut WaitQueues<S>,
693 sdata: &mut SystemData,
694 cache: &'a mut RefreshCacheStorage<S>,
695 ) -> Option<&'a mut CacheItem>
696 where
697 S: BuildHasher + Default + 'static,
698 {
699 let sid = sdata.id();
700 let mut cache = cache.get_mut(&sid).unwrap();
701 let (wait, retry) = cache.get_wait_retry_indices_mut();
702 if waits.enqueue(&wait, retry) {
703 drop(wait);
705 Some(cache.refresh())
706 } else {
707 None
708 }
709 }
710
711 fn pending_to_ready<W, S>(
712 target: Or<NonNull<WorkGroup<W>>, NonNull<ParkingSender<Task>>>,
713 pending: &mut Pending<S>,
714 waits: &mut WaitQueues<S>,
715 cycle: &mut SystemCycleIter<'_, S>,
716 cache: &mut RefreshCacheStorage<'_, S>,
717 ) where
718 S: BuildHasher + Default + 'static,
719 W: Work + 'static,
720 {
721 let mut cur = pending.first_position();
722
723 while let Some((next, &cycle_pos)) = pending.iter_next(cur) {
724 let sdata = cycle.get_at(cycle_pos).unwrap();
725 if let Some(cache) = Self::update_task(waits, sdata, cache) {
726 pending.remove(&cycle_pos);
727 Self::move_ready_system(target, sdata, cache);
728 }
729 cur = next;
730 }
731 }
732
733 fn move_ready_system<W>(
734 target: Or<NonNull<WorkGroup<W>>, NonNull<ParkingSender<Task>>>,
735 sdata: &mut SystemData,
736 cache: &mut CacheItem,
737 ) where
738 W: Work + 'static,
739 {
740 let sid = sdata.id();
741
742 unsafe {
748 let mut invoker = sdata.task_ptr();
749 let buf = ManagedMutPtr::new(cache.request_buffer_ptr());
750
751 if sdata.flags().is_private() {
757 invoker.invoke_private(sid, buf);
758 } else {
759 let task = Task::System(SysTask::new(invoker, buf, sid));
760 match target {
761 Or::A(wgroup) => wgroup.as_ref().inject_task(task),
762 Or::B(dedi) => dedi.as_ref().send(task).unwrap(),
763 }
764 }
765 }
766 }
767}
768
769#[derive(Debug, Clone, Copy, PartialEq, Eq)]
770enum PullRes {
771 Empty,
772 Success,
773 PendingFull,
774}
775
776#[derive(Debug)]
777pub(crate) struct ScheduleRecord<S> {
778 record: HashMap<SystemId, RunResult, S>,
779 injected: usize,
780 finished: usize,
781 panicked: usize,
782 aborted: usize,
783}
784
785impl<S> ScheduleRecord<S>
786where
787 S: BuildHasher + Default + 'static,
788{
789 fn new() -> Self {
790 Self {
791 record: HashMap::default(),
792 injected: 0,
793 finished: 0,
794 panicked: 0,
795 aborted: 0,
796 }
797 }
798
799 #[cfg(debug_assertions)]
800 pub(crate) fn len(&self) -> usize {
801 self.record.len()
802 }
803
804 #[cfg(debug_assertions)]
805 pub(crate) fn is_empty(&self) -> bool {
806 self.len() == 0
807 }
808
809 pub(crate) fn clear(&mut self) {
810 self.record.clear();
811 self.injected = 0;
812 self.finished = 0;
813 self.panicked = 0;
814 self.aborted = 0;
815 }
816
817 pub(crate) fn num_injected(&self) -> usize {
818 self.injected
819 }
820
821 pub(crate) fn num_completed(&self) -> usize {
822 self.finished + self.panicked
823 }
824
825 fn insert(&mut self, sid: SystemId, state: RunResult) {
826 match state {
827 RunResult::Injected => self.injected += 1,
828 RunResult::Finished => self.finished += 1,
829 RunResult::Panicked => self.panicked += 1,
830 RunResult::Aborted => self.aborted += 1,
831 }
832 self.record.insert(sid, state);
833 }
834}
835
836#[derive(Debug)]
837pub(crate) enum RunResult {
838 Injected,
839 Finished,
840 Panicked,
841 Aborted,
842}
843
844#[derive(Debug)]
845struct Pending<S> {
846 list: SetValueList<ListPos, S>,
848
849 limit: usize,
851}
852
853impl<S> Pending<S>
854where
855 S: BuildHasher + Default,
856{
857 fn new(limit: usize) -> Self {
858 Self {
859 list: SetValueList::new(ListPos::end()),
860 limit,
861 }
862 }
863
864 fn is_empty(&self) -> bool {
865 self.list.is_empty()
866 }
867
868 fn push(&mut self, pos: ListPos) -> bool {
869 if self.list.len() < self.limit {
870 self.list.push_back(pos);
871 true
872 } else {
873 false
874 }
875 }
876
877 fn remove(&mut self, pos: &ListPos) {
878 self.list.remove(pos);
879 }
880}
881
882impl<S> Deref for Pending<S> {
884 type Target = SetValueList<ListPos, S>;
885
886 fn deref(&self) -> &Self::Target {
887 &self.list
888 }
889}
890
891#[derive(Debug)]
892struct WorkGroup<W: Work + 'static> {
893 workers: Vec<W>,
895
896 sub_cxs: Vec<Pin<Box<SubContext>>>,
898
899 signal: Arc<GlobalSignal>,
901
902 injector: Arc<cb::Injector<Task>>,
904}
905
906impl<W> WorkGroup<W>
907where
908 W: Work + 'static,
909{
910 fn new(
911 group_index: u16,
912 workers: Vec<W>,
913 tx_msg: &ParkingSender<Message>,
914 tx_cmd: &CommandSender,
915 ) -> Self {
916 let injector = Arc::new(cb::Injector::new());
918
919 let dummy_signal = Arc::new(GlobalSignal::new(Vec::new()));
922
923 let comms = SubComm::with_len(
924 group_index,
925 &injector,
926 &dummy_signal,
927 tx_msg,
928 tx_cmd,
929 workers.len(),
930 );
931
932 let sub_cxs = comms
934 .into_iter()
935 .map(|comm| {
936 Box::pin(SubContext {
937 guide: sub::SubStateGuide::new(),
938 handle: UnsafeCell::new(thread::current()),
941 comm,
942 need_close: Cell::new(false),
943 _pin: PhantomPinned,
944 })
945 })
946 .collect();
947
948 Self {
949 workers,
950 sub_cxs,
951 signal: dummy_signal,
952 injector,
953 }
954 }
955
956 fn initialize(&mut self, rx_msg: &ParkingReceiver<Message>) {
957 for i in 0..self.len() {
959 self.unpark_one(i);
960 }
961
962 let mut remain = self.len();
963 while remain > 0 {
964 if let Ok(msg) = rx_msg.recv_timeout(Duration::MAX) {
965 debug_assert_eq!(
966 mem::discriminant(&msg),
967 mem::discriminant(&Message::Handle(WorkerId::dummy()))
968 );
969 remain -= 1;
970 }
971 }
972
973 let handles = self
974 .sub_cxs
975 .iter()
976 .map(|sub_cx| unsafe { (*sub_cx.handle.get()).clone() })
977 .collect();
978
979 self.signal = Arc::new(GlobalSignal::new(handles));
980
981 for sub_cx in self.sub_cxs.iter_mut() {
982 sub_cx.as_mut().set_flags(Arc::clone(&self.signal));
983 }
984 }
985
986 fn open(&mut self) {
987 for i in 0..self.len() {
988 if self.sub_cxs[i].guide.push_open() {
989 self.unpark_one(i);
990 }
991 }
992 }
993
994 fn close(&mut self) {
995 for i in 0..self.len() {
996 self.sub_cxs[i].guide.push_close();
997 }
998 self.signal.sub().notify_all();
999 }
1000
1001 fn len(&self) -> usize {
1002 debug_assert_eq!(self.workers.len(), self.sub_cxs.len());
1003
1004 self.sub_cxs.len()
1005 }
1006
1007 fn take_workers(&mut self) -> Vec<W> {
1008 self.destroy();
1009 self.sub_cxs.clear();
1010 mem::take(&mut self.workers)
1011 }
1012
1013 fn wait_exhausted(&self) {
1015 while !self.is_exhausted() {
1016 self.signal.wait_open_count(0);
1017 }
1018 }
1019
1020 fn is_exhausted(&self) -> bool {
1025 let is_guide_empty = self.sub_cxs.iter().all(|cx| cx.guide.is_empty());
1028
1029 let is_all_closed = self.signal.open_count() == 0;
1032
1033 is_guide_empty && is_all_closed
1034 }
1035
1036 fn inject_task(&self, task: Task) {
1037 debug_assert!(
1038 !self.workers.is_empty(),
1039 "no workers for a non-dedicated task"
1040 );
1041
1042 self.injector.push(task);
1043 self.signal.sub().notify_one();
1044 }
1045
1046 #[cfg(debug_assertions)]
1047 fn validate_clean(&self) {
1048 for cx in self.sub_cxs.iter() {
1050 cx.validate_clean();
1051 }
1052
1053 assert_eq!(self.signal.open_count(), 0);
1055 assert_eq!(self.signal.work_count(), 0);
1056 assert_eq!(self.signal.future_count(), 0);
1057
1058 assert!(self.injector.is_empty());
1060 }
1061
1062 fn insert_reset(&mut self, index: usize) {
1063 self.sub_cxs[index].guide.push_reset();
1064 self.unpark_one(index);
1065 }
1066
1067 #[allow(dead_code)]
1073 fn insert_search(&mut self, index: usize) {
1074 let must_true = self.sub_cxs[index].guide.push_search();
1078 debug_assert!(must_true);
1079 self.unpark_one(index);
1080 }
1081
1082 fn unpark_one(&mut self, index: usize) {
1083 let ptr = self.sub_cxs[index].as_ref().get_ref() as *const SubContext;
1084 let ptr = unsafe {
1087 let ptr = NonNullExt::new_unchecked(ptr.cast_mut());
1088 ManagedConstPtr::new(ptr)
1089 };
1090
1091 let must_true = self.workers[index].unpark(ptr);
1092 assert!(must_true);
1093 }
1094
1095 fn destroy(&mut self) {
1096 self.signal.set_abort(true);
1099 while !self.is_exhausted() {
1100 self.signal.sub().notify_all();
1101 self.signal.wait_open_count(0);
1102 }
1103 self.signal.set_abort(false);
1104
1105 for i in 0..self.len() {
1107 self.insert_reset(i);
1108 }
1109 self.signal.wait_open_count(self.len() as u32);
1110
1111 self.close();
1113 self.signal.wait_open_count(0);
1114
1115 #[cfg(debug_assertions)]
1116 self.validate_clean();
1117 }
1118}
1119
1120impl<W> Drop for WorkGroup<W>
1121where
1122 W: Work + 'static,
1123{
1124 fn drop(&mut self) {
1125 self.destroy();
1126 }
1127}
1128
1129thread_local! {
1130 pub(crate) static SUB_CONTEXT: Cell<NonNullExt<SubContext>> = const {
1132 Cell::new(NonNullExt::dangling())
1133 };
1134
1135 pub(crate) static WORKER_ID: Cell<WorkerId> = const {
1137 Cell::new(WorkerId::dummy())
1138 };
1139}
1140
1141#[derive(Debug)]
1143pub struct SubContext {
1144 guide: sub::SubStateGuide,
1145
1146 handle: UnsafeCell<Thread>,
1148
1149 comm: SubComm,
1150
1151 need_close: Cell<bool>,
1152
1153 _pin: PhantomPinned,
1154}
1155
1156impl SubContext {
1157 #[rustfmt::skip]
1161 pub fn execute(ptr: ManagedConstPtr<Self>) {
1162 if ptr.comm.maybe_uninit_worker_id() != WORKER_ID.get() {
1164 Self::set_handle(ptr);
1165 return;
1166 }
1167
1168 let this = {
1172 #[cfg(target_arch = "wasm32")] { ptr.into_ref() }
1173 #[cfg(not(target_arch = "wasm32"))] { &*ptr }
1174 };
1175
1176 this.comm.signal().add_open_count(1);
1179
1180 let mut cur = this.guide.pop();
1182 let mut steal = cb::Steal::Empty;
1183 while let Some(next) = this.execute_by_state(cur, &mut steal) {
1184 cur = next;
1185 }
1186
1187 #[cfg(not(target_arch = "wasm32"))]
1190 let this = ptr.into_ref();
1191
1192 this.comm.signal().sub_open_count(1);
1197 }
1198
1199 pub(crate) fn get_comm(&self) -> &SubComm {
1200 &self.comm
1201 }
1202
1203 fn set_handle(ptr: ManagedConstPtr<Self>) {
1204 let handle = unsafe { &mut *ptr.handle.get() };
1205 *handle = thread::current();
1206 SUB_CONTEXT.set(ptr.as_nonnullext());
1207 WORKER_ID.set(ptr.comm.maybe_uninit_worker_id());
1208 ptr.comm.send_message(Message::Handle(ptr.comm.worker_id()));
1209 }
1210
1211 #[inline]
1212 fn execute_by_state(&self, cur: SubState, steal: &mut cb::Steal<Task>) -> Option<SubState> {
1213 match cur {
1214 SubState::Wait => {
1215 self.comm.wait();
1216 if self.comm.signal().is_abort() {
1217 Some(SubState::Abort)
1218 } else {
1219 Some(SubState::Search)
1220 }
1221 }
1222 SubState::Search => {
1223 *steal = self.comm.search();
1224 if steal.is_success() {
1225 Some(SubState::Work)
1226 } else if self.need_close.take() {
1227 Some(SubState::Close)
1228 } else if self.can_close() {
1229 self.need_close.set(true);
1232 Some(SubState::Search)
1233 } else {
1234 Some(SubState::Wait)
1235 }
1236 }
1237 SubState::Work => {
1238 self.work(steal);
1239 Some(SubState::Search)
1240 }
1241 SubState::Abort => {
1242 self.abort();
1243 Some(SubState::Search)
1244 }
1245 SubState::Close => {
1246 self.comm.signal().sub().notify_all();
1247 None
1248 }
1249 SubState::Reset => {
1250 SUB_CONTEXT.set(NonNullExt::dangling());
1251 WORKER_ID.set(WorkerId::dummy());
1252 Some(SubState::Search)
1253 }
1254 }
1255 }
1256
1257 #[inline]
1258 fn can_close(&self) -> bool {
1259 let fut_cnt = self.comm.signal().future_count();
1264 let work_cnt = self.comm.signal().work_count();
1265 if fut_cnt > 0 || work_cnt > 0 {
1266 return false;
1267 }
1268
1269 if self.guide.need_close() {
1272 return true;
1273 }
1274
1275 false
1277 }
1278
1279 pub(super) fn work(&self, steal: &mut cb::Steal<Task>) {
1280 self.comm.signal().add_work_count(1);
1282
1283 loop {
1286 match mem::replace(steal, cb::Steal::Empty) {
1287 cb::Steal::Success(cur) => match cur {
1288 Task::System(task) => self.work_for_system_task(task),
1289 Task::Parallel(task) => self.work_for_parallel_task(task),
1290 Task::Async(task) => self.work_for_async_task(task),
1291 },
1292 cb::Steal::Empty => break,
1293 cb::Steal::Retry => {}
1294 }
1295 *steal = self.comm.pop();
1296 }
1297
1298 self.comm.signal().sub_work_count(1);
1302 }
1303
1304 fn work_for_system_task(&self, task: SysTask) {
1305 let wid = self.comm.worker_id();
1306 let sid = task.sid();
1307
1308 let resp = match task.execute(wid) {
1309 Ok(_) => Message::Fin(self.comm.worker_id(), sid),
1310 Err(payload) => Message::Panic(PanicMessage {
1311 wid: self.comm.worker_id(),
1312 sid,
1313 payload,
1314 unrecoverable: false,
1315 }),
1316 };
1317
1318 self.comm.send_message(resp);
1319 }
1320
1321 fn work_for_parallel_task(&self, task: ParTask) {
1322 let wid = self.comm.worker_id();
1323
1324 task.execute(wid, FnContext::MIGRATED);
1325 }
1326
1327 fn work_for_async_task(&self, task: AsyncTask) {
1328 let waker = UnsafeWaker::new(self as *const SubContext); unsafe {
1331 if !task.will_wake(&waker) {
1332 task.set_waker(waker);
1333 }
1334 }
1335
1336 let wid = self.comm.worker_id();
1338 let on_ready = |ready| {
1339 self.comm.signal().sub_future_count(1);
1341
1342 let cmd = CommandObject::Future(ready);
1344 self.comm.send_command_or_cancel(cmd);
1345 };
1346 task.execute(wid, on_ready);
1347 }
1348
1349 fn abort(&self) {
1355 loop {
1356 match self.comm.pop() {
1357 cb::Steal::Success(task) => self.abort_task(task),
1358 cb::Steal::Empty => {
1359 if self.comm.signal().future_count() == 0 {
1360 self.comm.signal().sub().notify_all();
1363 break;
1364 }
1365
1366 self.comm.wait();
1368 }
1369 cb::Steal::Retry => {}
1370 }
1371 }
1372 }
1373
1374 fn abort_task(&self, task: Task) {
1375 match task {
1376 Task::System(task) => {
1378 let wid = self.comm.worker_id();
1379 let sid = task.sid();
1380 self.comm.send_message(Message::Aborted(wid, sid));
1381 }
1382 Task::Parallel(task) => {
1385 self.work_for_parallel_task(task);
1386 }
1387 Task::Async(task) => {
1390 unsafe { task.destroy() };
1393 self.comm.signal().sub_future_count(1);
1394 }
1395 }
1396 }
1397
1398 fn set_flags(self: Pin<&mut Self>, flags: Arc<GlobalSignal>) {
1399 unsafe {
1401 let this = self.get_unchecked_mut();
1402 this.comm.set_signal(flags);
1403 }
1404 }
1405
1406 #[cfg(debug_assertions)]
1407 fn validate_clean(&self) {
1408 let mut v = Vec::new();
1410 while !self.guide.is_empty() {
1411 v.push(self.guide.pop());
1412 }
1413 if !v.is_empty() {
1414 panic!("guide is not empry: {v:?}");
1415 }
1416
1417 match self.get_comm().search() {
1419 cb::Steal::Empty => {}
1420 _ => panic!("validation failed due to remaining task"),
1421 }
1422 }
1423}
1424
1425mod sub {
1427 use super::SubState;
1428 use crate::ds::ArrayDeque;
1429 use std::sync::{
1430 atomic::{AtomicU32, Ordering},
1431 Mutex,
1432 };
1433
1434 #[derive(Debug)]
1457 pub(super) struct SubStateGuide {
1458 queue: Mutex<ArrayDeque<SubState, 8>>,
1464
1465 close: AtomicU32,
1469
1470 #[cfg(debug_assertions)]
1474 open: std::cell::Cell<bool>,
1475 }
1476
1477 impl SubStateGuide {
1478 pub(super) fn new() -> Self {
1479 Self {
1480 queue: Mutex::new(ArrayDeque::new()),
1481 close: AtomicU32::new(0),
1482 #[cfg(debug_assertions)]
1483 open: std::cell::Cell::new(false),
1484 }
1485 }
1486
1487 pub(super) fn is_empty(&self) -> bool {
1488 let queue = self.queue.lock().unwrap();
1489 queue.is_empty()
1490 }
1491
1492 pub(super) fn push_open(&self) -> bool {
1503 #[cfg(debug_assertions)]
1504 {
1505 assert!(!self.open.get());
1507 self.open.set(true);
1508 }
1509
1510 let mut queue = self.queue.lock().unwrap();
1513 if queue.capacity() - queue.len() < 2 {
1514 debug_assert_eq!(queue[queue.len() - 1], SubState::Close);
1515 debug_assert_eq!(queue[queue.len() - 2], SubState::Wait);
1516
1517 queue.pop_back();
1518 return false;
1519 }
1520
1521 queue.push_back(SubState::Wait);
1523 true
1524 }
1525
1526 pub(super) fn push_close(&self) {
1531 #[cfg(debug_assertions)]
1532 {
1533 assert!(self.open.get());
1535 self.open.set(false);
1536 }
1537
1538 let mut queue = self.queue.lock().unwrap();
1539 let must_true = queue.push_back(SubState::Close);
1540 debug_assert!(must_true);
1541 self.close.fetch_add(1, Ordering::Relaxed);
1542 }
1543
1544 pub(super) fn push_reset(&self) {
1545 #[cfg(debug_assertions)]
1546 {
1547 assert!(!self.open.get());
1549 self.open.set(true);
1550 }
1551
1552 let mut queue = self.queue.lock().unwrap();
1553 let must_true = queue.push_back(SubState::Reset);
1554 debug_assert!(must_true);
1555 }
1556
1557 #[allow(dead_code)]
1563 pub(super) fn push_search(&self) -> bool {
1564 let mut queue = self.queue.lock().unwrap();
1565 queue.push_front(SubState::Search)
1566 }
1567
1568 pub(super) fn pop(&self) -> SubState {
1570 let mut queue = self.queue.lock().unwrap();
1571 queue.pop_front().unwrap()
1572 }
1573
1574 pub(super) fn need_close(&self) -> bool {
1576 if self.close.load(Ordering::Relaxed) == 0 {
1577 return false;
1578 }
1579
1580 let mut queue = self.queue.lock().unwrap();
1581 if queue.front() != Some(&SubState::Close) {
1582 return false;
1583 }
1584
1585 self.close.fetch_sub(1, Ordering::Relaxed);
1586 queue.pop_front();
1587 true
1588 }
1589 }
1590}
1591
1592#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1593enum SubState {
1594 Wait,
1595 Search,
1596 Work,
1597 Abort,
1598 Close,
1599 Reset,
1600}
1601
1602#[derive(Debug, Clone)]
1603pub(crate) struct MainWaker {
1604 tx_dedi: ParkingSender<Task>,
1605 tid: ThreadId,
1606}
1607
1608impl MainWaker {
1609 pub(crate) fn new(tx_dedi: ParkingSender<Task>) -> Self {
1610 Self {
1611 tx_dedi,
1612 tid: thread::current().id(),
1613 }
1614 }
1615}
1616
1617impl WakeSend for MainWaker {
1618 fn wake_send(&self, handle: UnsafeFuture) {
1619 let task = Task::Async(AsyncTask(handle));
1620
1621 if self.tx_dedi.send(task).is_err() {
1626 unsafe { handle.destroy() };
1629 }
1630 }
1631}
1632
1633impl PartialEq for MainWaker {
1634 fn eq(&self, other: &Self) -> bool {
1635 self.tid == other.tid
1636 }
1637}
1638
1639#[derive(Debug, Clone, Copy, PartialEq)]
1640#[repr(transparent)]
1641pub(crate) struct UnsafeWaker {
1642 cx: *const SubContext,
1643}
1644
1645impl UnsafeWaker {
1646 pub(crate) const fn new(cx: *const SubContext) -> Self {
1647 Self { cx }
1648 }
1649}
1650
1651unsafe impl Send for UnsafeWaker {}
1652unsafe impl Sync for UnsafeWaker {}
1653
1654impl WakeSend for UnsafeWaker {
1655 fn wake_send(&self, handle: UnsafeFuture) {
1656 let comm = unsafe { self.cx.as_ref().unwrap_unchecked().get_comm() };
1658
1659 comm.push_future_task(handle);
1661
1662 comm.wake_self();
1665 }
1666}