1use crate::config::{PriorityConfig, QueueStrategy};
2use crate::error::{AcquireError, TryAcquireError};
3use crate::wait_queue::{WaitQueue, WaiterHandle};
4
5use std::fmt;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::{Arc, Mutex};
10use std::task::{Context, Poll};
11
12#[derive(Debug)]
14pub struct RankedSemaphore {
15 permits: AtomicUsize,
19 waiters: Mutex<WaitQueue>,
21}
22
23pub struct RankedSemaphorePermit<'a> {
25 sem: &'a RankedSemaphore,
26 permits: u32,
27}
28
29pub struct OwnedRankedSemaphorePermit {
31 sem: Arc<RankedSemaphore>,
32 permits: u32,
33}
34
35#[must_use = "futures do nothing unless you `.await` or poll them"]
37pub struct Acquire<'a> {
38 semaphore: &'a RankedSemaphore,
39 permits_needed: usize,
40 priority: isize,
41 waiter_handle: Option<WaiterHandle>,
42}
43
44#[must_use = "futures do nothing unless you `.await` or poll them"]
46pub struct AcquireOwned {
47 semaphore: Arc<RankedSemaphore>,
48 permits_needed: usize,
49 priority: isize,
50 waiter_handle: Option<WaiterHandle>,
51}
52
53impl RankedSemaphore {
54 pub const MAX_PERMITS: usize = usize::MAX >> 3;
56
57 const CLOSED: usize = 1;
59 const PERMIT_SHIFT: usize = 1;
60
61 pub fn new_fifo(permits: usize) -> Self {
63 if permits > Self::MAX_PERMITS {
64 panic!("permits exceed MAX_PERMITS");
65 }
66 Self::new(permits, QueueStrategy::Fifo)
67 }
68
69 pub fn new_lifo(permits: usize) -> Self {
71 if permits > Self::MAX_PERMITS {
72 panic!("permits exceed MAX_PERMITS");
73 }
74 Self::new(permits, QueueStrategy::Lifo)
75 }
76
77 pub fn new(permits: usize, default_strategy: QueueStrategy) -> Self {
79 if permits > Self::MAX_PERMITS {
80 panic!("permits exceed MAX_PERMITS");
81 }
82 let config = PriorityConfig::new().default_strategy(default_strategy);
83 Self::new_with_config(permits, config)
84 }
85
86 pub fn new_with_config(permits: usize, config: PriorityConfig) -> Self {
88 if permits > Self::MAX_PERMITS {
89 panic!("permits exceed MAX_PERMITS");
90 }
91 Self {
92 permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT),
93 waiters: Mutex::new(WaitQueue::new(config)),
94 }
95 }
96
97 pub fn available_permits(&self) -> usize {
99 self.permits.load(Ordering::Acquire) >> Self::PERMIT_SHIFT
100 }
101
102 pub fn is_closed(&self) -> bool {
104 self.permits.load(Ordering::Acquire) & Self::CLOSED == Self::CLOSED
105 }
106
107 pub fn add_permits(&self, added: usize) {
109 if added == 0 {
110 return;
111 }
112
113 self.add_permits_locked(added, self.waiters.lock().unwrap());
115 }
116
117 fn add_permits_locked(
125 &self,
126 mut rem: usize,
127 waiters: std::sync::MutexGuard<'_, crate::wait_queue::WaitQueue>,
128 ) {
129 let mut lock = Some(waiters);
130
131 while rem > 0 {
133 let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock().unwrap());
134
135 if waiters.is_empty() {
137 drop(waiters);
138 break;
139 }
140
141 let (wake_list, permits_assigned) = waiters.select_waiters_to_notify(rem);
143 rem -= permits_assigned;
144
145 if permits_assigned == 0 || wake_list.is_empty() {
147 drop(waiters);
148 break;
149 }
150
151 drop(waiters);
153
154 let mut wake_list = wake_list;
156 wake_list.wake_all();
157
158 if !wake_list.was_full() {
160 break;
161 }
162 }
163
164 if rem > 0 {
166 let prev = self
167 .permits
168 .fetch_add(rem << Self::PERMIT_SHIFT, Ordering::Release);
169 let prev_permits = prev >> Self::PERMIT_SHIFT;
170
171 if prev_permits + rem > Self::MAX_PERMITS {
173 panic!(
174 "number of added permits ({}) would overflow MAX_PERMITS ({})",
175 rem,
176 Self::MAX_PERMITS
177 );
178 }
179 }
180 }
181
182 pub fn close(&self) {
184 self.permits.fetch_or(Self::CLOSED, Ordering::Release);
185
186 let mut waiters = self.waiters.lock().unwrap();
187 waiters.close();
188 }
189
190 pub fn acquire(&self) -> Acquire<'_> {
194 self.acquire_many_with_priority(0, 1)
195 }
196
197 pub fn acquire_with_priority(&self, priority: isize) -> Acquire<'_> {
199 self.acquire_many_with_priority(priority, 1)
200 }
201
202 pub fn acquire_many(&self, n: u32) -> Acquire<'_> {
204 self.acquire_many_with_priority(0, n)
205 }
206
207 pub fn acquire_many_with_priority(&self, priority: isize, n: u32) -> Acquire<'_> {
209 Acquire {
210 semaphore: self,
211 permits_needed: n as usize,
212 priority,
213 waiter_handle: None,
214 }
215 }
216
217 pub fn try_acquire(&self) -> Result<RankedSemaphorePermit<'_>, TryAcquireError> {
221 let mut curr = self.permits.load(Ordering::Acquire);
222 loop {
223 if curr & Self::CLOSED == Self::CLOSED {
225 return Err(TryAcquireError::Closed);
226 }
227
228 if curr < (1 << Self::PERMIT_SHIFT) {
230 return Err(TryAcquireError::NoPermits);
231 }
232
233 let next = curr - (1 << Self::PERMIT_SHIFT);
234 match self.permits.compare_exchange_weak(
235 curr,
236 next,
237 Ordering::AcqRel,
238 Ordering::Acquire,
239 ) {
240 Ok(_) => {
241 return Ok(RankedSemaphorePermit {
242 sem: self,
243 permits: 1,
244 })
245 }
246 Err(actual) => curr = actual,
247 }
248 }
249 }
250
251 pub fn try_acquire_many(&self, n: u32) -> Result<RankedSemaphorePermit<'_>, TryAcquireError> {
253 if n == 0 {
254 return Ok(RankedSemaphorePermit {
255 sem: self,
256 permits: 0,
257 });
258 }
259
260 if n as usize > Self::MAX_PERMITS {
261 panic!("try_acquire_many: n exceeds MAX_PERMITS");
262 }
263
264 let n_shifted = (n as usize) << Self::PERMIT_SHIFT;
265 let mut curr = self.permits.load(Ordering::Acquire);
266 loop {
267 if curr & Self::CLOSED == Self::CLOSED {
269 return Err(TryAcquireError::Closed);
270 }
271
272 if curr < n_shifted {
274 return Err(TryAcquireError::NoPermits);
275 }
276
277 let next = curr - n_shifted;
278 match self.permits.compare_exchange_weak(
279 curr,
280 next,
281 Ordering::AcqRel,
282 Ordering::Acquire,
283 ) {
284 Ok(_) => {
285 return Ok(RankedSemaphorePermit {
286 sem: self,
287 permits: n,
288 })
289 }
290 Err(actual) => curr = actual,
291 }
292 }
293 }
294
295 pub fn forget_permits(&self, n: usize) -> usize {
300 if n == 0 {
301 return 0;
302 }
303
304 let mut curr_bits = self.permits.load(Ordering::Acquire);
305 loop {
306 let curr_permits = curr_bits >> Self::PERMIT_SHIFT;
307 let removed = curr_permits.min(n);
308 let new_permits = curr_permits - removed;
309 let new_bits = (new_permits << Self::PERMIT_SHIFT) | (curr_bits & Self::CLOSED);
310
311 match self.permits.compare_exchange_weak(
312 curr_bits,
313 new_bits,
314 Ordering::AcqRel,
315 Ordering::Acquire,
316 ) {
317 Ok(_) => return removed,
318 Err(actual) => curr_bits = actual,
319 }
320 }
321 }
322
323 pub fn acquire_owned(self: Arc<Self>) -> AcquireOwned {
327 self.acquire_many_owned_with_priority(0, 1)
328 }
329
330 pub fn acquire_owned_with_priority(self: Arc<Self>, priority: isize) -> AcquireOwned {
332 self.acquire_many_owned_with_priority(priority, 1)
333 }
334
335 pub fn acquire_many_owned(self: Arc<Self>, n: u32) -> AcquireOwned {
337 self.acquire_many_owned_with_priority(0, n)
338 }
339
340 pub fn acquire_many_owned_with_priority(
342 self: Arc<Self>,
343 priority: isize,
344 n: u32,
345 ) -> AcquireOwned {
346 AcquireOwned {
347 semaphore: self,
348 permits_needed: n as usize,
349 priority,
350 waiter_handle: None,
351 }
352 }
353
354 pub fn try_acquire_owned(
361 self: Arc<Self>,
362 ) -> Result<OwnedRankedSemaphorePermit, TryAcquireError> {
363 let mut curr = self.permits.load(Ordering::Acquire);
364 loop {
365 if curr & Self::CLOSED == Self::CLOSED {
367 return Err(TryAcquireError::Closed);
368 }
369
370 if curr < (1 << Self::PERMIT_SHIFT) {
372 return Err(TryAcquireError::NoPermits);
373 }
374
375 let next = curr - (1 << Self::PERMIT_SHIFT);
376 match self.permits.compare_exchange_weak(
377 curr,
378 next,
379 Ordering::AcqRel,
380 Ordering::Acquire,
381 ) {
382 Ok(_) => {
383 return Ok(OwnedRankedSemaphorePermit {
384 sem: self,
385 permits: 1,
386 })
387 }
388 Err(actual) => curr = actual,
389 }
390 }
391 }
392
393 pub fn try_acquire_many_owned(
400 self: Arc<Self>,
401 n: u32,
402 ) -> Result<OwnedRankedSemaphorePermit, TryAcquireError> {
403 if n == 0 {
404 return Ok(OwnedRankedSemaphorePermit {
405 sem: self,
406 permits: 0,
407 });
408 }
409
410 if n as usize > Self::MAX_PERMITS {
411 panic!("try_acquire_many_owned: n exceeds MAX_PERMITS");
412 }
413
414 let n_shifted = (n as usize) << Self::PERMIT_SHIFT;
415 let mut curr = self.permits.load(Ordering::Acquire);
416 loop {
417 if curr & Self::CLOSED == Self::CLOSED {
419 return Err(TryAcquireError::Closed);
420 }
421
422 if curr < n_shifted {
424 return Err(TryAcquireError::NoPermits);
425 }
426
427 let next = curr - n_shifted;
428 match self.permits.compare_exchange_weak(
429 curr,
430 next,
431 Ordering::AcqRel,
432 Ordering::Acquire,
433 ) {
434 Ok(_) => {
435 return Ok(OwnedRankedSemaphorePermit {
436 sem: self,
437 permits: n,
438 })
439 }
440 Err(actual) => curr = actual,
441 }
442 }
443 }
444}
445
446impl<'a> Future for Acquire<'a> {
449 type Output = Result<RankedSemaphorePermit<'a>, AcquireError>;
450
451 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
452 let this = &mut *self;
453
454 if this.waiter_handle.is_none() {
456 match this.semaphore.try_acquire_many(this.permits_needed as u32) {
457 Ok(permit) => return Poll::Ready(Ok(permit)),
458 Err(TryAcquireError::NoPermits) => {
459 }
461 Err(TryAcquireError::Closed) => return Poll::Ready(Err(AcquireError::closed())),
462 }
463 }
464
465 if this.waiter_handle.is_none() {
467 let mut waiters = this.semaphore.waiters.lock().unwrap();
468 match this.semaphore.try_acquire_many(this.permits_needed as u32) {
470 Ok(permit) => return Poll::Ready(Ok(permit)),
471 Err(TryAcquireError::NoPermits) => {
472 if this.semaphore.is_closed() {
473 return Poll::Ready(Err(AcquireError::closed()));
474 }
475 this.waiter_handle =
476 Some(waiters.push_waiter(this.permits_needed, this.priority));
477 }
478 Err(TryAcquireError::Closed) => return Poll::Ready(Err(AcquireError::closed())),
479 }
480 }
481
482 let handle = this.waiter_handle.as_ref().unwrap();
484 handle.state.set_waker(cx.waker().clone());
485
486 if handle.state.is_notified() {
487 return Poll::Ready(Ok(RankedSemaphorePermit {
489 sem: this.semaphore,
490 permits: this.permits_needed as u32,
491 }));
492 }
493
494 if handle.state.is_cancelled() {
495 return Poll::Ready(Err(AcquireError::closed()));
496 }
497
498 Poll::Pending
499 }
500}
501
502impl<'a> Drop for Acquire<'a> {
503 fn drop(&mut self) {
504 if let Some(handle) = &self.waiter_handle {
506 handle.state.cancel();
507 }
508 }
509}
510
511impl Future for AcquireOwned {
512 type Output = Result<OwnedRankedSemaphorePermit, AcquireError>;
513
514 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
515 let this = &mut *self;
516
517 if this.waiter_handle.is_none() {
519 match this.semaphore.try_acquire_many(this.permits_needed as u32) {
520 Ok(permit) => {
521 let permits = permit.permits;
524 std::mem::forget(permit);
525 return Poll::Ready(Ok(OwnedRankedSemaphorePermit {
526 sem: Arc::clone(&this.semaphore),
527 permits,
528 }));
529 }
530 Err(TryAcquireError::NoPermits) => {
531 }
533 Err(TryAcquireError::Closed) => return Poll::Ready(Err(AcquireError::closed())),
534 }
535 }
536
537 if this.waiter_handle.is_none() {
539 let mut waiters = this.semaphore.waiters.lock().unwrap();
540 match this.semaphore.try_acquire_many(this.permits_needed as u32) {
542 Ok(permit) => {
543 let permits = permit.permits;
544 std::mem::forget(permit);
545 return Poll::Ready(Ok(OwnedRankedSemaphorePermit {
546 sem: Arc::clone(&this.semaphore),
547 permits,
548 }));
549 }
550 Err(TryAcquireError::NoPermits) => {
551 if this.semaphore.is_closed() {
552 return Poll::Ready(Err(AcquireError::closed()));
553 }
554 this.waiter_handle =
555 Some(waiters.push_waiter(this.permits_needed, this.priority));
556 }
557 Err(TryAcquireError::Closed) => return Poll::Ready(Err(AcquireError::closed())),
558 }
559 }
560
561 let handle = this.waiter_handle.as_ref().unwrap();
563 handle.state.set_waker(cx.waker().clone());
564
565 if handle.state.is_notified() {
566 return Poll::Ready(Ok(OwnedRankedSemaphorePermit {
568 sem: Arc::clone(&this.semaphore),
569 permits: this.permits_needed as u32,
570 }));
571 }
572
573 if handle.state.is_cancelled() {
574 return Poll::Ready(Err(AcquireError::closed()));
575 }
576
577 Poll::Pending
578 }
579}
580
581impl Drop for AcquireOwned {
582 fn drop(&mut self) {
583 if let Some(handle) = &self.waiter_handle {
585 handle.state.cancel();
586 }
587 }
588}
589
590impl<'a> RankedSemaphorePermit<'a> {
593 pub fn forget(mut self) {
595 self.permits = 0;
596 }
597
598 pub fn num_permits(&self) -> usize {
600 self.permits as usize
601 }
602
603 pub fn merge(&mut self, mut other: Self) {
605 if !std::ptr::eq(self.sem, other.sem) {
606 panic!("Cannot merge permits from different semaphores");
607 }
608 self.permits += other.permits;
609 other.permits = 0;
611 }
612
613 pub fn split(&mut self, n: u32) -> Option<Self> {
615 if n > self.permits {
616 return None;
617 }
618 self.permits -= n;
619 Some(Self {
620 sem: self.sem,
621 permits: n,
622 })
623 }
624}
625
626impl<'a> Drop for RankedSemaphorePermit<'a> {
627 fn drop(&mut self) {
628 if self.permits == 0 {
629 return;
630 }
631
632 let permits_to_add = (self.permits as usize) << RankedSemaphore::PERMIT_SHIFT;
634 let waiters = self.sem.waiters.lock().unwrap();
635
636 if waiters.is_empty() {
637 drop(waiters);
639 self.sem
640 .permits
641 .fetch_add(permits_to_add, Ordering::Release);
642 return;
643 }
644
645 self.sem.add_permits_locked(self.permits as usize, waiters);
647 }
648}
649
650impl OwnedRankedSemaphorePermit {
651 pub fn forget(mut self) {
653 self.permits = 0;
654 }
655
656 pub fn num_permits(&self) -> usize {
658 self.permits as usize
659 }
660
661 pub fn merge(&mut self, mut other: Self) {
663 if !Arc::ptr_eq(&self.sem, &other.sem) {
664 panic!("merging permits from different semaphore instances");
665 }
666 self.permits += other.permits;
667 other.permits = 0;
669 }
670
671 pub fn split(&mut self, n: usize) -> Option<Self> {
674 let n = u32::try_from(n).ok()?;
675
676 if n > self.permits {
677 return None;
678 }
679
680 self.permits -= n;
681
682 Some(Self {
683 sem: self.sem.clone(),
684 permits: n,
685 })
686 }
687
688 pub fn semaphore(&self) -> &Arc<RankedSemaphore> {
690 &self.sem
691 }
692}
693
694impl Drop for OwnedRankedSemaphorePermit {
695 fn drop(&mut self) {
696 if self.permits == 0 {
697 return;
698 }
699
700 let permits_to_add = (self.permits as usize) << RankedSemaphore::PERMIT_SHIFT;
702 let waiters = self.sem.waiters.lock().unwrap();
703
704 if waiters.is_empty() {
705 drop(waiters);
707 self.sem
708 .permits
709 .fetch_add(permits_to_add, Ordering::Release);
710 return;
711 }
712
713 self.sem.add_permits_locked(self.permits as usize, waiters);
715 }
716}
717
718impl<'a> fmt::Debug for RankedSemaphorePermit<'a> {
721 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
722 f.debug_struct("RankedSemaphorePermit")
723 .field("permits", &self.permits)
724 .finish()
725 }
726}
727
728impl fmt::Debug for OwnedRankedSemaphorePermit {
729 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
730 f.debug_struct("OwnedRankedSemaphorePermit")
731 .field("permits", &self.permits)
732 .finish()
733 }
734}
735
736impl<'a> fmt::Debug for Acquire<'a> {
737 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
738 f.debug_struct("Acquire")
739 .field("permits_needed", &self.permits_needed)
740 .field("priority", &self.priority)
741 .field("queued", &self.waiter_handle.is_some())
742 .finish()
743 }
744}
745
746impl fmt::Debug for AcquireOwned {
747 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
748 f.debug_struct("AcquireOwned")
749 .field("permits_needed", &self.permits_needed)
750 .field("priority", &self.priority)
751 .field("queued", &self.waiter_handle.is_some())
752 .finish()
753 }
754}