use std::{
borrow,
cell::{Cell, Ref, RefCell},
mem,
ops::Deref,
rc::Rc,
};
use local_event::Event;
use crate::{
error::{RecvError, SendError},
state::{StateSnapshot, Version},
};
#[derive(Debug)]
struct State(Cell<usize>);
impl State {
fn new() -> Self {
State(Cell::new(Version::INITIAL.inner()))
}
#[inline]
fn load(&self) -> StateSnapshot {
StateSnapshot::from_usize(self.0.get())
}
#[inline]
fn increment_version(&self) {
self.0.update(|x| x + Version::STEP);
}
fn set_closed(&self) {
self.0.update(|x| x | StateSnapshot::CLOSED_BIT);
}
}
#[derive(Debug)]
struct Shared<T> {
state: State,
tx_count: Cell<usize>,
rx_count: Cell<usize>,
changed: Event,
closed: Event,
value: RefCell<T>,
}
impl<T> Shared<T> {
fn new(init: T) -> Self {
Self {
value: init.into(),
state: State::new(),
tx_count: Cell::new(0),
rx_count: Cell::new(0),
changed: Event::new(),
closed: Event::new(),
}
}
#[inline]
fn tx_count(&self) -> usize {
self.tx_count.get()
}
#[inline]
fn rx_count(&self) -> usize {
self.rx_count.get()
}
}
#[derive(Debug)]
pub struct Sender<T> {
shared: Rc<Shared<T>>,
}
impl<T> Sender<T> {
fn from_shared(shared: Rc<Shared<T>>) -> Self {
shared.tx_count.update(|x| x + 1);
Self { shared }
}
#[must_use]
pub fn new(init: T) -> Self {
let shared = Rc::new(Shared::new(init));
shared.tx_count.update(|x| x + 1);
Self { shared }
}
#[must_use]
pub fn send_if_modified<F>(&self, modify: F) -> bool
where
F: FnOnce(&mut T) -> bool,
{
let mut guard = self.shared.value.borrow_mut();
if !modify(&mut guard) {
return false;
}
self.shared.state.increment_version();
drop(guard);
self.shared.changed.notify(usize::MAX);
true
}
pub fn send_modify<F>(&self, modify: F)
where
F: FnOnce(&mut T),
{
let _ = self.send_if_modified(|value| {
modify(value);
true
});
}
#[must_use]
pub fn send_replace(&self, mut value: T) -> T {
self.send_modify(|old| mem::swap(old, &mut value));
value
}
pub fn send(&self, value: T) -> Result<(), SendError<T>> {
if self.is_closed() {
return Err(SendError::ChannelClosed(value));
}
let _ = self.send_replace(value);
Ok(())
}
#[must_use]
pub fn borrow(&self) -> Guard<'_, T> {
let inner = self.shared.value.borrow();
let has_changed = false;
Guard { inner, has_changed }
}
#[must_use]
pub fn sender_count(&self) -> usize {
self.shared.tx_count()
}
#[must_use]
pub fn receiver_count(&self) -> usize {
self.shared.rx_count()
}
#[must_use]
#[inline]
pub fn is_closed(&self) -> bool {
self.receiver_count() == 0
}
pub async fn closed(&self) {
if self.is_closed() {
return;
}
let listener = self.shared.closed.listen();
if self.is_closed() {
return;
}
listener.await;
debug_assert!(self.is_closed())
}
#[must_use]
pub fn same_channel(&self, other: &Self) -> bool {
Rc::ptr_eq(&self.shared, &other.shared)
}
#[must_use]
pub fn subscribe(&self) -> Receiver<T> {
let shared = self.shared.clone();
shared.rx_count.update(|x| x + 1);
let version = shared.state.load().version();
Receiver { version, shared }
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
self.shared.tx_count.update(|x| x + 1);
Self {
shared: self.shared.clone(),
}
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
self.shared.tx_count.update(|x| x - 1);
if self.shared.tx_count.get() == 0 {
self.shared.state.set_closed();
self.shared.changed.notify(usize::MAX);
}
}
}
impl<T: Default> Default for Sender<T> {
fn default() -> Self {
Self::new(T::default())
}
}
#[derive(Debug)]
pub struct Guard<'a, T> {
inner: Ref<'a, T>,
has_changed: bool,
}
impl<T> Guard<'_, T> {
pub fn has_changed(&self) -> bool {
self.has_changed
}
}
impl<T> Deref for Guard<'_, T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
self.inner.deref()
}
}
impl<T> AsRef<T> for Guard<'_, T> {
#[inline]
fn as_ref(&self) -> &T {
self
}
}
impl<T> borrow::Borrow<T> for Guard<'_, T> {
#[inline]
fn borrow(&self) -> &T {
self
}
}
#[derive(Debug)]
pub struct Receiver<T> {
shared: Rc<Shared<T>>,
version: Version,
}
impl<T> Receiver<T> {
#[must_use]
pub fn borrow(&self) -> Guard<'_, T> {
let inner = self.shared.value.borrow();
let new_version = self.shared.state.load().version();
let has_changed = self.version != new_version;
Guard { inner, has_changed }
}
#[must_use]
pub fn borrow_and_update(&mut self) -> Guard<'_, T> {
let inner = self.shared.value.borrow();
let new_version = self.shared.state.load().version();
let has_changed = self.version != new_version;
self.version = new_version;
Guard { inner, has_changed }
}
pub fn has_changed(&self) -> Result<bool, RecvError> {
let state = self.shared.state.load();
if state.is_closed() {
return Err(RecvError::ChannelClosed);
}
let new_version = state.version();
Ok(self.version != new_version)
}
pub fn mark_changed(&mut self) {
self.version.decrement();
}
pub fn mark_unchanged(&mut self) {
let current_version = self.shared.state.load().version();
self.version = current_version;
}
pub fn same_channel(&self, other: &Self) -> bool {
Rc::ptr_eq(&self.shared, &other.shared)
}
#[inline]
fn load_change(&self) -> Option<bool> {
let new_state = self.shared.state.load();
if new_state.is_closed() {
return None;
}
if new_state.version() != self.version {
return Some(true);
}
Some(false)
}
pub async fn changed(&self) -> Result<(), RecvError> {
loop {
match self.load_change() {
Some(true) => return Ok(()),
None => return Err(RecvError::ChannelClosed),
Some(false) => self.shared.changed.listen().await,
}
}
}
pub async fn wait_for<F>(&mut self, mut cond: F) -> Result<Guard<'_, T>, RecvError>
where
F: FnMut(&T) -> bool,
{
loop {
{
let guard = self.shared.value.borrow();
let new_version = self.shared.state.load().version();
let has_changed = self.version != new_version;
self.version = new_version;
if cond(&guard) {
drop(guard);
let guard = self.shared.value.borrow();
return Ok(Guard {
inner: guard,
has_changed,
});
}
drop(guard);
}
let state = self.shared.state.load();
if state.is_closed() {
return Err(RecvError::ChannelClosed);
}
self.changed().await?;
}
}
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
let version = self.version;
let shared = self.shared.clone();
shared.rx_count.update(|x| x + 1);
Self { shared, version }
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
self.shared.rx_count.update(|x| x - 1);
if self.shared.rx_count.get() == 0 {
self.shared.closed.notify(usize::MAX);
}
}
}
#[must_use]
pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) {
let shared = Rc::new(Shared::new(init));
let tx = Sender::from_shared(shared.clone());
let rx = tx.subscribe();
(tx, rx)
}