1use core::cell::{Cell, RefCell};
3use core::convert::Infallible;
4use core::future::{poll_fn, Future};
5use core::task::{Poll, Waker};
6
7use heapless::Deque;
8
9use crate::blocking_mutex::raw::RawMutex;
10use crate::blocking_mutex::Mutex;
11use crate::waitqueue::WakerRegistration;
12
13pub trait Semaphore: Sized {
19 type Error;
21
22 async fn acquire(&self, permits: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error>;
24
25 fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>>;
27
28 async fn acquire_all(&self, min: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error>;
35
36 fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>>;
38
39 fn release(&self, permits: usize);
41
42 fn set(&self, permits: usize);
44}
45
46#[derive(Debug)]
50pub struct SemaphoreReleaser<'a, S: Semaphore> {
51 semaphore: &'a S,
52 permits: usize,
53}
54
55impl<'a, S: Semaphore> Drop for SemaphoreReleaser<'a, S> {
56 fn drop(&mut self) {
57 self.semaphore.release(self.permits);
58 }
59}
60
61impl<'a, S: Semaphore> SemaphoreReleaser<'a, S> {
62 pub fn permits(&self) -> usize {
64 self.permits
65 }
66
67 pub fn disarm(self) -> usize {
71 let permits = self.permits;
72 core::mem::forget(self);
73 permits
74 }
75}
76
77pub struct GreedySemaphore<M: RawMutex> {
82 state: Mutex<M, Cell<SemaphoreState>>,
83}
84
85impl<M: RawMutex> Default for GreedySemaphore<M> {
86 fn default() -> Self {
87 Self::new(0)
88 }
89}
90
91impl<M: RawMutex> GreedySemaphore<M> {
92 pub const fn new(permits: usize) -> Self {
94 Self {
95 state: Mutex::new(Cell::new(SemaphoreState {
96 permits,
97 waker: WakerRegistration::new(),
98 })),
99 }
100 }
101
102 #[cfg(test)]
103 fn permits(&self) -> usize {
104 self.state.lock(|cell| {
105 let state = cell.replace(SemaphoreState::EMPTY);
106 let permits = state.permits;
107 cell.replace(state);
108 permits
109 })
110 }
111
112 fn poll_acquire(
113 &self,
114 permits: usize,
115 acquire_all: bool,
116 waker: Option<&Waker>,
117 ) -> Poll<Result<SemaphoreReleaser<'_, Self>, Infallible>> {
118 self.state.lock(|cell| {
119 let mut state = cell.replace(SemaphoreState::EMPTY);
120 if let Some(permits) = state.take(permits, acquire_all) {
121 cell.set(state);
122 Poll::Ready(Ok(SemaphoreReleaser {
123 semaphore: self,
124 permits,
125 }))
126 } else {
127 if let Some(waker) = waker {
128 state.register(waker);
129 }
130 cell.set(state);
131 Poll::Pending
132 }
133 })
134 }
135}
136
137impl<M: RawMutex> Semaphore for GreedySemaphore<M> {
138 type Error = Infallible;
139
140 async fn acquire(&self, permits: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> {
141 poll_fn(|cx| self.poll_acquire(permits, false, Some(cx.waker()))).await
142 }
143
144 fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>> {
145 match self.poll_acquire(permits, false, None) {
146 Poll::Ready(Ok(n)) => Some(n),
147 _ => None,
148 }
149 }
150
151 async fn acquire_all(&self, min: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> {
152 poll_fn(|cx| self.poll_acquire(min, true, Some(cx.waker()))).await
153 }
154
155 fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>> {
156 match self.poll_acquire(min, true, None) {
157 Poll::Ready(Ok(n)) => Some(n),
158 _ => None,
159 }
160 }
161
162 fn release(&self, permits: usize) {
163 if permits > 0 {
164 self.state.lock(|cell| {
165 let mut state = cell.replace(SemaphoreState::EMPTY);
166 state.permits += permits;
167 state.wake();
168 cell.set(state);
169 });
170 }
171 }
172
173 fn set(&self, permits: usize) {
174 self.state.lock(|cell| {
175 let mut state = cell.replace(SemaphoreState::EMPTY);
176 if permits > state.permits {
177 state.wake();
178 }
179 state.permits = permits;
180 cell.set(state);
181 });
182 }
183}
184
185#[derive(Debug)]
186struct SemaphoreState {
187 permits: usize,
188 waker: WakerRegistration,
189}
190
191impl SemaphoreState {
192 const EMPTY: SemaphoreState = SemaphoreState {
193 permits: 0,
194 waker: WakerRegistration::new(),
195 };
196
197 fn register(&mut self, w: &Waker) {
198 self.waker.register(w);
199 }
200
201 fn take(&mut self, mut permits: usize, acquire_all: bool) -> Option<usize> {
202 if self.permits < permits {
203 None
204 } else {
205 if acquire_all {
206 permits = self.permits;
207 }
208 self.permits -= permits;
209 Some(permits)
210 }
211 }
212
213 fn wake(&mut self) {
214 self.waker.wake();
215 }
216}
217
218#[derive(Debug)]
227pub struct FairSemaphore<M, const N: usize>
228where
229 M: RawMutex,
230{
231 state: Mutex<M, RefCell<FairSemaphoreState<N>>>,
232}
233
234impl<M, const N: usize> Default for FairSemaphore<M, N>
235where
236 M: RawMutex,
237{
238 fn default() -> Self {
239 Self::new(0)
240 }
241}
242
243impl<M, const N: usize> FairSemaphore<M, N>
244where
245 M: RawMutex,
246{
247 pub const fn new(permits: usize) -> Self {
249 Self {
250 state: Mutex::new(RefCell::new(FairSemaphoreState::new(permits))),
251 }
252 }
253
254 #[cfg(test)]
255 fn permits(&self) -> usize {
256 self.state.lock(|cell| cell.borrow().permits)
257 }
258
259 fn poll_acquire(
260 &self,
261 permits: usize,
262 acquire_all: bool,
263 cx: Option<(&mut Option<usize>, &Waker)>,
264 ) -> Poll<Result<SemaphoreReleaser<'_, Self>, WaitQueueFull>> {
265 let ticket = cx.as_ref().map(|(x, _)| **x).unwrap_or(None);
266 self.state.lock(|cell| {
267 let mut state = cell.borrow_mut();
268 if let Some(permits) = state.take(ticket, permits, acquire_all) {
269 Poll::Ready(Ok(SemaphoreReleaser {
270 semaphore: self,
271 permits,
272 }))
273 } else if let Some((ticket_ref, waker)) = cx {
274 match state.register(ticket, waker) {
275 Ok(ticket) => {
276 *ticket_ref = Some(ticket);
277 Poll::Pending
278 }
279 Err(err) => Poll::Ready(Err(err)),
280 }
281 } else {
282 Poll::Pending
283 }
284 })
285 }
286}
287
288#[derive(Debug, Clone, Copy, PartialEq, Eq)]
290#[cfg_attr(feature = "defmt", derive(defmt::Format))]
291pub struct WaitQueueFull;
292
293impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> {
294 type Error = WaitQueueFull;
295
296 fn acquire(&self, permits: usize) -> impl Future<Output = Result<SemaphoreReleaser<'_, Self>, Self::Error>> {
297 FairAcquire {
298 sema: self,
299 permits,
300 ticket: None,
301 }
302 }
303
304 fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>> {
305 match self.poll_acquire(permits, false, None) {
306 Poll::Ready(Ok(x)) => Some(x),
307 _ => None,
308 }
309 }
310
311 fn acquire_all(&self, min: usize) -> impl Future<Output = Result<SemaphoreReleaser<'_, Self>, Self::Error>> {
312 FairAcquireAll {
313 sema: self,
314 min,
315 ticket: None,
316 }
317 }
318
319 fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>> {
320 match self.poll_acquire(min, true, None) {
321 Poll::Ready(Ok(x)) => Some(x),
322 _ => None,
323 }
324 }
325
326 fn release(&self, permits: usize) {
327 if permits > 0 {
328 self.state.lock(|cell| {
329 let mut state = cell.borrow_mut();
330 state.permits += permits;
331 state.wake();
332 });
333 }
334 }
335
336 fn set(&self, permits: usize) {
337 self.state.lock(|cell| {
338 let mut state = cell.borrow_mut();
339 if permits > state.permits {
340 state.wake();
341 }
342 state.permits = permits;
343 });
344 }
345}
346
347#[derive(Debug)]
348struct FairAcquire<'a, M: RawMutex, const N: usize> {
349 sema: &'a FairSemaphore<M, N>,
350 permits: usize,
351 ticket: Option<usize>,
352}
353
354impl<'a, M: RawMutex, const N: usize> Drop for FairAcquire<'a, M, N> {
355 fn drop(&mut self) {
356 self.sema
357 .state
358 .lock(|cell| cell.borrow_mut().cancel(self.ticket.take()));
359 }
360}
361
362impl<'a, M: RawMutex, const N: usize> core::future::Future for FairAcquire<'a, M, N> {
363 type Output = Result<SemaphoreReleaser<'a, FairSemaphore<M, N>>, WaitQueueFull>;
364
365 fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
366 self.sema
367 .poll_acquire(self.permits, false, Some((&mut self.ticket, cx.waker())))
368 }
369}
370
371#[derive(Debug)]
372struct FairAcquireAll<'a, M: RawMutex, const N: usize> {
373 sema: &'a FairSemaphore<M, N>,
374 min: usize,
375 ticket: Option<usize>,
376}
377
378impl<'a, M: RawMutex, const N: usize> Drop for FairAcquireAll<'a, M, N> {
379 fn drop(&mut self) {
380 self.sema
381 .state
382 .lock(|cell| cell.borrow_mut().cancel(self.ticket.take()));
383 }
384}
385
386impl<'a, M: RawMutex, const N: usize> core::future::Future for FairAcquireAll<'a, M, N> {
387 type Output = Result<SemaphoreReleaser<'a, FairSemaphore<M, N>>, WaitQueueFull>;
388
389 fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
390 self.sema
391 .poll_acquire(self.min, true, Some((&mut self.ticket, cx.waker())))
392 }
393}
394
395#[derive(Debug)]
396struct FairSemaphoreState<const N: usize> {
397 permits: usize,
398 next_ticket: usize,
399 wakers: Deque<Option<Waker>, N>,
400}
401
402impl<const N: usize> FairSemaphoreState<N> {
403 const fn new(permits: usize) -> Self {
405 Self {
406 permits,
407 next_ticket: 0,
408 wakers: Deque::new(),
409 }
410 }
411
412 fn register(&mut self, ticket: Option<usize>, w: &Waker) -> Result<usize, WaitQueueFull> {
414 self.pop_canceled();
415
416 match ticket {
417 None => {
418 let ticket = self.next_ticket.wrapping_add(self.wakers.len());
419 self.wakers.push_back(Some(w.clone())).or(Err(WaitQueueFull))?;
420 Ok(ticket)
421 }
422 Some(ticket) => {
423 self.set_waker(ticket, Some(w.clone()));
424 Ok(ticket)
425 }
426 }
427 }
428
429 fn cancel(&mut self, ticket: Option<usize>) {
430 if let Some(ticket) = ticket {
431 self.set_waker(ticket, None);
432 }
433 }
434
435 fn set_waker(&mut self, ticket: usize, waker: Option<Waker>) {
436 let i = ticket.wrapping_sub(self.next_ticket);
437 if i < self.wakers.len() {
438 let (a, b) = self.wakers.as_mut_slices();
439 let x = if i < a.len() { &mut a[i] } else { &mut b[i - a.len()] };
440 *x = waker;
441 }
442 }
443
444 fn take(&mut self, ticket: Option<usize>, mut permits: usize, acquire_all: bool) -> Option<usize> {
445 self.pop_canceled();
446
447 if permits > self.permits {
448 return None;
449 }
450
451 match ticket {
452 Some(n) if n != self.next_ticket => return None,
453 None if !self.wakers.is_empty() => return None,
454 _ => (),
455 }
456
457 if acquire_all {
458 permits = self.permits;
459 }
460 self.permits -= permits;
461
462 if ticket.is_some() {
463 self.pop();
464 if self.permits > 0 {
465 self.wake();
466 }
467 }
468
469 Some(permits)
470 }
471
472 fn pop_canceled(&mut self) {
473 while let Some(None) = self.wakers.front() {
474 self.pop();
475 }
476 }
477
478 fn pop(&mut self) {
480 self.wakers.pop_front().unwrap();
481 self.next_ticket = self.next_ticket.wrapping_add(1);
482 }
483
484 fn wake(&mut self) {
485 self.pop_canceled();
486
487 if let Some(Some(waker)) = self.wakers.front() {
488 waker.wake_by_ref();
489 }
490 }
491}
492
493#[cfg(test)]
494mod tests {
495 mod greedy {
496 use core::pin::pin;
497
498 use futures_util::poll;
499
500 use super::super::*;
501 use crate::blocking_mutex::raw::NoopRawMutex;
502
503 #[test]
504 fn try_acquire() {
505 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
506
507 let a = semaphore.try_acquire(1).unwrap();
508 assert_eq!(a.permits(), 1);
509 assert_eq!(semaphore.permits(), 2);
510
511 core::mem::drop(a);
512 assert_eq!(semaphore.permits(), 3);
513 }
514
515 #[test]
516 fn disarm() {
517 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
518
519 let a = semaphore.try_acquire(1).unwrap();
520 assert_eq!(a.disarm(), 1);
521 assert_eq!(semaphore.permits(), 2);
522 }
523
524 #[futures_test::test]
525 async fn acquire() {
526 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
527
528 let a = semaphore.acquire(1).await.unwrap();
529 assert_eq!(a.permits(), 1);
530 assert_eq!(semaphore.permits(), 2);
531
532 core::mem::drop(a);
533 assert_eq!(semaphore.permits(), 3);
534 }
535
536 #[test]
537 fn try_acquire_all() {
538 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
539
540 let a = semaphore.try_acquire_all(1).unwrap();
541 assert_eq!(a.permits(), 3);
542 assert_eq!(semaphore.permits(), 0);
543 }
544
545 #[futures_test::test]
546 async fn acquire_all() {
547 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
548
549 let a = semaphore.acquire_all(1).await.unwrap();
550 assert_eq!(a.permits(), 3);
551 assert_eq!(semaphore.permits(), 0);
552 }
553
554 #[test]
555 fn release() {
556 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
557 assert_eq!(semaphore.permits(), 3);
558 semaphore.release(2);
559 assert_eq!(semaphore.permits(), 5);
560 }
561
562 #[test]
563 fn set() {
564 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
565 assert_eq!(semaphore.permits(), 3);
566 semaphore.set(2);
567 assert_eq!(semaphore.permits(), 2);
568 }
569
570 #[test]
571 fn contested() {
572 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
573
574 let a = semaphore.try_acquire(1).unwrap();
575 let b = semaphore.try_acquire(3);
576 assert!(b.is_none());
577
578 core::mem::drop(a);
579
580 let b = semaphore.try_acquire(3);
581 assert!(b.is_some());
582 }
583
584 #[futures_test::test]
585 async fn greedy() {
586 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
587
588 let a = semaphore.try_acquire(1).unwrap();
589
590 let b_fut = semaphore.acquire(3);
591 let mut b_fut = pin!(b_fut);
592 let b = poll!(b_fut.as_mut());
593 assert!(b.is_pending());
594
595 let c = semaphore.try_acquire(1);
597 assert!(c.is_some());
598
599 let b = poll!(b_fut.as_mut());
600 assert!(b.is_pending());
601
602 core::mem::drop(a);
603
604 let b = poll!(b_fut.as_mut());
605 assert!(b.is_pending());
606
607 core::mem::drop(c);
608
609 let b = poll!(b_fut.as_mut());
610 assert!(b.is_ready());
611 }
612 }
613
614 mod fair {
615 use core::pin::pin;
616 use core::time::Duration;
617
618 use futures_executor::ThreadPool;
619 use futures_timer::Delay;
620 use futures_util::poll;
621 use futures_util::task::SpawnExt;
622 use static_cell::StaticCell;
623
624 use super::super::*;
625 use crate::blocking_mutex::raw::{CriticalSectionRawMutex, NoopRawMutex};
626
627 #[test]
628 fn try_acquire() {
629 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
630
631 let a = semaphore.try_acquire(1).unwrap();
632 assert_eq!(a.permits(), 1);
633 assert_eq!(semaphore.permits(), 2);
634
635 core::mem::drop(a);
636 assert_eq!(semaphore.permits(), 3);
637 }
638
639 #[test]
640 fn disarm() {
641 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
642
643 let a = semaphore.try_acquire(1).unwrap();
644 assert_eq!(a.disarm(), 1);
645 assert_eq!(semaphore.permits(), 2);
646 }
647
648 #[futures_test::test]
649 async fn acquire() {
650 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
651
652 let a = semaphore.acquire(1).await.unwrap();
653 assert_eq!(a.permits(), 1);
654 assert_eq!(semaphore.permits(), 2);
655
656 core::mem::drop(a);
657 assert_eq!(semaphore.permits(), 3);
658 }
659
660 #[test]
661 fn try_acquire_all() {
662 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
663
664 let a = semaphore.try_acquire_all(1).unwrap();
665 assert_eq!(a.permits(), 3);
666 assert_eq!(semaphore.permits(), 0);
667 }
668
669 #[futures_test::test]
670 async fn acquire_all() {
671 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
672
673 let a = semaphore.acquire_all(1).await.unwrap();
674 assert_eq!(a.permits(), 3);
675 assert_eq!(semaphore.permits(), 0);
676 }
677
678 #[test]
679 fn release() {
680 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
681 assert_eq!(semaphore.permits(), 3);
682 semaphore.release(2);
683 assert_eq!(semaphore.permits(), 5);
684 }
685
686 #[test]
687 fn set() {
688 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
689 assert_eq!(semaphore.permits(), 3);
690 semaphore.set(2);
691 assert_eq!(semaphore.permits(), 2);
692 }
693
694 #[test]
695 fn contested() {
696 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
697
698 let a = semaphore.try_acquire(1).unwrap();
699 let b = semaphore.try_acquire(3);
700 assert!(b.is_none());
701
702 core::mem::drop(a);
703
704 let b = semaphore.try_acquire(3);
705 assert!(b.is_some());
706 }
707
708 #[futures_test::test]
709 async fn fairness() {
710 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
711
712 let a = semaphore.try_acquire(1);
713 assert!(a.is_some());
714
715 let b_fut = semaphore.acquire(3);
716 let mut b_fut = pin!(b_fut);
717 let b = poll!(b_fut.as_mut()); assert!(b.is_pending());
719
720 let c = semaphore.try_acquire(1);
721 assert!(c.is_none());
722
723 let c_fut = semaphore.acquire(1);
724 let mut c_fut = pin!(c_fut);
725 let c = poll!(c_fut.as_mut()); assert!(c.is_pending()); let d = semaphore.acquire(1).await;
729 assert!(matches!(d, Err(WaitQueueFull)));
730
731 core::mem::drop(a);
732
733 let c = poll!(c_fut.as_mut());
734 assert!(c.is_pending()); let b = poll!(b_fut.as_mut());
737 assert!(b.is_ready());
738
739 let c = poll!(c_fut.as_mut());
740 assert!(c.is_pending()); core::mem::drop(b);
743
744 let c = poll!(c_fut.as_mut());
745 assert!(c.is_ready());
746 }
747
748 #[futures_test::test]
749 async fn wakers() {
750 let executor = ThreadPool::new().unwrap();
751
752 static SEMAPHORE: StaticCell<FairSemaphore<CriticalSectionRawMutex, 2>> = StaticCell::new();
753 let semaphore = &*SEMAPHORE.init(FairSemaphore::new(3));
754
755 let a = semaphore.try_acquire(2);
756 assert!(a.is_some());
757
758 let b_task = executor
759 .spawn_with_handle(async move { semaphore.acquire(2).await })
760 .unwrap();
761 while semaphore.state.lock(|x| x.borrow().wakers.is_empty()) {
762 Delay::new(Duration::from_millis(50)).await;
763 }
764
765 let c_task = executor
766 .spawn_with_handle(async move { semaphore.acquire(1).await })
767 .unwrap();
768
769 core::mem::drop(a);
770
771 let b = b_task.await.unwrap();
772 assert_eq!(b.permits(), 2);
773
774 let c = c_task.await.unwrap();
775 assert_eq!(c.permits(), 1);
776 }
777 }
778}