1#[cfg(feature = "semaphore-total")]
3use crate::SyncWriteGuard;
4#[cfg(doc)]
5use crate::{FIFO, LIFO, Mutex, queue::*};
6use crate::{
7 Priority, RwLock,
8 queue::{PriorityQueue, PriorityQueueHandle},
9 waiter::{self, Waiter, WaiterFlagFut},
10};
11use core::{
12 cmp::Ordering,
13 error::Error,
14 fmt::{Debug, Display},
15 marker::PhantomData,
16 mem::ManuallyDrop,
17 usize,
18};
19
20#[cfg(feature = "const-default")]
21use const_default::ConstDefault;
22pub trait SemaphoreQueue<P: Priority>: PriorityQueue<SemaphoreWaiter<P>> {}
23impl<T: PriorityQueue<SemaphoreWaiter<P>>, P: Priority> SemaphoreQueue<P> for T {}
24
25#[derive(Debug)]
26pub struct SemaphoreWaiter<P: Priority> {
30 priority: P,
31 waiter: Waiter,
32 count: usize,
33}
34
35impl<P: Priority> SemaphoreWaiter<P> {
36 #[inline(always)]
37 fn count(&self) -> usize {
38 if cfg!(feature = "semaphore-total") {
39 self.count & !WITHIN_TOTAL_BIT
40 } else {
41 self.count
42 }
43 }
44
45 #[cfg(feature = "semaphore-total")]
46 #[inline(always)]
47 fn count_and_flag(&self) -> (usize, bool) {
48 (
49 self.count & !WITHIN_TOTAL_BIT,
50 self.count & WITHIN_TOTAL_BIT != 0,
51 )
52 }
53}
54
55impl<P: Priority> Priority for SemaphoreWaiter<P> {
61 #[inline]
62 fn compare(&self, other: &Self) -> core::cmp::Ordering {
63 match (self.waiter.has_lock(), other.waiter.has_lock()) {
64 (true, false) => Ordering::Greater,
65 (false, true) => Ordering::Less,
66 _ => self.priority.compare(&other.priority),
71 }
72 }
73
74 #[inline]
75 fn compare_new(&self, old: &Self) -> Ordering {
76 match (self.waiter.has_lock(), old.waiter.has_lock()) {
77 (true, false) => Ordering::Greater,
78 (false, true) => Ordering::Less,
79 (is_held, _) => {
80 let ret = self.priority.compare_new(&old.priority);
81
82 if !is_held {
83 return ret.then_with(|| self.count().compare(&old.count()).reverse());
90 }
91 ret
92 }
93 }
94 }
95}
96
97#[cfg(feature = "arena-queue")]
98type DefaultSemaphoreQueue_<P> = crate::queue::DualLinkArenaQueue<SemaphoreWaiter<P>>;
99#[cfg(all(feature = "box-queue", not(feature = "arena-queue")))]
100type DefaultSemaphoreQueue_<P> = crate::queue::DualLinkBoxQueue<SemaphoreWaiter<P>>;
101
102#[cfg(any(feature = "arena-queue", feature = "box-queue"))]
116pub type DefaultSemaphoreQueue<P> = DefaultSemaphoreQueue_<P>;
117
118#[derive(Default)]
119struct SemaphoreInner<P: Priority, Q: SemaphoreQueue<P>> {
126 queue: Q,
129 #[cfg(feature = "semaphore-total")]
130 total: usize,
131 available: usize,
132 _phantom: PhantomData<P>,
133}
134
135impl<P: Priority, Q: SemaphoreQueue<P>> SemaphoreInner<P, Q> {
136 #[inline]
137 fn activate_waiters(&mut self, mut next: Option<Q::SharedHandle>) {
138 while let Some(handle) = next.take() {
139 let node = self.queue.get_by_handle(handle.as_ref());
140 let flags = node.waiter.flags();
141 let is_held = flags & waiter::WAITER_FLAG_HAS_LOCK != 0;
142
143 let count = node.count();
144 next = self.queue.get_next_handle(handle.as_ref());
145
146 if is_held {
150 continue;
151 }
152
153 if count > self.available {
154 let will_evict = cfg!(feature = "semaphore-total")
157 && flags & waiter::WAITER_FLAG_WANTS_EVICT != 0;
158
159 if will_evict {
161 continue;
162 }
163
164 break;
165 }
166
167 self.available -= count;
168 self.queue.update_node(handle.as_ref(), |x| {
169 x.waiter.start();
170 true
171 });
172 }
173 }
174
175 #[cfg(feature = "semaphore-total")]
176 #[inline]
177 fn notify_oversized_waiters(&self, start: Option<&Q::Handle>) {
178 for node in self.queue.iter_at(start) {
179 let (count, should_evict) = node.count_and_flag();
183
184 if should_evict && count > self.total {
185 node.waiter.evict();
186 }
187 }
188 }
189}
190
191#[cfg(feature = "const-default")]
192impl<P: Priority, Q: ConstDefault + SemaphoreQueue<P>> ConstDefault for SemaphoreInner<P, Q> {
193 const DEFAULT: Self = Self {
194 queue: Q::DEFAULT,
195 #[cfg(feature = "semaphore-total")]
196 total: 0,
197 available: 0,
198 _phantom: PhantomData,
199 };
200}
201
202struct SemaphorePermitWaiter<'a, P: Priority, Q: SemaphoreQueue<P>> {
208 semaphore: &'a Semaphore<P, Q>,
209 handle: ManuallyDrop<Q::Handle>,
210}
211
212unsafe impl<'a, P: Priority, Q: SemaphoreQueue<P>> Sync for SemaphorePermitWaiter<'a, P, Q>
213where
214 Semaphore<P, Q>: Sync,
215 Q::Handle: Sync,
216{
217}
218
219unsafe impl<'a, P: Priority, Q: SemaphoreQueue<P>> Send for SemaphorePermitWaiter<'a, P, Q>
220where
221 Semaphore<P, Q>: Sync,
222 Q::Handle: Send,
223{
224}
225
226impl<'a, P: Priority, Q: SemaphoreQueue<P>> SemaphorePermitWaiter<'a, P, Q> {
227 const HAS_PURE_LOAD: bool = Q::Handle::LOAD_PURE.is_some();
228}
229
230impl<'a, P: Priority, Q: SemaphoreQueue<P>> waiter::WaiterHandle
231 for SemaphorePermitWaiter<'a, P, Q>
232{
233 #[inline]
234 fn with_waker<T>(&self, f: impl FnOnce(&Waiter) -> T) -> T {
235 if Self::HAS_PURE_LOAD {
236 unsafe { f(&Q::Handle::LOAD_PURE.unwrap_unchecked()(&self.handle).waiter) }
237 } else {
238 let sem = self.semaphore.0.read();
239 f(&sem.queue.get_by_handle(&self.handle).waiter)
240 }
241 }
242}
243
244impl<'a, P: Priority, Q: SemaphoreQueue<P>> Drop for SemaphorePermitWaiter<'a, P, Q> {
245 #[inline]
246 fn drop(&mut self) {
247 let mut sem = self.semaphore.0.write();
248
249 let handle = unsafe { ManuallyDrop::take(&mut self.handle) };
250 let node = sem.queue.get_by_handle(&handle);
251 let is_active = node.waiter.has_lock();
252 if is_active {
253 sem.available += node.count();
254 }
255
256 let has_available = is_active || sem.available > 0;
260
261 let (prev, next) = sem.queue.dequeue(handle);
262
263 if next.is_none() || !has_available {
264 return;
265 }
266
267 if is_active {
269 return sem.activate_waiters(next);
270 }
271
272 if prev.is_none_or(|x| x.waiter.has_lock()) {
275 sem.activate_waiters(next);
276 }
277 }
278}
279
280#[repr(transparent)]
281pub struct SemaphorePermit<'a, P: Priority, Q: SemaphoreQueue<P>>(
285 ManuallyDrop<SemaphorePermitWaiter<'a, P, Q>>,
288);
289
290unsafe impl<'a, P: Priority, Q: SemaphoreQueue<P>> Sync for SemaphorePermit<'a, P, Q>
291where
292 Semaphore<P, Q>: Sync,
293 Q::Handle: Sync,
294{
295}
296
297unsafe impl<'a, P: Priority, Q: SemaphoreQueue<P>> Send for SemaphorePermit<'a, P, Q>
298where
299 Semaphore<P, Q>: Sync,
300 Q::Handle: Send,
301{
302}
303
304impl<'a, P: Priority, Q: SemaphoreQueue<P>> SemaphorePermit<'a, P, Q> {
305 #[cfg(feature = "evict")]
306 #[inline]
307 pub fn evicted(&mut self) -> impl Future<Output = ()> {
312 waiter::VoidFut(WaiterFlagFut::<_, { waiter::WAITER_FLAG_WANTS_EVICT }>::new(&*self.0))
313 }
314
315 #[inline]
317 pub fn forget(mut self) {
318 let mut sem = self.0.semaphore.0.write();
319 #[cfg_attr(not(feature = "semaphore-total"), allow(unused))]
320 let count = sem.queue.get_by_handle(&self.0.handle).count();
321
322 #[cfg_attr(not(feature = "semaphore-total"), allow(unused))]
323 let (_, maybe_next) = sem
324 .queue
325 .dequeue(unsafe { ManuallyDrop::take(&mut self.0.handle) });
326
327 core::mem::forget(self);
328
329 #[cfg(feature = "semaphore-total")]
330 {
331 sem.total -= count;
332
333 if let Some(next) = maybe_next {
334 sem.downgrade()
335 .notify_oversized_waiters(Some(next.as_ref()));
336 }
337 }
338 }
339
340 #[inline]
341 pub fn permits(&self) -> usize {
343 if SemaphorePermitWaiter::<'a, P, Q>::HAS_PURE_LOAD {
344 return unsafe { Q::Handle::LOAD_PURE.unwrap_unchecked()(&self.0.handle).count() };
345 }
346 let sem = self.0.semaphore.0.read();
347
348 sem.queue.get_by_handle(&self.0.handle).count()
349 }
350
351 #[inline]
352 pub fn belongs_to(&self, semaphore: &Semaphore<P, Q>) -> bool {
354 core::ptr::eq(self.0.semaphore, semaphore)
355 }
356
357 #[inline]
364 pub fn split(&mut self, count: usize) -> Result<Self, InsufficientPermitsError>
365 where
366 P: Clone,
367 {
368 assert!(
369 count > 0,
370 "count must be greater than zero, received {count}"
371 );
372 let mut sem = self.0.semaphore.0.write();
373
374 let mut priority: Option<P> = None;
375 let mut avail = 0;
376
377 sem.queue.update_node(&self.0.handle, |node| {
378 avail = node.count();
379 if avail > count {
381 node.count -= count;
382 priority = Some(node.priority.clone());
383 }
384
385 false
386 });
387
388 if priority.is_none() {
389 return Err(InsufficientPermitsError {
390 total: avail,
391 requested: count,
392 });
393 }
394
395 let handle = sem.queue.enqueue(SemaphoreWaiter {
396 priority: priority.unwrap(),
397 waiter: Waiter::new(true),
398 count,
399 });
400
401 Ok(SemaphorePermitWaiter {
402 semaphore: self.0.semaphore,
403 handle: ManuallyDrop::new(handle),
404 }
405 .into())
406 }
407
408 pub fn split_with_priority(
409 &mut self,
410 count: usize,
411 priority: P,
412 ) -> Result<Self, InsufficientPermitsError> {
413 assert!(
414 count > 0,
415 "count must be greater than zero, received {count}"
416 );
417 let mut sem = self.0.semaphore.0.write();
418
419 let mut avail = 0;
420 let mut has_capacity = false;
421
422 sem.queue.update_node(&self.0.handle, |node| {
423 avail = node.count();
424 if avail > count {
426 node.count -= count;
427 has_capacity = true
428 }
429
430 false
431 });
432
433 if !has_capacity {
434 return Err(InsufficientPermitsError {
435 total: avail,
436 requested: count,
437 });
438 }
439
440 let handle = sem.queue.enqueue(SemaphoreWaiter {
441 priority: priority.into(),
442 waiter: Waiter::new(true),
443 count,
444 });
445
446 Ok(SemaphorePermitWaiter {
447 semaphore: self.0.semaphore,
448 handle: ManuallyDrop::new(handle),
449 }
450 .into())
451 }
452
453 #[inline]
454 pub fn merge(&mut self, mut other: Self) -> Result<(), ()> {
459 if &raw const *self.0.semaphore != other.0.semaphore {
460 return Err(());
461 }
462
463 let mut sem = self.0.semaphore.0.write();
464
465 let other_count = sem.queue.get_by_handle(&other.0.handle).count();
466
467 let mut would_overflow = false;
468 sem.queue.update_node(&self.0.handle, |node| {
469 would_overflow = node.count() + other_count > MAX_PERMITS;
470 if !would_overflow {
471 node.count += other_count
472 }
473
474 false
475 });
476
477 if would_overflow {
478 return Err(());
479 }
480
481 let other_handle = unsafe { ManuallyDrop::take(&mut other.0.handle) };
482 core::mem::forget(other);
483
484 sem.queue.dequeue(other_handle);
485
486 Ok(())
487 }
488}
489
490impl<'a, P: Priority, Q: SemaphoreQueue<P>> Drop for SemaphorePermit<'a, P, Q> {
491 #[inline]
492 fn drop(&mut self) {
494 let mut sem = self.0.semaphore.0.write();
495
496 let handle = unsafe { ManuallyDrop::take(&mut self.0.handle) };
497 let count = sem.queue.get_by_handle(&handle).count();
498 let (_, next) = sem.queue.dequeue(handle);
499 if cfg!(feature = "semaphore-total") {
500 sem.available += count;
501 } else {
502 sem.available = match sem.available.checked_add(count) {
505 Some(x) => x,
506 None => {
507 let avail = sem.available;
508 drop(sem);
510 core::panic!(
511 "failed to release {} permits back to semaphore as that would overflow (current available: {})",
512 count,
513 avail
514 );
515 }
516 }
517 }
518
519 sem.activate_waiters(next);
520 }
521}
522
523impl<'a, P: Priority, Q: SemaphoreQueue<P>> From<SemaphorePermitWaiter<'a, P, Q>>
524 for SemaphorePermit<'a, P, Q>
525{
526 #[inline(always)]
527 fn from(value: SemaphorePermitWaiter<'a, P, Q>) -> Self {
528 Self(ManuallyDrop::new(value))
529 }
530}
531
532impl<'a, P: Priority, Q: SemaphoreQueue<P>> Debug for SemaphorePermit<'a, P, Q> {
533 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
534 f.debug_struct("SemaphorePermit")
535 .field("permits", &self.permits())
536 .finish()
537 }
538}
539
540#[derive(Default)]
553pub struct Semaphore<
554 P: Priority,
555 #[cfg(any(feature = "arena-queue", feature = "box-queue"))] Q: SemaphoreQueue<P> = DefaultSemaphoreQueue<P>,
556 #[cfg(not(any(feature = "arena-queue", feature = "box-queue")))] Q: SemaphoreQueue<P>,
557>(RwLock<SemaphoreInner<P, Q>>);
558
559unsafe impl<P: Priority, Q: SemaphoreQueue<P> + Send + Sync> Sync for Semaphore<P, Q> {}
560unsafe impl<P: Priority, Q: SemaphoreQueue<P> + Send> Send for Semaphore<P, Q> {}
561
562pub const MAX_PERMITS: usize = usize::MAX >> 1;
564const WITHIN_TOTAL_BIT: usize = 1 << (usize::BITS - 1);
569
570#[derive(Debug, Clone, Copy, PartialEq, Eq)]
584pub struct InsufficientPermitsError {
585 total: usize,
592 requested: usize,
593}
594
595impl InsufficientPermitsError {
596 #[inline(always)]
599 pub fn total(&self) -> Option<usize> {
600 (self.total != usize::MAX).then_some(self.total)
601 }
602
603 #[inline(always)]
605 pub fn requested(&self) -> usize {
606 self.requested
607 }
608}
609
610impl Display for InsufficientPermitsError {
611 #[inline]
612 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
613 if self.total == usize::MAX {
614 write!(
615 f,
616 "semaphoer lacks sufficient permits: want {}",
617 self.requested
618 )
619 } else {
620 write!(
621 f,
622 "insufficient total permits: have {} want {}",
623 self.total, self.requested
624 )
625 }
626 }
627}
628
629impl Error for InsufficientPermitsError {}
630
631impl<P: Priority, Q: SemaphoreQueue<P>> Debug for Semaphore<P, Q> {
632 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
633 let sem = self.0.read();
634 let mut dbg = f.debug_struct("Semaphore");
635 dbg.field("available", &sem.available);
636
637 #[cfg(feature = "semaphore-total")]
638 dbg.field("total", &sem.total);
639 dbg.finish()
640 }
641}
642
643impl<P: Priority, Q: SemaphoreQueue<P>> Semaphore<P, Q> {
644 #[inline]
645 pub fn new(capacity: usize) -> Self {
647 Self(RwLock::new(SemaphoreInner {
648 queue: Q::default(),
649 #[cfg(feature = "semaphore-total")]
650 total: capacity,
651 available: capacity,
652 _phantom: PhantomData,
653 }))
654 }
655
656 #[cfg(feature = "const-default")]
657 pub const fn const_new(permits: usize) -> Self
662 where
663 Q: ConstDefault,
664 {
665 Self(RwLock::new(SemaphoreInner {
666 queue: Q::DEFAULT,
667 #[cfg(feature = "semaphore-total")]
668 total: permits,
669 available: permits,
670 _phantom: PhantomData,
671 }))
672 }
673
674 #[inline]
675 fn do_acquire(
676 &self,
677 inner: &mut SemaphoreInner<P, Q>,
678 priority: P,
679 count: usize,
680 ) -> (SemaphorePermitWaiter<'_, P, Q>, bool) {
681 let has_available = inner.available >= (count & !WITHIN_TOTAL_BIT);
682 let will_acquire = has_available
683 && inner
684 .queue
685 .iter()
686 .skip_while(|x| {
688 let flags = x.waiter.flags();
689 if flags & waiter::WAITER_FLAG_HAS_LOCK != 0 {
690 #[cfg(feature = "evict")]
691 if priority.compare(&x.priority).is_gt() {
692 x.waiter.evict();
693 }
694 return true;
695 }
696
697 cfg!(feature = "semaphore-total")
700 && flags & waiter::WAITER_FLAG_WANTS_EVICT == 0
701 })
702 .next()
703 .is_none_or(|first_pending| priority.compare(&first_pending.priority).is_ge());
706
707 let handle = inner.queue.enqueue(SemaphoreWaiter {
708 priority,
709 waiter: Waiter::new(will_acquire),
710 count,
711 });
712
713 let guard = SemaphorePermitWaiter {
714 semaphore: self,
715 handle: ManuallyDrop::new(handle),
716 };
717
718 if will_acquire {
719 inner.available -= count & !WITHIN_TOTAL_BIT;
720 }
721
722 #[cfg(feature = "evict")]
723 if !has_available {
726 let node = inner.queue.get_by_handle(&guard.handle);
727 for ex in inner.queue.iter() {
728 if ex.waiter.has_lock() {
729 if node.priority.compare(&ex.priority).is_gt() {
732 ex.waiter.evict();
733 }
734 }
735 }
736 }
737
738 return (guard, will_acquire);
739 }
740
741 #[inline]
744 pub fn acquire(&self, priority: P) -> impl Future<Output = SemaphorePermit<'_, P, Q>> {
745 self.acquire_many(priority, 1)
746 }
747
748 #[inline(always)]
754 pub fn acquire_from(
755 &self,
756 priority: impl Into<P>,
757 ) -> impl Future<Output = SemaphorePermit<'_, P, Q>> {
758 self.acquire(priority.into())
759 }
760
761 #[inline]
762 pub fn acquire_default(&self) -> impl Future<Output = SemaphorePermit<'_, P, Q>>
768 where
769 P: Default,
770 {
771 self.acquire_many(Default::default(), 1)
772 }
773
774 pub async fn acquire_many(&self, priority: P, count: usize) -> SemaphorePermit<'_, P, Q> {
782 assert!(
783 count < MAX_PERMITS,
784 "count for a single holder must be less than {} and not zero (received {})",
785 MAX_PERMITS,
786 count
787 );
788 let guard = {
790 let mut inner = self.0.write();
791 let (guard, did_acquire) = self.do_acquire(&mut inner, priority.into(), count);
792
793 if did_acquire {
794 return guard.into();
795 }
796 guard
797 };
798
799 WaiterFlagFut::<_, { waiter::WAITER_FLAG_HAS_LOCK }>::new(&guard).await;
800
801 guard.into()
802 }
803
804 #[inline(always)]
812 pub async fn acquire_many_default(
813 &self,
814 count: usize,
815 ) -> impl Future<Output = SemaphorePermit<'_, P, Q>>
816 where
817 P: Default,
818 {
819 self.acquire_many(Default::default(), count)
820 }
821
822 #[inline(always)]
830 pub async fn acquire_many_from(
831 &self,
832 count: usize,
833 priority: impl Into<P>,
834 ) -> impl Future<Output = SemaphorePermit<'_, P, Q>>
835 where
836 P: Default,
837 {
838 self.acquire_many(priority.into(), count)
839 }
840
841 #[cfg(feature = "semaphore-total")]
842 pub async fn acquire_within_total(
867 &self,
868 priority: P,
869 count: usize,
870 ) -> Result<SemaphorePermit<'_, P, Q>, InsufficientPermitsError> {
871 assert!(
872 count < MAX_PERMITS,
873 "count for a single holder must be less than {} and not zero (received {})",
874 MAX_PERMITS,
875 count
876 );
877 let guard = {
878 let mut inner = self.0.write();
879 if inner.total < count {
880 return Err(InsufficientPermitsError {
881 total: inner.total,
882 requested: count,
883 });
884 }
885
886 let (guard, did_acquire) =
887 self.do_acquire(&mut inner, priority.into(), count | WITHIN_TOTAL_BIT);
888
889 if did_acquire {
890 return Ok(guard.into());
891 }
892
893 guard
894 };
895
896 let flags = WaiterFlagFut::<
897 _,
898 { waiter::WAITER_FLAG_HAS_LOCK | waiter::WAITER_FLAG_WANTS_EVICT },
899 >::new(&guard)
900 .await;
901
902 if flags & waiter::WAITER_FLAG_HAS_LOCK == 0 {
903 return Err(InsufficientPermitsError {
906 total: usize::MAX,
907 requested: count,
908 });
909 }
910
911 Ok(guard.into())
915 }
916
917 #[inline]
918 pub fn add_permits(&self, count: usize) -> usize {
936 self.try_add_permits(count).expect("must add permits")
937 }
938
939 #[cfg(feature = "semaphore-total")]
940 #[inline]
941 pub fn total_permits(&self) -> usize {
943 self.0.read().total
944 }
945
946 #[inline]
948 pub fn available_permits(&self) -> usize {
949 self.0.read().available
950 }
951
952 #[inline]
953 pub fn try_add_permits(&self, count: usize) -> Result<usize, ()> {
956 let mut inner = self.0.write();
957
958 #[cfg(feature = "semaphore-total")]
959 {
960 inner.total = inner.total.checked_add(count).ok_or(())?;
961 inner.available += count
963 }
964 #[cfg(not(feature = "semaphore-total"))]
965 {
966 inner.available = inner.available.checked_add(count).ok_or(())?;
967 }
968
969 let head = inner.queue.head_handle();
970 inner.activate_waiters(head);
971
972 Ok(inner.available)
973 }
974
975 #[inline]
984 pub fn forget_permits(&self, mut count: usize) -> usize {
985 let mut inner = self.0.write();
986
987 count = count.min(inner.available);
988
989 inner.available -= count;
990
991 #[cfg(feature = "semaphore-total")]
992 if count != 0 {
993 unsafe { core::hint::assert_unchecked(inner.total >= count) };
994 inner.total -= count;
995
996 inner.downgrade().notify_oversized_waiters(None);
997 }
998
999 count
1000 }
1001}
1002
1003#[cfg(feature = "const-default")]
1004impl<P: Priority, Q: ConstDefault + SemaphoreQueue<P>> ConstDefault for Semaphore<P, Q> {
1005 const DEFAULT: Self = Self(RwLock::new(ConstDefault::DEFAULT));
1006}