1use crate::config::{PriorityConfig, QueueStrategy};
11use crate::error::{AcquireError, TryAcquireError};
12use crate::wait_queue::{WaitQueue, WaiterHandle};
13
14use std::fmt;
15use std::future::Future;
16use std::pin::Pin;
17use std::sync::atomic::{AtomicUsize, Ordering};
18use std::sync::{Arc, Mutex};
19use std::task::{Context, Poll};
20
21#[derive(Debug)]
23pub struct RankedSemaphore {
24 permits: AtomicUsize,
28 waiters: Mutex<WaitQueue>,
30}
31
32pub struct RankedSemaphorePermit<'a> {
34 sem: &'a RankedSemaphore,
35 permits: u32,
36}
37
38pub struct OwnedRankedSemaphorePermit {
40 sem: Arc<RankedSemaphore>,
41 permits: u32,
42}
43
44#[must_use = "futures do nothing unless you `.await` or poll them"]
46pub struct Acquire<'a> {
47 semaphore: &'a RankedSemaphore,
48 permits_needed: usize,
49 priority: isize,
50 waiter_handle: Option<WaiterHandle>,
51}
52
53#[must_use = "futures do nothing unless you `.await` or poll them"]
55pub struct AcquireOwned {
56 semaphore: Arc<RankedSemaphore>,
57 permits_needed: usize,
58 priority: isize,
59 waiter_handle: Option<WaiterHandle>,
60}
61
62impl RankedSemaphore {
63 pub const MAX_PERMITS: usize = usize::MAX >> 3;
65
66 const CLOSED: usize = 1;
68 const PERMIT_SHIFT: usize = 1;
69
70 pub fn new_fifo(permits: usize) -> Self {
72 if permits > Self::MAX_PERMITS {
73 panic!("permits exceed MAX_PERMITS");
74 }
75 Self::new(permits, QueueStrategy::Fifo)
76 }
77
78 pub fn new_lifo(permits: usize) -> Self {
80 if permits > Self::MAX_PERMITS {
81 panic!("permits exceed MAX_PERMITS");
82 }
83 Self::new(permits, QueueStrategy::Lifo)
84 }
85
86 pub fn new(permits: usize, default_strategy: QueueStrategy) -> Self {
88 if permits > Self::MAX_PERMITS {
89 panic!("permits exceed MAX_PERMITS");
90 }
91 let config = PriorityConfig::new().default_strategy(default_strategy);
92 Self::new_with_config(permits, config)
93 }
94
95 pub fn new_with_config(permits: usize, config: PriorityConfig) -> Self {
97 if permits > Self::MAX_PERMITS {
98 panic!("permits exceed MAX_PERMITS");
99 }
100 Self {
101 permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT),
102 waiters: Mutex::new(WaitQueue::new(config)),
103 }
104 }
105
106 pub fn available_permits(&self) -> usize {
108 self.permits.load(Ordering::Acquire) >> Self::PERMIT_SHIFT
109 }
110
111 pub fn is_closed(&self) -> bool {
113 self.permits.load(Ordering::Acquire) & Self::CLOSED == Self::CLOSED
114 }
115
116 pub fn add_permits(&self, added: usize) {
118 if added == 0 {
119 return;
120 }
121
122 self.add_permits_locked(added, self.waiters.lock().unwrap());
124 }
125
126 fn add_permits_locked(
134 &self,
135 mut rem: usize,
136 waiters: std::sync::MutexGuard<'_, crate::wait_queue::WaitQueue>,
137 ) {
138 let mut lock = Some(waiters);
139
140 while rem > 0 {
142 let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock().unwrap());
143
144 if waiters.is_empty() {
146 drop(waiters);
147 break;
148 }
149
150 let (wake_list, permits_assigned) = waiters.select_waiters_to_notify(rem);
152 rem -= permits_assigned;
153
154 if permits_assigned == 0 || wake_list.is_empty() {
156 drop(waiters);
157 break;
158 }
159
160 drop(waiters);
162
163 let mut wake_list = wake_list;
165 wake_list.wake_all();
166
167 if !wake_list.was_full() {
169 break;
170 }
171 }
172
173 if rem > 0 {
175 let prev = self
176 .permits
177 .fetch_add(rem << Self::PERMIT_SHIFT, Ordering::Release);
178 let prev_permits = prev >> Self::PERMIT_SHIFT;
179
180 if prev_permits + rem > Self::MAX_PERMITS {
182 panic!(
183 "number of added permits ({}) would overflow MAX_PERMITS ({})",
184 rem,
185 Self::MAX_PERMITS
186 );
187 }
188 }
189 }
190
191 pub fn close(&self) {
193 self.permits.fetch_or(Self::CLOSED, Ordering::Release);
194
195 let mut waiters = self.waiters.lock().unwrap();
196 waiters.close();
197 }
198
199 pub fn acquire(&self) -> Acquire<'_> {
203 self.acquire_many_with_priority(0, 1)
204 }
205
206 pub fn acquire_with_priority(&self, priority: isize) -> Acquire<'_> {
208 self.acquire_many_with_priority(priority, 1)
209 }
210
211 pub fn acquire_many(&self, n: u32) -> Acquire<'_> {
213 self.acquire_many_with_priority(0, n)
214 }
215
216 pub fn acquire_many_with_priority(&self, priority: isize, n: u32) -> Acquire<'_> {
218 Acquire {
219 semaphore: self,
220 permits_needed: n as usize,
221 priority,
222 waiter_handle: None,
223 }
224 }
225
226 pub fn try_acquire(&self) -> Result<RankedSemaphorePermit<'_>, TryAcquireError> {
230 let mut curr = self.permits.load(Ordering::Acquire);
231 loop {
232 if curr & Self::CLOSED == Self::CLOSED {
234 return Err(TryAcquireError::Closed);
235 }
236
237 if curr < (1 << Self::PERMIT_SHIFT) {
239 return Err(TryAcquireError::NoPermits);
240 }
241
242 let next = curr - (1 << Self::PERMIT_SHIFT);
243 match self.permits.compare_exchange_weak(
244 curr,
245 next,
246 Ordering::AcqRel,
247 Ordering::Acquire,
248 ) {
249 Ok(_) => {
250 return Ok(RankedSemaphorePermit {
251 sem: self,
252 permits: 1,
253 })
254 }
255 Err(actual) => curr = actual,
256 }
257 }
258 }
259
260 pub fn try_acquire_many(&self, n: u32) -> Result<RankedSemaphorePermit<'_>, TryAcquireError> {
262 if n == 0 {
263 return Ok(RankedSemaphorePermit {
264 sem: self,
265 permits: 0,
266 });
267 }
268
269 if n as usize > Self::MAX_PERMITS {
270 panic!("try_acquire_many: n exceeds MAX_PERMITS");
271 }
272
273 let n_shifted = (n as usize) << Self::PERMIT_SHIFT;
274 let mut curr = self.permits.load(Ordering::Acquire);
275 loop {
276 if curr & Self::CLOSED == Self::CLOSED {
278 return Err(TryAcquireError::Closed);
279 }
280
281 if curr < n_shifted {
283 return Err(TryAcquireError::NoPermits);
284 }
285
286 let next = curr - n_shifted;
287 match self.permits.compare_exchange_weak(
288 curr,
289 next,
290 Ordering::AcqRel,
291 Ordering::Acquire,
292 ) {
293 Ok(_) => {
294 return Ok(RankedSemaphorePermit {
295 sem: self,
296 permits: n,
297 })
298 }
299 Err(actual) => curr = actual,
300 }
301 }
302 }
303
304 pub fn forget_permits(&self, n: usize) -> usize {
309 if n == 0 {
310 return 0;
311 }
312
313 let mut curr_bits = self.permits.load(Ordering::Acquire);
314 loop {
315 let curr_permits = curr_bits >> Self::PERMIT_SHIFT;
316 let removed = curr_permits.min(n);
317 let new_permits = curr_permits - removed;
318 let new_bits = (new_permits << Self::PERMIT_SHIFT) | (curr_bits & Self::CLOSED);
319
320 match self.permits.compare_exchange_weak(
321 curr_bits,
322 new_bits,
323 Ordering::AcqRel,
324 Ordering::Acquire,
325 ) {
326 Ok(_) => return removed,
327 Err(actual) => curr_bits = actual,
328 }
329 }
330 }
331
332 pub fn acquire_owned(self: Arc<Self>) -> AcquireOwned {
336 self.acquire_many_owned_with_priority(0, 1)
337 }
338
339 pub fn acquire_owned_with_priority(self: Arc<Self>, priority: isize) -> AcquireOwned {
341 self.acquire_many_owned_with_priority(priority, 1)
342 }
343
344 pub fn acquire_many_owned(self: Arc<Self>, n: u32) -> AcquireOwned {
346 self.acquire_many_owned_with_priority(0, n)
347 }
348
349 pub fn acquire_many_owned_with_priority(
351 self: Arc<Self>,
352 priority: isize,
353 n: u32,
354 ) -> AcquireOwned {
355 AcquireOwned {
356 semaphore: self,
357 permits_needed: n as usize,
358 priority,
359 waiter_handle: None,
360 }
361 }
362
363 pub fn try_acquire_owned(
370 self: Arc<Self>,
371 ) -> Result<OwnedRankedSemaphorePermit, TryAcquireError> {
372 let mut curr = self.permits.load(Ordering::Acquire);
373 loop {
374 if curr & Self::CLOSED == Self::CLOSED {
376 return Err(TryAcquireError::Closed);
377 }
378
379 if curr < (1 << Self::PERMIT_SHIFT) {
381 return Err(TryAcquireError::NoPermits);
382 }
383
384 let next = curr - (1 << Self::PERMIT_SHIFT);
385 match self.permits.compare_exchange_weak(
386 curr,
387 next,
388 Ordering::AcqRel,
389 Ordering::Acquire,
390 ) {
391 Ok(_) => {
392 return Ok(OwnedRankedSemaphorePermit {
393 sem: self,
394 permits: 1,
395 })
396 }
397 Err(actual) => curr = actual,
398 }
399 }
400 }
401
402 pub fn try_acquire_many_owned(
409 self: Arc<Self>,
410 n: u32,
411 ) -> Result<OwnedRankedSemaphorePermit, TryAcquireError> {
412 if n == 0 {
413 return Ok(OwnedRankedSemaphorePermit {
414 sem: self,
415 permits: 0,
416 });
417 }
418
419 if n as usize > Self::MAX_PERMITS {
420 panic!("try_acquire_many_owned: n exceeds MAX_PERMITS");
421 }
422
423 let n_shifted = (n as usize) << Self::PERMIT_SHIFT;
424 let mut curr = self.permits.load(Ordering::Acquire);
425 loop {
426 if curr & Self::CLOSED == Self::CLOSED {
428 return Err(TryAcquireError::Closed);
429 }
430
431 if curr < n_shifted {
433 return Err(TryAcquireError::NoPermits);
434 }
435
436 let next = curr - n_shifted;
437 match self.permits.compare_exchange_weak(
438 curr,
439 next,
440 Ordering::AcqRel,
441 Ordering::Acquire,
442 ) {
443 Ok(_) => {
444 return Ok(OwnedRankedSemaphorePermit {
445 sem: self,
446 permits: n,
447 })
448 }
449 Err(actual) => curr = actual,
450 }
451 }
452 }
453}
454
455impl<'a> Future for Acquire<'a> {
458 type Output = Result<RankedSemaphorePermit<'a>, AcquireError>;
459
460 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
461 let this = &mut *self;
462
463 if this.waiter_handle.is_none() {
465 match this.semaphore.try_acquire_many(this.permits_needed as u32) {
466 Ok(permit) => return Poll::Ready(Ok(permit)),
467 Err(TryAcquireError::NoPermits) => {
468 }
470 Err(TryAcquireError::Closed) => return Poll::Ready(Err(AcquireError::closed())),
471 }
472 }
473
474 if this.waiter_handle.is_none() {
476 let mut waiters = this.semaphore.waiters.lock().unwrap();
477 match this.semaphore.try_acquire_many(this.permits_needed as u32) {
479 Ok(permit) => return Poll::Ready(Ok(permit)),
480 Err(TryAcquireError::NoPermits) => {
481 if this.semaphore.is_closed() {
482 return Poll::Ready(Err(AcquireError::closed()));
483 }
484 this.waiter_handle =
485 Some(waiters.push_waiter(this.permits_needed, this.priority));
486 }
487 Err(TryAcquireError::Closed) => return Poll::Ready(Err(AcquireError::closed())),
488 }
489 }
490
491 let handle = this.waiter_handle.as_ref().unwrap();
493 handle.state.set_waker(cx.waker().clone());
494
495 if handle.state.is_notified() {
496 return Poll::Ready(Ok(RankedSemaphorePermit {
498 sem: this.semaphore,
499 permits: this.permits_needed as u32,
500 }));
501 }
502
503 if handle.state.is_cancelled() {
504 return Poll::Ready(Err(AcquireError::closed()));
505 }
506
507 Poll::Pending
508 }
509}
510
511impl<'a> Drop for Acquire<'a> {
512 fn drop(&mut self) {
513 if let Some(handle) = &self.waiter_handle {
515 handle.state.cancel();
516 }
517 }
518}
519
520impl Future for AcquireOwned {
521 type Output = Result<OwnedRankedSemaphorePermit, AcquireError>;
522
523 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
524 let this = &mut *self;
525
526 if this.waiter_handle.is_none() {
528 match this.semaphore.try_acquire_many(this.permits_needed as u32) {
529 Ok(permit) => {
530 let permits = permit.permits;
533 std::mem::forget(permit);
534 return Poll::Ready(Ok(OwnedRankedSemaphorePermit {
535 sem: Arc::clone(&this.semaphore),
536 permits,
537 }));
538 }
539 Err(TryAcquireError::NoPermits) => {
540 }
542 Err(TryAcquireError::Closed) => return Poll::Ready(Err(AcquireError::closed())),
543 }
544 }
545
546 if this.waiter_handle.is_none() {
548 let mut waiters = this.semaphore.waiters.lock().unwrap();
549 match this.semaphore.try_acquire_many(this.permits_needed as u32) {
551 Ok(permit) => {
552 let permits = permit.permits;
553 std::mem::forget(permit);
554 return Poll::Ready(Ok(OwnedRankedSemaphorePermit {
555 sem: Arc::clone(&this.semaphore),
556 permits,
557 }));
558 }
559 Err(TryAcquireError::NoPermits) => {
560 if this.semaphore.is_closed() {
561 return Poll::Ready(Err(AcquireError::closed()));
562 }
563 this.waiter_handle =
564 Some(waiters.push_waiter(this.permits_needed, this.priority));
565 }
566 Err(TryAcquireError::Closed) => return Poll::Ready(Err(AcquireError::closed())),
567 }
568 }
569
570 let handle = this.waiter_handle.as_ref().unwrap();
572 handle.state.set_waker(cx.waker().clone());
573
574 if handle.state.is_notified() {
575 return Poll::Ready(Ok(OwnedRankedSemaphorePermit {
577 sem: Arc::clone(&this.semaphore),
578 permits: this.permits_needed as u32,
579 }));
580 }
581
582 if handle.state.is_cancelled() {
583 return Poll::Ready(Err(AcquireError::closed()));
584 }
585
586 Poll::Pending
587 }
588}
589
590impl Drop for AcquireOwned {
591 fn drop(&mut self) {
592 if let Some(handle) = &self.waiter_handle {
594 handle.state.cancel();
595 }
596 }
597}
598
599impl<'a> RankedSemaphorePermit<'a> {
602 pub fn forget(mut self) {
604 self.permits = 0;
605 }
606
607 pub fn num_permits(&self) -> usize {
609 self.permits as usize
610 }
611
612 pub fn merge(&mut self, mut other: Self) {
614 if !std::ptr::eq(self.sem, other.sem) {
615 panic!("Cannot merge permits from different semaphores");
616 }
617 self.permits += other.permits;
618 other.permits = 0;
620 }
621
622 pub fn split(&mut self, n: u32) -> Option<Self> {
624 if n > self.permits {
625 return None;
626 }
627 self.permits -= n;
628 Some(Self {
629 sem: self.sem,
630 permits: n,
631 })
632 }
633}
634
635impl<'a> Drop for RankedSemaphorePermit<'a> {
636 fn drop(&mut self) {
637 if self.permits == 0 {
638 return;
639 }
640
641 let permits_to_add = (self.permits as usize) << RankedSemaphore::PERMIT_SHIFT;
643 let waiters = self.sem.waiters.lock().unwrap();
644
645 if waiters.is_empty() {
646 drop(waiters);
648 self.sem
649 .permits
650 .fetch_add(permits_to_add, Ordering::Release);
651 return;
652 }
653
654 self.sem.add_permits_locked(self.permits as usize, waiters);
656 }
657}
658
659impl OwnedRankedSemaphorePermit {
660 pub fn forget(mut self) {
662 self.permits = 0;
663 }
664
665 pub fn num_permits(&self) -> usize {
667 self.permits as usize
668 }
669
670 pub fn merge(&mut self, mut other: Self) {
672 if !Arc::ptr_eq(&self.sem, &other.sem) {
673 panic!("merging permits from different semaphore instances");
674 }
675 self.permits += other.permits;
676 other.permits = 0;
678 }
679
680 pub fn split(&mut self, n: usize) -> Option<Self> {
683 let n = u32::try_from(n).ok()?;
684
685 if n > self.permits {
686 return None;
687 }
688
689 self.permits -= n;
690
691 Some(Self {
692 sem: self.sem.clone(),
693 permits: n,
694 })
695 }
696
697 pub fn semaphore(&self) -> &Arc<RankedSemaphore> {
699 &self.sem
700 }
701}
702
703impl Drop for OwnedRankedSemaphorePermit {
704 fn drop(&mut self) {
705 if self.permits == 0 {
706 return;
707 }
708
709 let permits_to_add = (self.permits as usize) << RankedSemaphore::PERMIT_SHIFT;
711 let waiters = self.sem.waiters.lock().unwrap();
712
713 if waiters.is_empty() {
714 drop(waiters);
716 self.sem
717 .permits
718 .fetch_add(permits_to_add, Ordering::Release);
719 return;
720 }
721
722 self.sem.add_permits_locked(self.permits as usize, waiters);
724 }
725}
726
727impl<'a> fmt::Debug for RankedSemaphorePermit<'a> {
730 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
731 f.debug_struct("RankedSemaphorePermit")
732 .field("permits", &self.permits)
733 .finish()
734 }
735}
736
737impl fmt::Debug for OwnedRankedSemaphorePermit {
738 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
739 f.debug_struct("OwnedRankedSemaphorePermit")
740 .field("permits", &self.permits)
741 .finish()
742 }
743}
744
745impl<'a> fmt::Debug for Acquire<'a> {
746 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
747 f.debug_struct("Acquire")
748 .field("permits_needed", &self.permits_needed)
749 .field("priority", &self.priority)
750 .field("queued", &self.waiter_handle.is_some())
751 .finish()
752 }
753}
754
755impl fmt::Debug for AcquireOwned {
756 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
757 f.debug_struct("AcquireOwned")
758 .field("permits_needed", &self.permits_needed)
759 .field("priority", &self.priority)
760 .field("queued", &self.waiter_handle.is_some())
761 .finish()
762 }
763}