use super::RecvError;
use std::{
cell::{Cell, Ref, RefCell},
future::Future,
ops::Deref,
pin::Pin,
rc::Rc,
task::{Context, Poll, Waker},
};
#[derive(Debug)]
pub struct Value<T> {
value: T,
version: u64,
}
#[derive(Debug)]
struct State<T> {
value: RefCell<Value<T>>,
wakers: RefCell<Vec<Waker>>,
sender_exists: Cell<bool>,
}
impl<T> State<T> {
fn add_waker(&self, waker: &Waker) {
let mut wakers = self.wakers.borrow_mut();
if !wakers.iter().any(|w| waker.will_wake(w)) {
wakers.push(waker.clone());
}
}
fn wake_all(&self) {
for waker in self.wakers.borrow_mut().drain(..) {
waker.wake()
}
}
}
#[derive(thiserror::Error, Debug)]
#[error(
"failed to send this value, as someone is currently holding a reference to the previous value"
)]
pub struct SendError<T>(pub T);
#[derive(Debug)]
pub struct Sender<T> {
state: Rc<State<T>>,
}
#[derive(Debug)]
pub struct Receiver<T> {
state: Rc<State<T>>,
seen_version: u64,
}
impl<T> Sender<T> {
pub fn subscribe(&self) -> Receiver<T> {
Receiver {
state: self.state.clone(),
seen_version: self.state.value.borrow().version,
}
}
pub fn send(&self, value: T) -> Result<(), SendError<T>> {
if let Ok(mut value_ref) = self.state.value.try_borrow_mut() {
value_ref.value = value;
value_ref.version = value_ref.version.wrapping_add(1);
} else {
return Err(SendError(value));
}
self.state.wake_all();
Ok(())
}
pub fn send_modify(&self, modify: impl FnOnce(&mut T)) {
self.try_send_modify(modify)
.expect("no receivers referencing the old value")
}
pub fn try_send_modify(&self, modify: impl FnOnce(&mut T)) -> Result<(), SendError<()>> {
let mut value_ref = self
.state
.value
.try_borrow_mut()
.map_err(|_| SendError(()))?;
modify(&mut value_ref.value);
value_ref.version = value_ref.version.wrapping_add(1);
self.state.wake_all();
Ok(())
}
pub fn borrow(&self) -> ValueRef<'_, T> {
ValueRef(self.state.value.borrow())
}
pub fn get(&self) -> T
where
T: Copy,
{
*self.borrow().deref()
}
pub fn get_cloned(&self) -> T
where
T: Clone,
{
self.borrow().deref().clone()
}
pub fn is_closed(&self) -> bool {
Rc::strong_count(&self.state) == 1
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
self.state.sender_exists.set(false);
self.state.wake_all()
}
}
pub struct ValueRef<'a, T>(Ref<'a, Value<T>>);
impl<T> Deref for ValueRef<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0.value
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Notification<'a, T> {
rx: &'a mut Receiver<T>,
}
impl<T> Receiver<T> {
pub fn has_changed(&self) -> bool {
self.state.value.borrow().version != self.seen_version
}
pub fn mark_seen(&mut self) {
self.seen_version = self.state.value.borrow().version;
}
#[cfg(feature = "internal_test")]
pub fn value_version(&self) -> u64 {
self.state.value.borrow().version
}
#[cfg(feature = "internal_test")]
pub fn seen_version(&self) -> u64 {
self.seen_version
}
pub fn changed(&mut self) -> Notification<'_, T> {
Notification { rx: self }
}
pub fn borrow(&self) -> ValueRef<'_, T> {
ValueRef(self.state.value.borrow())
}
pub fn get(&self) -> T
where
T: Copy,
{
*self.borrow().deref()
}
pub fn get_cloned(&self) -> T
where
T: Clone,
{
self.borrow().deref().clone()
}
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
seen_version: self.state.value.borrow().version,
}
}
}
impl<T> Future for Notification<'_, T> {
type Output = Result<(), RecvError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if !self.rx.state.sender_exists.get() {
return Poll::Ready(Err(RecvError));
}
let version = self.rx.state.value.borrow().version;
if version != self.rx.seen_version {
self.rx.seen_version = version;
Poll::Ready(Ok(()))
} else {
self.rx.state.add_waker(cx.waker());
Poll::Pending
}
}
}
pub fn channel<T>(initial: T) -> (Sender<T>, Receiver<T>) {
let state = State {
value: RefCell::new(Value {
value: initial,
version: 0,
}),
wakers: Default::default(),
sender_exists: Cell::new(true),
};
let tx = Sender {
state: Rc::new(state),
};
let rx = tx.subscribe();
(tx, rx)
}
#[cfg(feature = "internal_test")]
mod tests {
#![allow(clippy::approx_constant)]
use super::*;
use crate::fiber;
use crate::fiber::r#async::timeout::{self, IntoTimeout};
use futures::join;
use std::time::Duration;
const _1_SEC: Duration = Duration::from_secs(1);
#[crate::test(tarantool = "crate")]
fn receive_notification_sent_before() {
let (tx, mut rx_1) = channel::<i32>(10);
let mut rx_2 = rx_1.clone();
let mut rx_3 = tx.subscribe();
tx.send(20).unwrap();
assert_eq!(
fiber::block_on(async move {
let _ = join!(rx_1.changed(), rx_2.changed(), rx_3.changed());
(*rx_1.borrow(), *rx_2.borrow(), *rx_3.borrow())
}),
(20, 20, 20)
);
}
#[crate::test(tarantool = "crate")]
fn receive_notification_sent_after() {
let (tx, mut rx_1) = channel::<i32>(10);
let mut rx_2 = rx_1.clone();
let mut rx_3 = tx.subscribe();
let jh = fiber::start_async(async move {
let _ = join!(rx_1.changed(), rx_2.changed(), rx_3.changed());
(*rx_1.borrow(), *rx_2.borrow(), *rx_3.borrow())
});
tx.send(20).unwrap();
assert_eq!(jh.join(), (20, 20, 20))
}
#[crate::test(tarantool = "crate")]
fn receive_multiple_notifications() {
let (tx, mut rx_1) = channel::<i32>(10);
let jh = fiber::start_async(async {
rx_1.changed().await.unwrap();
*rx_1.borrow()
});
tx.send(1).unwrap();
assert_eq!(jh.join(), 1);
let jh = fiber::start_async(async {
rx_1.changed().await.unwrap();
*rx_1.borrow()
});
tx.send(2).unwrap();
assert_eq!(jh.join(), 2);
}
#[crate::test(tarantool = "crate")]
fn retains_only_last_notification() {
let (tx, mut rx_1) = channel::<i32>(10);
tx.send(1).unwrap();
tx.send(2).unwrap();
tx.send(3).unwrap();
let v = fiber::block_on(async {
rx_1.changed().await.unwrap();
*rx_1.borrow()
});
assert_eq!(v, 3);
assert_eq!(
fiber::block_on(rx_1.changed().timeout(_1_SEC)),
Err(timeout::Error::Expired)
);
}
#[crate::test(tarantool = "crate")]
fn notification_receive_error() {
let (tx, mut rx_1) = channel::<i32>(10);
let jh = fiber::start_async(rx_1.changed());
drop(tx);
assert_eq!(jh.join(), Err(RecvError));
}
#[crate::test(tarantool = "crate")]
fn notification_received_in_concurrent_fiber() {
let (tx, mut rx_1) = channel::<i32>(10);
let mut rx_2 = rx_1.clone();
let jh_1 = fiber::start_async(rx_1.changed());
let jh_2 = fiber::start_async(rx_2.changed());
tx.send(1).unwrap();
assert!(jh_1.join().is_ok());
assert!(jh_2.join().is_ok());
}
#[crate::test(tarantool = "crate")]
fn send_modify() {
let (tx, mut rx) = channel(vec![13]);
let jh = fiber::start(|| {
fiber::block_on(rx.changed()).unwrap();
rx.get_cloned()
});
tx.send_modify(|v| v.push(37));
assert_eq!(jh.join(), [13, 37]);
}
#[crate::test(tarantool = "crate")]
fn sender_get() {
let (tx, _) = channel(69);
assert_eq!(tx.get(), 69);
tx.send(420).unwrap();
assert_eq!(tx.get(), 420);
let (tx, _) = channel("foo".to_string());
assert_eq!(tx.get_cloned(), "foo");
tx.send("bar".into()).unwrap();
assert_eq!(tx.get_cloned(), "bar");
let (tx, mut rx) = channel(RefCell::new(vec![3.14]));
let value_ref = tx.borrow();
assert_eq!(*value_ref.borrow(), [3.14]);
value_ref.borrow_mut().push(2.71);
assert_eq!(*tx.get_cloned().borrow(), [3.14, 2.71]);
let res = fiber::block_on(rx.changed().timeout(Duration::ZERO));
assert_eq!(res, Err(timeout::Error::Expired));
tx.try_send_modify(|v| v.get_mut().push(1.61)).unwrap_err();
drop(value_ref);
tx.send_modify(|v| v.get_mut().push(1.61));
fiber::block_on(rx.changed()).unwrap();
assert_eq!(*rx.get_cloned().borrow(), [3.14, 2.71, 1.61]);
}
#[crate::test(tarantool = "crate")]
fn check_closed() {
let (tx, rx_1) = channel(());
assert!(!tx.is_closed());
drop(rx_1);
assert!(tx.is_closed());
let _rx_2 = tx.subscribe();
assert!(!tx.is_closed());
}
}