use core::{pin::Pin, task::Poll};
use std::{cell::{Cell, RefCell}, collections::VecDeque, rc::Rc, sync::{Arc}, task::{Context, Waker}};
use crossbeam_utils::CachePadded;
use lfqueue::UnboundedQueue;
use crate::base::{event::{EventSetter, EventState, EventWaker, impl_async_event}, lot::{AsyncLotSignature, AsyncLotTemplate, impl_async_lot}, signal::{TechnicalCounter, ValidityMarker, ValidityState, WakeQueue}};
use crate::atomic::*;
pub struct CountedEvent<E> {
inner: E,
counter: AtomicUsize,
}
#[pin_project::pin_project(PinnedDrop)]
#[allow(private_bounds)]
pub struct CountedAwaiter<'a, E, F>
where
E: EventSetter<'a>,
F: Future<Output = ()>,
{
master: &'a CountedEvent<E>,
#[pin]
future: F,
state: CountedAwaiterState,
}
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum CountedAwaiterState {
Init,
Waiting,
Complete,
}
#[pin_project::pinned_drop]
impl<'a, E, F> PinnedDrop for CountedAwaiter<'a, E, F>
where
E: EventSetter<'a>,
F: Future<Output = ()>,
{
fn drop(self: Pin<&mut Self>) {
match self.state {
CountedAwaiterState::Init => {
}
CountedAwaiterState::Waiting => {
self.master.counter.fetch_sub(1, Ordering::Release);
}
CountedAwaiterState::Complete => {
}
}
}
}
impl<'a, E, F> Future for CountedAwaiter<'a, E, F>
where
E: EventSetter<'a>,
F: Future<Output = ()> + Unpin,
{
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.state {
CountedAwaiterState::Init => {
match this.future.poll(cx) {
Poll::Pending => {
this.master.counter.fetch_add(1, Ordering::Release);
*this.state = CountedAwaiterState::Waiting;
Poll::Pending
}
Poll::Ready(_) => {
Poll::Ready(())
}
}
}
CountedAwaiterState::Waiting => {
match this.future.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(_) => {
this.master.counter.fetch_sub(1, Ordering::Release);
*this.state = CountedAwaiterState::Complete;
Poll::Ready(())
}
}
}
CountedAwaiterState::Complete => {
Poll::Ready(())
}
}
}
}
#[allow(private_bounds)]
impl<E> CountedEvent<E>
where
for<'a> E: EventSetter<'a>,
{
#[allow(private_interfaces)]
pub fn new(source: E) -> Self {
Self {
inner: source,
counter: AtomicUsize::new(0),
}
}
#[allow(private_interfaces)]
pub fn wait(&self) -> <Self as EventSetter>::Waiter {
<Self as EventSetter>::wait(&self)
}
#[allow(private_interfaces)]
pub fn set_one(&self) {
<Self as EventSetter>::set_one(&self);
}
#[allow(private_interfaces)]
pub fn set_all(&self) {
<Self as EventSetter>::set_all(&self, || {});
}
pub fn count(&self) -> usize {
self.counter.load(Acquire)
}
}
impl<'a, E> EventSetter<'a> for CountedEvent<E>
where
E: EventSetter<'a> + 'a,
{
type Waiter = CountedAwaiter<'a, E, E::Waiter>;
fn new() -> Self {
Self {
counter: AtomicUsize::default(),
inner: E::new()
}
}
fn new_set() -> Self {
Self {
counter: AtomicUsize::default(),
inner: E::new_set()
}
}
fn wait(&'a self) -> CountedAwaiter<'a, E, E::Waiter> {
CountedAwaiter {
master: self,
future: self.inner.wait(),
state: CountedAwaiterState::Init,
}
}
fn try_wait(&self) -> bool {
if self.inner.try_wait() {
true
} else {
false
}
}
fn set_all<F: FnMut()>(&self, mut functor: F) {
self.inner.set_all(|| {
functor();
self.counter.fetch_sub(1, Ordering::Release);
});
}
fn set_one(&self) -> bool {
if self.inner.set_one() {
true
} else {
false
}
}
fn has_waiters(&self) -> bool {
self.inner.has_waiters()
}
}
impl ValidityMarker for Arc<AtomicU8> {
fn create() -> Self {
Arc::new(AtomicU8::new(ValidityState::Idle as u8))
}
fn get(&self) -> ValidityState {
ValidityState::from_u8(self.load(Acquire)).expect("Invalid validity state discriminant stored.")
}
fn set(&self, value: ValidityState) {
self.store(value as u8, Release);
}
}
impl TechnicalCounter for CachePadded<AtomicUsize> {
fn decrement(&self) {
self.fetch_sub(1, Release);
}
fn increment(&self) {
self.fetch_add(1, Release);
}
fn get(&self) -> usize {
self.load(Acquire)
}
}
impl<T> WakeQueue<T> for UnboundedQueue<T>
where
T: Send + Sync
{
fn dequeue(&self) -> Option<T> {
UnboundedQueue::dequeue(&self)
}
fn enqueue(&self, item: T) {
UnboundedQueue::enqueue(&self, item);
}
}
#[derive(Default)]
struct LocalCounter(Cell<usize>);
impl TechnicalCounter for LocalCounter {
fn decrement(&self) {
let current = self.0.get();
self.0.set(current - 1);
}
fn get(&self) -> usize {
self.0.get()
}
fn increment(&self) {
let current = self.0.get();
self.0.set(current + 1);
}
}
struct LocalQueue<T>(RefCell<VecDeque<T>>);
impl<T> Default for LocalQueue<T> {
fn default() -> Self {
Self(RefCell::default())
}
}
impl<T> WakeQueue<T> for LocalQueue<T> {
fn dequeue(&self) -> Option<T> {
self.0.borrow_mut().pop_front()
}
fn enqueue(&self, item: T) {
self.0.borrow_mut().push_back(item);
}
}
impl ValidityMarker for Rc<Cell<ValidityState>> {
fn create() -> Self {
Rc::new(Cell::new(ValidityState::Idle))
}
fn get(&self) -> ValidityState {
Cell::get(&self)
}
fn set(&self, value: ValidityState) {
Cell::set(&self, value);
}
}
impl_async_lot!(
name = LocalAllocatedAsyncLot,
waker = EventWaker,
validity = Rc<Cell<ValidityState>>,
counter = LocalCounter,
queue = LocalQueue
);
impl_async_lot!(
name = AllocatedAsyncLot,
waker = EventWaker,
validity = Arc<AtomicU8>,
counter = CachePadded<AtomicUsize>,
queue = UnboundedQueue
);
impl EventState for AtomicU8 {
fn new(state: u8) -> Self {
AtomicU8::new(state)
}
fn cmpxchng_weak(
&self,
current: u8,
new: u8,
success: Ordering,
failure: Ordering,
) -> Result<u8, u8> {
AtomicU8::compare_exchange_weak(&self, current, new, success, failure)
}
fn store(
&self,
value: u8,
ordering: Ordering
) {
AtomicU8::store(&self, value, ordering);
}
}
impl_async_event!(
name = Event,
waitername = EventAwait,
lot = AllocatedAsyncLot,
state = AtomicU8
);
unsafe impl Send for Event {}
unsafe impl Sync for Event {}
struct LocalState(Cell<u8>);
impl EventState for LocalState {
fn cmpxchng_weak(
&self,
current: u8,
new: u8,
_: Ordering,
_: Ordering,
) -> Result<u8, u8> {
let value = self.0.get();
if value == current {
self.0.set(new);
Ok(new)
} else {
Err(value)
}
}
fn new(state: u8) -> Self {
Self(Cell::new(state))
}
fn store(
&self,
value: u8,
_: Ordering
) {
self.0.set(value);
}
}
impl_async_event!(
name = LocalEvent,
waitername = LocalEventAwait,
lot = LocalAllocatedAsyncLot,
state = LocalState
);