1#[cfg(feature = "semaphore-total")]
3use crate::SyncWriteGuard;
4#[cfg(doc)]
5use crate::{FIFO, LIFO, Mutex, queue::*};
6use crate::{
7 Priority, RwLock,
8 queue::{PriorityQueue, PriorityQueueHandle},
9 waiter::{self, Waiter, WaiterFlagFut},
10};
11use core::{
12 cmp::Ordering,
13 error::Error,
14 fmt::{Debug, Display},
15 marker::PhantomData,
16 mem::ManuallyDrop,
17 usize,
18};
19
20#[cfg(feature = "const-default")]
21use const_default::ConstDefault;
22pub trait SemaphoreQueue<P: Priority>: PriorityQueue<SemaphoreWaiter<P>> {}
23impl<T: PriorityQueue<SemaphoreWaiter<P>>, P: Priority> SemaphoreQueue<P> for T {}
24
25#[derive(Debug)]
26pub struct SemaphoreWaiter<P: Priority> {
30 priority: P,
31 waiter: Waiter,
32 count: usize,
33}
34
35impl<P: Priority> SemaphoreWaiter<P> {
36 #[inline(always)]
37 fn count(&self) -> usize {
38 if cfg!(feature = "semaphore-total") {
39 self.count & !WITHIN_TOTAL_BIT
40 } else {
41 self.count
42 }
43 }
44
45 #[cfg(feature = "semaphore-total")]
46 #[inline(always)]
47 fn count_and_flag(&self) -> (usize, bool) {
48 (
49 self.count & !WITHIN_TOTAL_BIT,
50 self.count & WITHIN_TOTAL_BIT != 0,
51 )
52 }
53}
54
55impl<P: Priority> Priority for SemaphoreWaiter<P> {
61 #[inline]
62 fn compare(&self, other: &Self) -> core::cmp::Ordering {
63 match (self.waiter.has_lock(), other.waiter.has_lock()) {
64 (true, false) => Ordering::Greater,
65 (false, true) => Ordering::Less,
66 _ => self.priority.compare(&other.priority),
71 }
72 }
73
74 #[inline]
75 fn compare_new(&self, old: &Self) -> Ordering {
76 match (self.waiter.has_lock(), old.waiter.has_lock()) {
77 (true, false) => Ordering::Greater,
78 (false, true) => Ordering::Less,
79 (is_held, _) => {
80 let ret = self.priority.compare_new(&old.priority);
81
82 if !is_held {
83 return ret.then_with(|| self.count().compare(&old.count()).reverse());
90 }
91 ret
92 }
93 }
94 }
95}
96
97#[cfg(feature = "arena-queue")]
98type DefaultSemaphoreQueue_<P> = crate::queue::DualLinkArenaQueue<SemaphoreWaiter<P>>;
99#[cfg(all(feature = "box-queue", not(feature = "arena-queue")))]
100type DefaultSemaphoreQueue_<P> = crate::queue::DualLinkBoxQueue<SemaphoreWaiter<P>>;
101
102#[cfg(any(feature = "arena-queue", feature = "box-queue"))]
116pub type DefaultSemaphoreQueue<P> = DefaultSemaphoreQueue_<P>;
117
118#[derive(Default)]
119struct SemaphoreInner<P: Priority, Q: SemaphoreQueue<P>> {
126 queue: Q,
129 #[cfg(feature = "semaphore-total")]
130 total: usize,
131 available: usize,
132 _phantom: PhantomData<P>,
133}
134
135impl<P: Priority, Q: SemaphoreQueue<P>> SemaphoreInner<P, Q> {
136 #[inline]
137 fn activate_waiters(&mut self, mut next: Option<Q::SharedHandle>) {
138 while let Some(handle) = next.take() {
139 let node = self.queue.get_by_handle(handle.as_ref());
140 let flags = node.waiter.flags();
141 let is_held = flags & waiter::WAITER_FLAG_HAS_LOCK != 0;
142
143 let count = node.count();
144 next = self.queue.get_next_handle(handle.as_ref());
145
146 if is_held {
150 continue;
151 }
152
153 if count > self.available {
154 let will_evict = cfg!(feature = "semaphore-total")
157 && flags & waiter::WAITER_FLAG_WANTS_EVICT != 0;
158
159 if will_evict {
161 continue;
162 }
163
164 break;
165 }
166
167 self.available -= count;
168 self.queue.update_node(handle.as_ref(), |x| {
169 x.waiter.start();
170 true
171 });
172 }
173 }
174
175 #[cfg(feature = "semaphore-total")]
176 #[inline]
177 fn notify_oversized_waiters(&self, start: Option<&Q::Handle>) {
178 for node in self.queue.iter_at(start) {
179 let (count, should_evict) = node.count_and_flag();
183
184 if should_evict && count > self.total {
185 node.waiter.evict();
186 }
187 }
188 }
189}
190
191#[cfg(feature = "const-default")]
192impl<P: Priority, Q: ConstDefault + SemaphoreQueue<P>> ConstDefault for SemaphoreInner<P, Q> {
193 const DEFAULT: Self = Self {
194 queue: Q::DEFAULT,
195 #[cfg(feature = "semaphore-total")]
196 total: 0,
197 available: 0,
198 _phantom: PhantomData,
199 };
200}
201
202struct SemaphorePermitWaiter<'a, P: Priority, Q: SemaphoreQueue<P>> {
208 semaphore: &'a Semaphore<P, Q>,
209 handle: ManuallyDrop<Q::Handle>,
210}
211
212unsafe impl<'a, P: Priority, Q: SemaphoreQueue<P>> Sync for SemaphorePermitWaiter<'a, P, Q>
213where
214 Semaphore<P, Q>: Sync,
215 Q::Handle: Sync,
216{
217}
218
219unsafe impl<'a, P: Priority, Q: SemaphoreQueue<P>> Send for SemaphorePermitWaiter<'a, P, Q>
220where
221 Semaphore<P, Q>: Sync,
222 Q::Handle: Send,
223{
224}
225
226impl<'a, P: Priority, Q: SemaphoreQueue<P>> SemaphorePermitWaiter<'a, P, Q> {
227 const HAS_PURE_LOAD: bool = Q::Handle::LOAD_PURE.is_some();
228}
229
230impl<'a, P: Priority, Q: SemaphoreQueue<P>> waiter::WaiterHandle
231 for SemaphorePermitWaiter<'a, P, Q>
232{
233 #[inline]
234 fn with_waker<T>(&self, f: impl FnOnce(&Waiter) -> T) -> T {
235 if Self::HAS_PURE_LOAD {
236 unsafe { f(&Q::Handle::LOAD_PURE.unwrap_unchecked()(&self.handle).waiter) }
237 } else {
238 let sem = self.semaphore.0.read();
239 f(&sem.queue.get_by_handle(&self.handle).waiter)
240 }
241 }
242}
243
244impl<'a, P: Priority, Q: SemaphoreQueue<P>> Drop for SemaphorePermitWaiter<'a, P, Q> {
245 #[inline]
246 fn drop(&mut self) {
247 let mut sem = self.semaphore.0.write();
248
249 let handle = unsafe { ManuallyDrop::take(&mut self.handle) };
250 let node = sem.queue.get_by_handle(&handle);
251 let is_active = node.waiter.has_lock();
252 if is_active {
253 sem.available += node.count();
254 }
255
256 let has_available = is_active || sem.available > 0;
260
261 let (prev, next) = sem.queue.dequeue(handle);
262
263 if next.is_none() || !has_available {
264 return;
265 }
266
267 if is_active {
269 return sem.activate_waiters(next);
270 }
271
272 if prev.is_none_or(|x| x.waiter.has_lock()) {
275 sem.activate_waiters(next);
276 }
277 }
278}
279
280#[repr(transparent)]
281pub struct SemaphorePermit<'a, P: Priority, Q: SemaphoreQueue<P>>(
285 ManuallyDrop<SemaphorePermitWaiter<'a, P, Q>>,
288);
289
290unsafe impl<'a, P: Priority, Q: SemaphoreQueue<P>> Sync for SemaphorePermit<'a, P, Q>
291where
292 Semaphore<P, Q>: Sync,
293 Q::Handle: Sync,
294{
295}
296
297unsafe impl<'a, P: Priority, Q: SemaphoreQueue<P>> Send for SemaphorePermit<'a, P, Q>
298where
299 Semaphore<P, Q>: Sync,
300 Q::Handle: Send,
301{
302}
303
304impl<'a, P: Priority, Q: SemaphoreQueue<P>> SemaphorePermit<'a, P, Q> {
305 #[cfg(feature = "evict")]
306 #[inline]
307 pub fn evicted(&mut self) -> impl Future<Output = ()> {
312 waiter::VoidFut(WaiterFlagFut::<_, { waiter::WAITER_FLAG_WANTS_EVICT }>::new(&*self.0))
313 }
314
315 #[inline]
317 pub fn forget(mut self) {
318 let mut sem = self.0.semaphore.0.write();
319 #[cfg_attr(not(feature = "semaphore-total"), allow(unused))]
320 let count = sem.queue.get_by_handle(&self.0.handle).count();
321
322 #[cfg_attr(not(feature = "semaphore-total"), allow(unused))]
323 let (_, maybe_next) = sem
324 .queue
325 .dequeue(unsafe { ManuallyDrop::take(&mut self.0.handle) });
326
327 core::mem::forget(self);
328
329 #[cfg(feature = "semaphore-total")]
330 {
331 sem.total -= count;
332
333 if let Some(next) = maybe_next {
334 sem.downgrade()
335 .notify_oversized_waiters(Some(next.as_ref()));
336 }
337 }
338 }
339
340 #[inline]
341 pub fn permits(&self) -> usize {
343 if SemaphorePermitWaiter::<'a, P, Q>::HAS_PURE_LOAD {
344 return unsafe { Q::Handle::LOAD_PURE.unwrap_unchecked()(&self.0.handle).count() };
345 }
346 let sem = self.0.semaphore.0.read();
347
348 sem.queue.get_by_handle(&self.0.handle).count()
349 }
350
351 #[inline]
352 pub fn belongs_to(&self, semaphore: &Semaphore<P, Q>) -> bool {
354 core::ptr::eq(self.0.semaphore, semaphore)
355 }
356
357 #[inline]
364 pub fn split(&mut self, count: usize) -> Result<Self, InsufficientPermitsError>
365 where
366 P: Clone,
367 {
368 assert!(
369 count > 0,
370 "count must be greater than zero, received {count}"
371 );
372 let mut sem = self.0.semaphore.0.write();
373
374 let mut priority: Option<P> = None;
375 let mut avail = 0;
376
377 sem.queue.update_node(&self.0.handle, |node| {
378 avail = node.count();
379 if avail > count {
381 node.count -= count;
382 priority = Some(node.priority.clone());
383 }
384
385 false
386 });
387
388 if priority.is_none() {
389 return Err(InsufficientPermitsError {
390 total: avail,
391 requested: count,
392 });
393 }
394
395 let handle = sem.queue.enqueue(SemaphoreWaiter {
396 priority: priority.unwrap(),
397 waiter: Waiter::new(true),
398 count,
399 });
400
401 Ok(SemaphorePermitWaiter {
402 semaphore: self.0.semaphore,
403 handle: ManuallyDrop::new(handle),
404 }
405 .into())
406 }
407
408 pub fn split_with_priority(
409 &mut self,
410 count: usize,
411 priority: P,
412 ) -> Result<Self, InsufficientPermitsError> {
413 assert!(
414 count > 0,
415 "count must be greater than zero, received {count}"
416 );
417 let mut sem = self.0.semaphore.0.write();
418
419 let mut avail = 0;
420 let mut has_capacity = false;
421
422 sem.queue.update_node(&self.0.handle, |node| {
423 avail = node.count();
424 if avail > count {
426 node.count -= count;
427 has_capacity = true
428 }
429
430 false
431 });
432
433 if !has_capacity {
434 return Err(InsufficientPermitsError {
435 total: avail,
436 requested: count,
437 });
438 }
439
440 let handle = sem.queue.enqueue(SemaphoreWaiter {
441 priority: priority.into(),
442 waiter: Waiter::new(true),
443 count,
444 });
445
446 Ok(SemaphorePermitWaiter {
447 semaphore: self.0.semaphore,
448 handle: ManuallyDrop::new(handle),
449 }
450 .into())
451 }
452
453 #[inline]
454 pub fn merge(&mut self, mut other: Self) -> Result<(), ()> {
459 if &raw const *self.0.semaphore != other.0.semaphore {
460 return Err(());
461 }
462
463 let mut sem = self.0.semaphore.0.write();
464
465 let other_count = sem.queue.get_by_handle(&other.0.handle).count();
466
467 let mut would_overflow = false;
468 sem.queue.update_node(&self.0.handle, |node| {
469 would_overflow = node.count() + other_count > MAX_PERMITS;
470 if !would_overflow {
471 node.count += other_count
472 }
473
474 false
475 });
476
477 if would_overflow {
478 return Err(());
479 }
480
481 let other_handle = unsafe { ManuallyDrop::take(&mut other.0.handle) };
482 core::mem::forget(other);
483
484 sem.queue.dequeue(other_handle);
485
486 Ok(())
487 }
488}
489
490impl<'a, P: Priority, Q: SemaphoreQueue<P>> Drop for SemaphorePermit<'a, P, Q> {
491 #[inline]
492 fn drop(&mut self) {
494 let mut sem = self.0.semaphore.0.write();
495
496 let handle = unsafe { ManuallyDrop::take(&mut self.0.handle) };
497 let count = sem.queue.get_by_handle(&handle).count();
498 let (_, next) = sem.queue.dequeue(handle);
499 if cfg!(feature = "semaphore-total") {
500 sem.available += count;
501 } else {
502 sem.available = match sem.available.checked_add(count) {
505 Some(x) => x,
506 None => {
507 let avail = sem.available;
508 drop(sem);
510 core::panic!(
511 "failed to release {} permits back to semaphore as that would overflow (current available: {})",
512 count,
513 avail
514 );
515 }
516 }
517 }
518
519 sem.activate_waiters(next);
520 }
521}
522
523impl<'a, P: Priority, Q: SemaphoreQueue<P>> From<SemaphorePermitWaiter<'a, P, Q>>
524 for SemaphorePermit<'a, P, Q>
525{
526 #[inline(always)]
527 fn from(value: SemaphorePermitWaiter<'a, P, Q>) -> Self {
528 Self(ManuallyDrop::new(value))
529 }
530}
531
532impl<'a, P: Priority, Q: SemaphoreQueue<P>> Debug for SemaphorePermit<'a, P, Q> {
533 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
534 f.debug_struct("SemaphorePermit")
535 .field("permits", &self.permits())
536 .finish()
537 }
538}
539
540pub struct Semaphore<
553 P: Priority,
554 #[cfg(any(feature = "arena-queue", feature = "box-queue"))] Q: SemaphoreQueue<P> = DefaultSemaphoreQueue<P>,
555 #[cfg(not(any(feature = "arena-queue", feature = "box-queue")))] Q: SemaphoreQueue<P>,
556>(RwLock<SemaphoreInner<P, Q>>);
557
558unsafe impl<P: Priority, Q: SemaphoreQueue<P> + Send + Sync> Sync for Semaphore<P, Q> {}
559unsafe impl<P: Priority, Q: SemaphoreQueue<P> + Send> Send for Semaphore<P, Q> {}
560
561pub const MAX_PERMITS: usize = usize::MAX >> 1;
563const WITHIN_TOTAL_BIT: usize = 1 << (usize::BITS - 1);
568
569#[derive(Debug, Clone, Copy, PartialEq, Eq)]
583pub struct InsufficientPermitsError {
584 total: usize,
591 requested: usize,
592}
593
594impl InsufficientPermitsError {
595 #[inline(always)]
598 pub fn total(&self) -> Option<usize> {
599 (self.total != usize::MAX).then_some(self.total)
600 }
601
602 #[inline(always)]
604 pub fn requested(&self) -> usize {
605 self.requested
606 }
607}
608
609impl Display for InsufficientPermitsError {
610 #[inline]
611 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
612 if self.total == usize::MAX {
613 write!(
614 f,
615 "semaphoer lacks sufficient permits: want {}",
616 self.requested
617 )
618 } else {
619 write!(
620 f,
621 "insufficient total permits: have {} want {}",
622 self.total, self.requested
623 )
624 }
625 }
626}
627
628impl Error for InsufficientPermitsError {}
629
630impl<P: Priority, Q: SemaphoreQueue<P>> Debug for Semaphore<P, Q> {
631 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
632 let sem = self.0.read();
633 let mut dbg = f.debug_struct("Semaphore");
634 dbg.field("available", &sem.available);
635
636 #[cfg(feature = "semaphore-total")]
637 dbg.field("total", &sem.total);
638 dbg.finish()
639 }
640}
641
642impl<P: Priority, Q: SemaphoreQueue<P>> Semaphore<P, Q> {
643 #[inline]
644 pub fn new(capacity: usize) -> Self {
646 Self(RwLock::new(SemaphoreInner {
647 queue: Q::default(),
648 #[cfg(feature = "semaphore-total")]
649 total: capacity,
650 available: capacity,
651 _phantom: PhantomData,
652 }))
653 }
654
655 #[cfg(feature = "const-default")]
656 pub const fn const_new(permits: usize) -> Self
661 where
662 Q: ConstDefault,
663 {
664 Self(RwLock::new(SemaphoreInner {
665 queue: Q::DEFAULT,
666 #[cfg(feature = "semaphore-total")]
667 total: permits,
668 available: permits,
669 _phantom: PhantomData,
670 }))
671 }
672
673 #[inline]
674 fn do_acquire(
675 &self,
676 inner: &mut SemaphoreInner<P, Q>,
677 priority: P,
678 count: usize,
679 ) -> (SemaphorePermitWaiter<'_, P, Q>, bool) {
680 let has_available = inner.available >= (count & !WITHIN_TOTAL_BIT);
681 let will_acquire = has_available
682 && inner
683 .queue
684 .iter()
685 .skip_while(|x| {
687 let flags = x.waiter.flags();
688 if flags & waiter::WAITER_FLAG_HAS_LOCK != 0 {
689 #[cfg(feature = "evict")]
690 if priority.compare(&x.priority).is_gt() {
691 x.waiter.evict();
692 }
693 return true;
694 }
695
696 cfg!(feature = "semaphore-total")
699 && flags & waiter::WAITER_FLAG_WANTS_EVICT == 0
700 })
701 .next()
702 .is_none_or(|first_pending| priority.compare(&first_pending.priority).is_ge());
705
706 let handle = inner.queue.enqueue(SemaphoreWaiter {
707 priority,
708 waiter: Waiter::new(will_acquire),
709 count,
710 });
711
712 let guard = SemaphorePermitWaiter {
713 semaphore: self,
714 handle: ManuallyDrop::new(handle),
715 };
716
717 if will_acquire {
718 inner.available -= count & !WITHIN_TOTAL_BIT;
719 }
720
721 #[cfg(feature = "evict")]
722 if !has_available {
725 let node = inner.queue.get_by_handle(&guard.handle);
726 for ex in inner.queue.iter() {
727 if ex.waiter.has_lock() {
728 if node.priority.compare(&ex.priority).is_gt() {
731 ex.waiter.evict();
732 }
733 }
734 }
735 }
736
737 return (guard, will_acquire);
738 }
739
740 #[inline]
743 pub fn acquire(&self, priority: P) -> impl Future<Output = SemaphorePermit<'_, P, Q>> {
744 self.acquire_many(priority, 1)
745 }
746
747 #[inline(always)]
753 pub fn acquire_from(
754 &self,
755 priority: impl Into<P>,
756 ) -> impl Future<Output = SemaphorePermit<'_, P, Q>> {
757 self.acquire(priority.into())
758 }
759
760 #[inline]
761 pub fn acquire_default(&self) -> impl Future<Output = SemaphorePermit<'_, P, Q>>
767 where
768 P: Default,
769 {
770 self.acquire_many(Default::default(), 1)
771 }
772
773 pub async fn acquire_many(&self, priority: P, count: usize) -> SemaphorePermit<'_, P, Q> {
781 assert!(
782 count < MAX_PERMITS,
783 "count for a single holder must be less than {} and not zero (received {})",
784 MAX_PERMITS,
785 count
786 );
787 let guard = {
789 let mut inner = self.0.write();
790 let (guard, did_acquire) = self.do_acquire(&mut inner, priority.into(), count);
791
792 if did_acquire {
793 return guard.into();
794 }
795 guard
796 };
797
798 WaiterFlagFut::<_, { waiter::WAITER_FLAG_HAS_LOCK }>::new(&guard).await;
799
800 guard.into()
801 }
802
803 #[inline(always)]
811 pub async fn acquire_many_default(
812 &self,
813 count: usize,
814 ) -> impl Future<Output = SemaphorePermit<'_, P, Q>>
815 where
816 P: Default,
817 {
818 self.acquire_many(Default::default(), count)
819 }
820
821 #[inline(always)]
829 pub async fn acquire_many_from(
830 &self,
831 count: usize,
832 priority: impl Into<P>,
833 ) -> impl Future<Output = SemaphorePermit<'_, P, Q>>
834 where
835 P: Default,
836 {
837 self.acquire_many(priority.into(), count)
838 }
839
840 #[cfg(feature = "semaphore-total")]
841 pub async fn acquire_within_total(
866 &self,
867 priority: P,
868 count: usize,
869 ) -> Result<SemaphorePermit<'_, P, Q>, InsufficientPermitsError> {
870 assert!(
871 count < MAX_PERMITS,
872 "count for a single holder must be less than {} and not zero (received {})",
873 MAX_PERMITS,
874 count
875 );
876 let guard = {
877 let mut inner = self.0.write();
878 if inner.total < count {
879 return Err(InsufficientPermitsError {
880 total: inner.total,
881 requested: count,
882 });
883 }
884
885 let (guard, did_acquire) =
886 self.do_acquire(&mut inner, priority.into(), count | WITHIN_TOTAL_BIT);
887
888 if did_acquire {
889 return Ok(guard.into());
890 }
891
892 guard
893 };
894
895 let flags = WaiterFlagFut::<
896 _,
897 { waiter::WAITER_FLAG_HAS_LOCK | waiter::WAITER_FLAG_WANTS_EVICT },
898 >::new(&guard)
899 .await;
900
901 if flags & waiter::WAITER_FLAG_HAS_LOCK == 0 {
902 return Err(InsufficientPermitsError {
905 total: usize::MAX,
906 requested: count,
907 });
908 }
909
910 Ok(guard.into())
914 }
915
916 #[inline]
917 pub fn add_permits(&self, count: usize) -> usize {
935 self.try_add_permits(count).expect("must add permits")
936 }
937
938 #[cfg(feature = "semaphore-total")]
939 #[inline]
940 pub fn total_permits(&self) -> usize {
942 self.0.read().total
943 }
944
945 #[inline]
947 pub fn available_permits(&self) -> usize {
948 self.0.read().available
949 }
950
951 #[inline]
952 pub fn try_add_permits(&self, count: usize) -> Result<usize, ()> {
955 let mut inner = self.0.write();
956
957 #[cfg(feature = "semaphore-total")]
958 {
959 inner.total = inner.total.checked_add(count).ok_or(())?;
960 inner.available += count
962 }
963 #[cfg(not(feature = "semaphore-total"))]
964 {
965 inner.available = inner.available.checked_add(count).ok_or(())?;
966 }
967
968 let head = inner.queue.head_handle();
969 inner.activate_waiters(head);
970
971 Ok(inner.available)
972 }
973
974 #[inline]
983 pub fn forget_permits(&self, mut count: usize) -> usize {
984 let mut inner = self.0.write();
985
986 count = count.min(inner.available);
987
988 inner.available -= count;
989
990 #[cfg(feature = "semaphore-total")]
991 if count != 0 {
992 unsafe { core::hint::assert_unchecked(inner.total >= count) };
993 inner.total -= count;
994
995 inner.downgrade().notify_oversized_waiters(None);
996 }
997
998 count
999 }
1000}
1001
1002#[cfg(feature = "const-default")]
1003impl<P: Priority, Q: ConstDefault + SemaphoreQueue<P>> ConstDefault for Semaphore<P, Q> {
1004 const DEFAULT: Self = Self(RwLock::new(ConstDefault::DEFAULT));
1005}
1006
1007impl<P: Priority, Q: SemaphoreQueue<P>> Default for Semaphore<P, Q> {
1009 #[inline]
1010 fn default() -> Self {
1011 Self::new(0)
1012 }
1013}