use crate::error::{GlommioError, ResourceType};
use std::{
cell::RefCell,
future::Future,
pin::Pin,
task::{Context, Poll, Waker},
};
use intrusive_collections::{
container_of,
linked_list::LinkOps,
offset_of,
Adapter,
LinkedList,
LinkedListLink,
PointerOps,
};
use std::{marker::PhantomPinned, ptr::NonNull, rc::Rc};
type Result<T> = crate::error::Result<T, ()>;
#[derive(Debug)]
struct Waiter<'a> {
node: WaiterNode,
semaphore: &'a Semaphore,
}
#[derive(Debug)]
struct WaiterNode {
link: LinkedListLink,
units: u64,
waker: RefCell<Option<Waker>>,
_p: PhantomPinned,
}
struct WaiterPointerOps;
unsafe impl PointerOps for WaiterPointerOps {
type Value = WaiterNode;
type Pointer = NonNull<WaiterNode>;
unsafe fn from_raw(&self, value: *const Self::Value) -> Self::Pointer {
NonNull::new(value as *mut Self::Value).expect("Passed in Pointer can not be null")
}
fn into_raw(&self, ptr: Self::Pointer) -> *const Self::Value {
ptr.as_ptr() as *const Self::Value
}
}
struct WaiterAdapter {
pointers_ops: WaiterPointerOps,
link_ops: LinkOps,
}
impl WaiterAdapter {
fn new() -> Self {
WaiterAdapter {
pointers_ops: WaiterPointerOps,
link_ops: LinkOps,
}
}
}
unsafe impl Adapter for WaiterAdapter {
type LinkOps = LinkOps;
type PointerOps = WaiterPointerOps;
unsafe fn get_value(
&self,
link: <Self::LinkOps as intrusive_collections::LinkOps>::LinkPtr,
) -> *const <Self::PointerOps as PointerOps>::Value {
container_of!(link.as_ptr(), WaiterNode, link)
}
unsafe fn get_link(
&self,
value: *const <Self::PointerOps as PointerOps>::Value,
) -> <Self::LinkOps as intrusive_collections::LinkOps>::LinkPtr {
if value.is_null() {
panic!("Passed in pointer to the value can not be null");
}
let ptr = (value as *const u8).add(offset_of!(WaiterNode, link));
core::ptr::NonNull::new_unchecked(ptr as *mut _)
}
fn link_ops(&self) -> &Self::LinkOps {
&self.link_ops
}
fn link_ops_mut(&mut self) -> &mut Self::LinkOps {
&mut self.link_ops
}
fn pointer_ops(&self) -> &Self::PointerOps {
&self.pointers_ops
}
}
impl<'a> Waiter<'a> {
fn new(units: u64, semaphore: &'a Semaphore) -> Waiter<'a> {
Waiter {
node: WaiterNode {
link: LinkedListLink::new(),
units,
waker: RefCell::new(None),
_p: PhantomPinned,
},
semaphore,
}
}
fn remove_from_waiting_queue(
waiter_node: Pin<&mut WaiterNode>,
sem_state: &mut SemaphoreState,
) {
if waiter_node.link.is_linked() {
let mut cursor = unsafe {
sem_state
.waiters_list
.cursor_mut_from_ptr(Pin::into_inner_unchecked(waiter_node) as *const _)
};
if cursor.remove().is_none() {
panic!("Waiter has to be linked into the list of waiting futures");
}
}
}
fn register_in_waiting_queue(
waiter_node: Pin<&mut WaiterNode>,
sem_state: &mut SemaphoreState,
waker: Waker,
) {
*waiter_node.waker.borrow_mut() = Some(waker);
if waiter_node.link.is_linked() {
return;
}
sem_state.waiters_list.push_back(unsafe {
NonNull::new_unchecked(Pin::into_inner_unchecked(waiter_node) as *mut _)
});
}
}
impl<'a> Drop for Waiter<'a> {
fn drop(&mut self) {
if self.node.link.is_linked() {
let waiter_node = unsafe { Pin::new_unchecked(&mut self.node) };
Self::remove_from_waiting_queue(waiter_node, &mut self.semaphore.state.borrow_mut())
}
}
}
impl<'a> Future for Waiter<'a> {
type Output = Result<()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut sem_state = self.semaphore.state.borrow_mut();
let future_mut = unsafe { self.get_unchecked_mut() };
let waiter_node = unsafe { Pin::new_unchecked(&mut future_mut.node) };
let units = waiter_node.units;
match sem_state.try_acquire(units) {
Err(x) => {
Self::remove_from_waiting_queue(waiter_node, &mut sem_state);
Poll::Ready(Err(x))
}
Ok(true) => {
Self::remove_from_waiting_queue(waiter_node, &mut sem_state);
Poll::Ready(Ok(()))
}
Ok(false) => {
Self::register_in_waiting_queue(waiter_node, &mut sem_state, cx.waker().clone());
Poll::Pending
}
}
}
}
#[derive(Debug)]
struct SemaphoreState {
avail: u64,
closed: bool,
waiters_list: LinkedList<WaiterAdapter>,
}
impl SemaphoreState {
fn new(avail: u64) -> Self {
SemaphoreState {
avail,
closed: false,
waiters_list: LinkedList::new(WaiterAdapter::new()),
}
}
fn available(&self) -> u64 {
self.avail
}
fn try_acquire(&mut self, units: u64) -> Result<bool> {
if self.closed {
return Err(GlommioError::Closed(ResourceType::Semaphore {
requested: units,
available: self.avail,
}));
}
if self.avail >= units {
self.avail -= units;
return Ok(true);
}
Ok(false)
}
fn close(&mut self) {
self.closed = true;
let mut cursor = self.waiters_list.front_mut();
while !cursor.is_null() {
let node = cursor.remove().unwrap();
let node = unsafe { Pin::new_unchecked(&*node.as_ptr()) };
let waker = node.waker.borrow_mut().take();
if let Some(waker) = waker {
waker.wake();
} else {
panic!("Future is linked into the waiting list without a waker");
}
}
}
fn signal(&mut self, units: u64) {
self.avail += units;
}
}
#[derive(Debug)]
#[must_use = "units are only held while the permit is alive. If unused then semaphore will \
immediately release units"]
pub struct Permit<'a> {
units: u64,
sem: &'a Semaphore,
}
#[derive(Debug)]
pub struct StaticPermit {
units: u64,
sem: Rc<Semaphore>,
}
impl<'a> Permit<'a> {
fn new(units: u64, sem: &'a Semaphore) -> Permit<'a> {
Permit { units, sem }
}
}
impl StaticPermit {
fn new(units: u64, sem: Rc<Semaphore>) -> StaticPermit {
StaticPermit { units, sem }
}
pub fn close(&self) {
self.sem.close()
}
}
impl<'a> Drop for Permit<'a> {
fn drop(&mut self) {
process_wakes(self.sem, self.units);
}
}
impl Drop for StaticPermit {
fn drop(&mut self) {
process_wakes(&self.sem, self.units);
}
}
fn process_wakes(sem: &Semaphore, units: u64) {
let mut state = sem.state.borrow_mut();
state.signal(units);
let mut available_units = state.avail;
let mut cursor = state.waiters_list.front_mut();
while available_units > 0 {
let mut waker = None;
if let Some(node) = cursor.get() {
let node = unsafe { Pin::new_unchecked(node) };
if node.units <= available_units {
let w = node.waker.borrow_mut().take();
if w.is_some() {
waker = w;
} else {
panic!("Future was linked into the waiting list without a waker");
}
available_units -= node.units;
}
} else {
break;
}
if let Some(waker) = waker {
waker.wake();
cursor.remove();
} else {
cursor.move_next();
}
}
}
#[derive(Debug)]
pub struct Semaphore {
state: RefCell<SemaphoreState>,
}
impl Semaphore {
pub fn new(avail: u64) -> Semaphore {
Semaphore {
state: RefCell::new(SemaphoreState::new(avail)),
}
}
pub fn available(&self) -> u64 {
self.state.borrow().available()
}
pub async fn acquire_permit(&self, units: u64) -> Result<Permit<'_>> {
self.acquire(units).await?;
Ok(Permit::new(units, self))
}
pub async fn acquire_static_permit(self: &Rc<Self>, units: u64) -> Result<StaticPermit> {
self.acquire(units).await?;
Ok(StaticPermit::new(units, self.clone()))
}
pub async fn acquire(&self, units: u64) -> Result<()> {
let waiter = {
let mut state = self.state.borrow_mut();
if state.waiters_list.is_empty() && state.try_acquire(units)? {
return Ok(());
}
Waiter::new(units, self)
};
waiter.await
}
pub fn try_acquire(&self, units: u64) -> Result<bool> {
let mut state = self.state.borrow_mut();
if state.waiters_list.is_empty() && state.try_acquire(units)? {
return Ok(true);
}
Ok(false)
}
pub fn try_acquire_permit(&self, units: u64) -> Result<Permit<'_>> {
let mut state = self.state.borrow_mut();
if state.waiters_list.is_empty() && state.try_acquire(units)? {
return Ok(Permit::new(units, self));
}
Err(GlommioError::WouldBlock(ResourceType::Semaphore {
requested: units,
available: state.available(),
}))
}
pub fn signal(&self, units: u64) {
process_wakes(self, units);
}
pub fn close(&self) {
let mut state = self.state.borrow_mut();
state.close();
}
}
impl Drop for Semaphore {
fn drop(&mut self) {
assert!(self.state.borrow().waiters_list.is_empty());
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{
enclose,
timer::{sleep, Timer},
Local,
LocalExecutor,
};
use futures_lite::future::or;
use std::{
cell::Cell,
rc::Rc,
time::{Duration, Instant},
};
#[test]
fn semaphore_acquisition_for_zero_unit_works() {
make_shared_var!(Semaphore::new(1), sem1);
test_executor!(async move {
sem1.acquire(0).await.unwrap();
});
}
#[test]
fn permit_raii_works() {
test_executor!(async move {
let sem = Rc::new(Semaphore::new(0));
let exec = Rc::new(Cell::new(0));
let t1 = Local::local(enclose! { (sem, exec) async move {
exec.set(exec.get() + 1);
let _g = sem.acquire_permit(1).await.unwrap();
}});
let t2 = Task::local(enclose! { (sem, exec) async move {
exec.set(exec.get() + 1);
let _g = sem.acquire_permit(1).await.unwrap();
}});
let t3 = Local::local(enclose! { (sem, exec) async move {
exec.set(exec.get() + 1);
let g = sem.acquire_static_permit(1).await.unwrap();
Local::local(async move {
let _g = g;
sleep(Duration::from_secs(1)).await;
}).await
}});
while exec.get() != 3 {
Local::later().await;
}
sem.signal(1);
t3.await;
t2.await;
t1.await;
sleep(Duration::from_secs(2)).await;
assert_eq!(sem.available(), 1);
});
}
#[test]
fn explicit_signal_unblocks_waiting_semaphore() {
make_shared_var!(Semaphore::new(0), sem1, sem2);
make_shared_var_mut!(0, exec1, exec2);
test_executor!(
async move {
{
wait_on_cond!(exec1, 1);
let _g = sem1.acquire_permit(1).await.unwrap();
update_cond!(exec1, 2);
}
},
async move {
update_cond!(exec2, 1);
let _ = sem2.signal(1);
wait_on_cond!(exec2, 2, 1);
}
);
}
#[test]
fn explicit_signal_unblocks_many_wakers() {
make_shared_var!(Semaphore::new(0), sem1, sem2, sem3);
test_executor!(
async move {
sem1.acquire(1).await.unwrap();
},
async move {
sem2.acquire(1).await.unwrap();
},
async move {
sem3.signal(2);
}
);
}
#[test]
fn broken_semaphore_returns_the_right_error() {
test_executor!(async move {
let sem = Semaphore::new(0);
sem.close();
match sem.acquire(0).await {
Ok(_) => panic!("Should have failed"),
Err(e) => match e {
GlommioError::Closed(ResourceType::Semaphore { .. }) => {}
_ => panic!("Wrong Error"),
},
}
});
}
#[test]
fn try_acquire_sufficient_units() {
let sem = Semaphore::new(42);
assert!(sem.try_acquire(24).unwrap());
}
#[test]
fn try_acquire_permit_sufficient_units() {
let sem = Semaphore::new(42);
let _ = sem.try_acquire_permit(24).unwrap();
}
#[test]
fn try_acquire_insufficient_units() {
let sem = Semaphore::new(42);
assert!(!sem.try_acquire(62).unwrap());
}
#[test]
fn try_acquire_permit_insufficient_units() {
let sem = Semaphore::new(42);
let result = sem.try_acquire_permit(62);
assert!(result.is_err());
let err = result.err().unwrap();
if !matches!(
err,
GlommioError::WouldBlock(ResourceType::Semaphore { .. })
) {
panic!("Incorrect error type is returned from try_acquire_permit method");
}
}
#[test]
fn try_acquire_semaphore_is_closed() {
let sem = Semaphore::new(42);
sem.close();
let result = sem.try_acquire(24);
assert!(result.is_err());
let err = result.err().unwrap();
if !matches!(err, GlommioError::Closed(ResourceType::Semaphore { .. })) {
panic!("Incorrect error type is returned from try_acquire method");
}
}
#[test]
fn try_acquire_permit_semaphore_is_closed() {
let sem = Semaphore::new(42);
sem.close();
let result = sem.try_acquire_permit(24);
assert!(result.is_err());
let err = result.err().unwrap();
if !matches!(err, GlommioError::Closed(ResourceType::Semaphore { .. })) {
panic!("Incorrect error type is returned from try_acquire_permit method");
}
}
#[test]
#[should_panic]
fn broken_semaphore_if_close_happens_first() {
make_shared_var!(Semaphore::new(1), sem1, sem2);
make_shared_var_mut!(0, exec1, exec2);
test_executor!(
async move {
wait_on_cond!(exec1, 1);
let _g = sem1.acquire_permit(0).await.unwrap();
},
async move {
sem2.close();
update_cond!(exec2, 1);
}
);
}
#[test]
#[should_panic]
fn broken_semaphore_if_acquire_happens_first() {
make_shared_var!(Semaphore::new(0), sem1, sem2);
make_shared_var_mut!(0, exec1, exec2);
test_executor!(
async move {
update_cond!(exec1, 1);
let _g = sem1.acquire_permit(1).await.unwrap();
},
async move {
wait_on_cond!(exec2, 1);
sem2.close();
}
);
}
#[test]
fn semaphore_overflow() {
let ex = LocalExecutor::default();
let semaphore = Rc::new(Semaphore::new(0));
let semaphore_c = semaphore.clone();
ex.run(async move {
Local::local(async move {
for _ in 0..100 {
Timer::new(Duration::from_micros(100)).await;
}
let mut waiters_count = 0;
for _ in &semaphore_c.state.borrow().waiters_list {
waiters_count += 1;
}
assert_eq!(1, waiters_count);
semaphore_c.signal(1);
})
.detach();
let _ = semaphore.acquire(1).await.unwrap();
});
}
#[test]
fn semaphore_ensure_execution_order() {
let ex = LocalExecutor::default();
let semaphore = Rc::new(Semaphore::new(0));
let semaphore_c1 = semaphore.clone();
let semaphore_c2 = semaphore.clone();
let semaphore_c3 = semaphore.clone();
let state = Rc::new(RefCell::new(0));
let state_c1 = state.clone();
let state_c2 = state.clone();
let state_c3 = state.clone();
ex.run(async move {
let t1 = Local::local(async move {
*state_c1.borrow_mut() = 1;
let _g = semaphore_c1.acquire_permit(1).await.unwrap();
assert_eq!(*state_c1.borrow(), 3);
*state_c1.borrow_mut() = 4;
});
let t2 = Local::local(async move {
while *state_c2.borrow() != 1 {
Local::later().await;
}
*state_c2.borrow_mut() = 2;
let _g = semaphore_c2.acquire_permit(1).await.unwrap();
assert_eq!(*state_c2.borrow(), 4);
*state_c2.borrow_mut() = 5;
});
let t3 = Local::local(async move {
while *state_c3.borrow() != 2 {
Local::later().await;
}
*state_c3.borrow_mut() = 3;
let _g = semaphore_c3.acquire_permit(1).await.unwrap();
assert_eq!(*state_c3.borrow(), 5);
});
Local::local(async move {
while *state.borrow() != 3 {
Local::later().await;
}
semaphore.signal(1);
})
.detach();
or(or(t1, t2), t3).await;
});
}
}