use crate::current;
use crate::runtime::execution::ExecutionState;
use crate::runtime::task::{clock::VectorClock, TaskId};
use crate::runtime::thread;
use std::cell::RefCell;
use std::collections::VecDeque;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::sync::Mutex;
use std::task::{Context, Poll, Waker};
use tracing::trace;
struct Waiter {
task_id: TaskId,
num_permits: usize,
is_queued: AtomicBool,
has_permits: AtomicBool,
clock: VectorClock,
waker: Mutex<Option<Waker>>,
}
impl fmt::Debug for Waiter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Waiter")
.field("task_id", &self.task_id)
.field("num_permits", &self.num_permits)
.field("is_queued", &self.is_queued)
.field("has_permits", &self.has_permits)
.field("waker", &self.waker)
.finish()
}
}
impl Waiter {
fn new(num_permits: usize) -> Self {
Self {
task_id: ExecutionState::me(),
num_permits,
is_queued: AtomicBool::new(false),
has_permits: AtomicBool::new(false),
clock: current::clock(),
waker: Mutex::new(None),
}
}
}
struct PermitsAvailable {
num_available: usize,
permit_clocks: Option<VecDeque<(usize, VectorClock)>>,
last_acquire: VectorClock,
}
impl fmt::Debug for PermitsAvailable {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PermitsAvailable")
.field("num_available", &self.num_available)
.finish()
}
}
impl PermitsAvailable {
fn new(num_permits: usize) -> Self {
let mut permit_clocks = VecDeque::new();
if num_permits > 0 {
permit_clocks.push_back((num_permits, current::clock()));
}
Self {
num_available: num_permits,
permit_clocks: Some(permit_clocks),
last_acquire: VectorClock::new(),
}
}
const fn const_new(num_permits: usize) -> Self {
Self {
num_available: num_permits,
permit_clocks: None,
last_acquire: VectorClock::new(),
}
}
fn available(&self) -> usize {
self.num_available
}
fn init_permit_clocks(&mut self) {
if self.permit_clocks.is_none() {
let mut permit_clocks = VecDeque::new();
if self.num_available > 0 {
permit_clocks.push_back((self.num_available, VectorClock::new()));
}
self.permit_clocks = Some(permit_clocks);
}
}
fn acquire(&mut self, mut num_permits: usize, acquire_clock: VectorClock) -> Result<VectorClock, TryAcquireError> {
if num_permits == 0 {
return Ok(VectorClock::new());
}
if num_permits <= self.num_available {
self.init_permit_clocks();
self.last_acquire.update(&acquire_clock);
self.num_available -= num_permits;
let mut clock = VectorClock::new();
let permit_clocks = self.permit_clocks.as_mut().unwrap();
while let Some((batch_size, batch_clock)) = permit_clocks.front_mut() {
clock.update(batch_clock);
if num_permits < *batch_size {
*batch_size -= num_permits;
num_permits = 0;
} else {
num_permits -= *batch_size;
permit_clocks.pop_front();
}
if num_permits == 0 {
break;
}
}
assert_eq!(num_permits, 0);
Ok(clock)
} else {
Err(TryAcquireError::NoPermits)
}
}
fn release(&mut self, num_permits: usize, clock: VectorClock) {
self.init_permit_clocks();
self.num_available += num_permits;
self.permit_clocks.as_mut().unwrap().push_back((num_permits, clock));
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Fairness {
StrictlyFair,
Unfair,
}
#[derive(Debug)]
struct BatchSemaphoreState {
id: Option<crate::annotations::ObjectId>,
waiters: VecDeque<Arc<Waiter>>,
permits_available: PermitsAvailable,
closed: bool,
}
impl BatchSemaphoreState {
fn acquire_permits(&mut self, num_permits: usize, fairness: Fairness) -> Result<(), TryAcquireError> {
assert!(num_permits > 0);
if self.closed {
Err(TryAcquireError::Closed)
} else if self.waiters.is_empty() || matches!(fairness, Fairness::Unfair) {
let clock = self.permits_available.acquire(num_permits, current::clock())?;
ExecutionState::with(|s| {
s.update_clock(&clock);
});
Ok(())
} else {
Err(TryAcquireError::NoPermits)
}
}
fn unblock_waiters_from_front(&mut self) {
while let Some(front) = self.waiters.front() {
if front.num_permits <= self.permits_available.available() {
let waiter = self.waiters.pop_front().unwrap();
crate::annotations::record_semaphore_acquire_unblocked(
self.id.unwrap(),
waiter.task_id,
waiter.num_permits,
);
let clock = self
.permits_available
.acquire(waiter.num_permits, waiter.clock.clone())
.unwrap();
trace!("granted {:?} permits to waiter {:?}", waiter.num_permits, waiter);
assert!(waiter.is_queued.swap(false, Ordering::SeqCst));
assert!(!waiter.has_permits.swap(true, Ordering::SeqCst));
ExecutionState::with(|s| {
let task = s.get_mut(waiter.task_id);
assert!(!task.finished());
task.clock.update(&clock);
task.unblock();
});
let mut maybe_waker = waiter.waker.lock().unwrap();
if let Some(waker) = maybe_waker.take() {
waker.wake();
}
} else {
return;
}
}
}
}
#[derive(Debug)]
pub struct BatchSemaphore {
state: RefCell<BatchSemaphoreState>,
fairness: Fairness,
}
#[derive(Debug, PartialEq, Eq)]
pub enum TryAcquireError {
Closed,
NoPermits,
}
#[derive(Debug)]
pub struct AcquireError(());
impl AcquireError {
fn closed() -> AcquireError {
AcquireError(())
}
}
impl fmt::Display for AcquireError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "semaphore closed")
}
}
impl std::error::Error for AcquireError {}
impl BatchSemaphore {
pub fn new(num_permits: usize, fairness: Fairness) -> Self {
let state = RefCell::new(BatchSemaphoreState {
id: Some(crate::annotations::record_semaphore_created()),
waiters: VecDeque::new(),
permits_available: PermitsAvailable::new(num_permits),
closed: false,
});
Self { state, fairness }
}
pub const fn const_new(num_permits: usize, fairness: Fairness) -> Self {
let state = RefCell::new(BatchSemaphoreState {
id: None,
waiters: VecDeque::new(),
permits_available: PermitsAvailable::const_new(num_permits),
closed: false,
});
Self { state, fairness }
}
pub fn available_permits(&self) -> usize {
let state = self.state.borrow();
state.permits_available.available()
}
fn init_object_id(&self) {
let mut state = self.state.borrow_mut();
if state.id.is_none() {
state.id = Some(crate::annotations::record_semaphore_created());
}
}
pub fn close(&self) {
self.init_object_id();
let mut state = self.state.borrow_mut();
if state.closed {
return;
}
crate::annotations::record_semaphore_closed(state.id.unwrap());
state.closed = true;
let ptr = &*state as *const BatchSemaphoreState;
for waiter in state.waiters.drain(..) {
trace!(
"semaphore {:p} removing and waking up waiter {:?} on close",
ptr,
waiter,
);
assert!(waiter.is_queued.swap(false, Ordering::SeqCst));
assert!(!waiter.has_permits.load(Ordering::SeqCst)); ExecutionState::with(|exec_state| {
if !exec_state.in_cleanup() {
exec_state.get_mut(waiter.task_id).unblock();
}
});
let mut maybe_waker = waiter.waker.lock().unwrap();
if let Some(waker) = maybe_waker.take() {
waker.wake();
}
}
}
pub fn is_closed(&self) -> bool {
let state = self.state.borrow();
state.closed
}
pub fn try_acquire(&self, num_permits: usize) -> Result<(), TryAcquireError> {
self.init_object_id();
let mut state = self.state.borrow_mut();
let id = state.id.unwrap();
let res = state.acquire_permits(num_permits, self.fairness).inspect_err(|_err| {
ExecutionState::with(|s| {
s.update_clock(&state.permits_available.last_acquire);
});
});
drop(state);
if res.is_ok() {
self.reblock_if_unfair();
}
crate::annotations::record_semaphore_try_acquire(id, num_permits, res.is_ok());
thread::switch();
res
}
fn reblock_if_unfair(&self) {
if self.fairness == Fairness::Unfair {
let state = self.state.borrow_mut();
ExecutionState::with(|s| {
for waiter in &state.waiters {
let available = state.permits_available.available();
if available < waiter.num_permits {
s.get_mut(waiter.task_id).block(false);
}
}
});
}
}
fn enqueue_waiter(&self, waiter: &Arc<Waiter>) {
let mut state = self.state.borrow_mut();
trace!("enqueuing waiter {:?} for semaphore {:p}", waiter, &self.state);
state.waiters.push_back(waiter.clone());
assert!(!waiter.has_permits.load(Ordering::SeqCst));
assert!(!waiter.is_queued.swap(true, Ordering::SeqCst));
}
fn remove_waiter(&self, waiter: &Arc<Waiter>) {
let mut state = self.state.borrow_mut();
trace!(waiters = ?state.waiters, "removing waiter {:?} from semaphore {:p}", waiter, &self.state);
assert!(!state.closed);
assert!(!waiter.has_permits.load(Ordering::SeqCst));
let index = state
.waiters
.iter()
.position(|x| Arc::ptr_eq(x, waiter))
.expect("did not find waiter");
state.waiters.remove(index).unwrap();
assert!(waiter.is_queued.swap(false, Ordering::SeqCst));
match self.fairness {
Fairness::StrictlyFair => {
if index == 0 {
state.unblock_waiters_from_front();
}
}
Fairness::Unfair => {}
}
}
pub fn acquire(&self, num_permits: usize) -> Acquire<'_> {
self.init_object_id();
Acquire::new(self, num_permits)
}
pub fn acquire_blocking(&self, num_permits: usize) -> Result<(), AcquireError> {
self.init_object_id();
crate::future::block_on(self.acquire(num_permits))
}
pub fn release(&self, num_permits: usize) {
self.init_object_id();
if num_permits == 0 {
return;
}
let mut state = self.state.borrow_mut();
crate::annotations::record_semaphore_release(state.id.unwrap(), num_permits);
if ExecutionState::should_stop() {
state.permits_available.release(num_permits, VectorClock::new());
for waiter in &state.waiters {
waiter.is_queued.swap(false, Ordering::SeqCst);
}
state.waiters.clear();
state.closed = true;
return;
}
ExecutionState::with(|s| {
let clock = s.increment_clock();
state.permits_available.release(num_permits, clock.clone());
});
let me = ExecutionState::me();
trace!(task = ?me, avail = ?state.permits_available, waiters = ?state.waiters, "released {} permits for semaphore {:p}", num_permits, &self.state);
match self.fairness {
Fairness::StrictlyFair => {
state.unblock_waiters_from_front();
}
Fairness::Unfair => {
let num_available = state.permits_available.available();
for waiter in &mut state.waiters {
if waiter.num_permits <= num_available {
ExecutionState::with(|s| {
let task = s.get_mut(waiter.task_id);
assert!(!task.finished());
task.unblock();
});
let maybe_waker = waiter.waker.lock().unwrap();
if let Some(waker) = maybe_waker.as_ref() {
waker.wake_by_ref();
}
}
}
}
}
drop(state);
thread::switch();
}
}
unsafe impl Send for BatchSemaphore {}
unsafe impl Sync for BatchSemaphore {}
impl Default for BatchSemaphore {
fn default() -> Self {
Self::new(Default::default(), Fairness::StrictlyFair)
}
}
#[derive(Debug)]
pub struct Acquire<'a> {
waiter: Arc<Waiter>,
semaphore: &'a BatchSemaphore,
completed: bool, }
impl<'a> Acquire<'a> {
fn new(semaphore: &'a BatchSemaphore, num_permits: usize) -> Self {
let waiter = Arc::new(Waiter::new(num_permits));
Self {
waiter,
semaphore,
completed: false,
}
}
}
impl Future for Acquire<'_> {
type Output = Result<(), AcquireError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
assert!(!self.completed);
if self.waiter.has_permits.load(Ordering::SeqCst) {
assert!(!self.waiter.is_queued.load(Ordering::SeqCst));
self.completed = true;
trace!("Acquire::poll for waiter {:?} with permits", self.waiter);
Poll::Ready(Ok(()))
} else if self.semaphore.is_closed() {
assert!(!self.waiter.is_queued.load(Ordering::SeqCst));
self.completed = true;
trace!("Acquire::poll for waiter {:?} with closed", self.waiter);
Poll::Ready(Err(AcquireError::closed()))
} else {
let is_queued = self.waiter.is_queued.load(Ordering::SeqCst);
trace!("Acquire::poll for waiter {:?}; is queued: {is_queued:?}", self.waiter);
assert_eq!(is_queued, self.waiter.waker.lock().unwrap().is_some());
let try_to_acquire = match (self.semaphore.fairness, is_queued) {
(Fairness::Unfair, false) | (Fairness::StrictlyFair, false) | (Fairness::Unfair, true) => true,
(Fairness::StrictlyFair, true) => false,
};
if try_to_acquire {
let mut state = self.semaphore.state.borrow_mut();
let id = state.id.unwrap();
let acquire_result = state.acquire_permits(self.waiter.num_permits, self.semaphore.fairness);
drop(state);
match acquire_result {
Ok(()) => {
if is_queued {
crate::annotations::record_semaphore_acquire_unblocked(
id,
self.waiter.task_id,
self.waiter.num_permits,
);
self.semaphore.remove_waiter(&self.waiter);
} else {
crate::annotations::record_semaphore_acquire_fast(id, self.waiter.num_permits);
}
self.waiter.has_permits.store(true, Ordering::SeqCst);
self.completed = true;
trace!("Acquire::poll for waiter {:?} that got permits", self.waiter);
self.semaphore.reblock_if_unfair();
thread::switch();
Poll::Ready(Ok(()))
}
Err(TryAcquireError::NoPermits) => {
let mut maybe_waker = self.waiter.waker.lock().unwrap();
*maybe_waker = Some(cx.waker().clone());
if !is_queued {
crate::annotations::record_semaphore_acquire_blocked(id, self.waiter.num_permits);
self.semaphore.enqueue_waiter(&self.waiter);
self.waiter.is_queued.store(true, Ordering::SeqCst);
}
trace!("Acquire::poll for waiter {:?} that is enqueued", self.waiter);
Poll::Pending
}
Err(TryAcquireError::Closed) => unreachable!(),
}
} else {
Poll::Pending
}
}
}
}
impl Drop for Acquire<'_> {
fn drop(&mut self) {
trace!("Acquire::drop for Acquire {:p} with waiter {:?}", self, self.waiter);
if self.waiter.is_queued.load(Ordering::SeqCst) {
self.semaphore.remove_waiter(&self.waiter);
} else if self.waiter.has_permits.load(Ordering::SeqCst) && !self.completed {
self.semaphore.release(self.waiter.num_permits);
}
}
}
impl crate::annotations::WithName for &BatchSemaphore {
fn with_name_and_kind(self, name: Option<&str>, kind: Option<&str>) -> Self {
self.init_object_id();
crate::annotations::record_name_for_object(self.state.borrow().id.unwrap(), name, kind);
self
}
}
impl crate::annotations::WithName for BatchSemaphore {
fn with_name_and_kind(self, name: Option<&str>, kind: Option<&str>) -> Self {
(&self).with_name_and_kind(name, kind);
self
}
}