use core::cell::Cell;
use core::pin::Pin;
use core::task::{Poll, Waker};
use pin_cell::{PinCell, PinMut};
use pin_project::pin_project;
use wakerset::{ExtractedWakers, WakerList, WakerSlot};
#[derive(Debug, Default)]
#[pin_project]
pub struct Source {
generation: Cell<usize>,
#[pin]
list: PinCell<WakerList>,
}
impl Source {
pub(crate) fn new() -> Self {
Self::default()
}
pub(crate) fn waiter(self: Pin<&Self>) -> Waiter<'_> {
Waiter::new(self)
}
pub(crate) fn increment_generation(self: Pin<&Self>) {
self.generation.set(self.generation.get() + 1);
let round = PinMut::as_mut(&mut self.project_ref().list.borrow_mut()).begin_extraction();
let mut wakers = ExtractedWakers::new();
let mut more = true;
while more {
more = PinMut::as_mut(&mut self.project_ref().list.borrow_mut())
.extract_some_wakers(round, &mut wakers);
wakers.wake_all();
}
}
fn link(self: Pin<&Self>, slot: Pin<&mut WakerSlot>, waker: Waker) {
PinMut::as_mut(&mut self.project_ref().list.borrow_mut()).link(slot, waker)
}
fn unlink(self: Pin<&Self>, slot: Pin<&mut WakerSlot>) {
PinMut::as_mut(&mut self.project_ref().list.borrow_mut()).unlink(slot)
}
}
#[derive(Debug)]
pub(crate) struct Waiter<'a> {
generation: usize,
source: Pin<&'a Source>,
}
impl<'a> Waiter<'a> {
pub(crate) fn new(source: Pin<&'a Source>) -> Self {
Self {
generation: source.generation.get(),
source,
}
}
pub(crate) fn update_generation(&mut self) {
self.generation = self.source.generation.get();
}
pub(crate) async fn wait(&self) -> Result<(), MissedUpdate> {
struct Guard<'a, 'b> {
source: Pin<&'a Source>,
slot: Pin<&'b mut WakerSlot>,
}
impl Drop for Guard<'_, '_> {
fn drop(&mut self) {
if self.slot.is_linked() {
self.source.unlink(self.slot.as_mut());
}
}
}
use core::pin::pin;
let mut guard = Guard {
source: self.source,
slot: pin!(WakerSlot::new()),
};
core::future::poll_fn(|cx| {
let current = self.source.generation.get();
if current == self.generation {
self.source.link(guard.slot.as_mut(), cx.waker().clone());
return Poll::Pending;
}
let expected = self.generation + 1;
if current != expected {
return Poll::Ready(Err(MissedUpdate { expected, current }));
}
Poll::Ready(Ok(()))
})
.await
}
}
pub(crate) struct MissedUpdate {
pub(crate) expected: usize,
pub(crate) current: usize,
}
#[cfg(test)]
mod tests {
use std::cell::Cell;
use std::future::Future;
use std::pin::pin;
use crate::datastore::generational;
#[test]
fn example() {
let source = pin!(generational::Source::new());
let counter = Cell::new(0);
let sum = Cell::new(0);
let mut waiter = source.as_ref().waiter();
let mut future = pin!(async {
loop {
let _ = waiter.wait().await;
waiter.update_generation();
sum.set(sum.get() + counter.get());
}
});
let mut context = std::task::Context::from_waker(futures::task::noop_waker_ref());
for i in 1..10 {
assert!(future.as_mut().poll(&mut context).is_pending());
assert_eq!(sum.get(), (i - 1) * i / 2);
counter.set(i);
source.as_ref().increment_generation();
assert!(future.as_mut().poll(&mut context).is_pending());
assert_eq!(sum.get(), i * (i + 1) / 2);
}
}
}