use std::{
fmt,
result::Result,
sync::{Arc, Mutex, Weak},
};
#[derive(Debug)]
pub struct Receiver<T> {
latest: T,
latest_set: Arc<Mutex<Option<T>>>,
}
impl<T> Receiver<T> {
fn update_latest(&mut self) {
if let Ok(mut latest_set) = self.latest_set.lock() {
if let Some(value) = latest_set.take() {
self.latest = value;
}
}
}
pub fn latest(&mut self) -> &T {
self.update_latest();
&self.latest
}
pub fn latest_mut(&mut self) -> &mut T {
self.update_latest();
&mut self.latest
}
pub fn has_no_updater(&self) -> bool {
Arc::weak_count(&self.latest_set) == 0
}
}
#[derive(Debug)]
pub struct Updater<T> {
latest: Weak<Mutex<Option<T>>>,
}
impl<T> Clone for Updater<T> {
fn clone(&self) -> Self {
Updater {
latest: Weak::clone(&self.latest),
}
}
}
#[derive(PartialEq, Eq, Clone, Copy)]
pub struct NoReceiverError<T>(pub T);
impl<T> fmt::Debug for NoReceiverError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "NoReceiverError")
}
}
impl<T> fmt::Display for NoReceiverError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "receiver has been dropped")
}
}
impl<T> std::error::Error for NoReceiverError<T> {}
impl<T> Updater<T> {
pub fn update(&self, value: T) -> Result<(), NoReceiverError<T>> {
match self.latest.upgrade() {
Some(mutex) => {
*mutex.lock().unwrap() = Some(value);
Ok(())
}
None => Err(NoReceiverError(value)),
}
}
pub fn has_no_receiver(&self) -> bool {
self.latest.upgrade().is_none()
}
}
pub fn channel_starting_with<T>(initial: T) -> (Receiver<T>, Updater<T>) {
let receiver = Receiver {
latest: initial,
latest_set: Arc::new(Mutex::new(None)),
};
let updater = Updater {
latest: Arc::downgrade(&receiver.latest_set),
};
(receiver, updater)
}
pub fn channel<T>() -> (Receiver<Option<T>>, Updater<Option<T>>) {
channel_starting_with(None)
}
#[cfg(test)]
mod test {
use super::*;
use std::{mem, sync::Barrier, thread};
#[test]
fn send_recv_value() {
let (mut recv, send) = channel_starting_with(12);
assert_eq!(recv.latest(), &12);
send.update(123).unwrap();
assert_eq!(recv.latest(), &123);
}
#[test]
fn send_recv_option() {
let (mut recv, send) = channel_starting_with(None);
assert_eq!(*recv.latest(), None);
send.update(Some(234)).unwrap();
assert_eq!(*recv.latest(), Some(234));
}
fn barrier_pair() -> (Arc<Barrier>, Arc<Barrier>) {
let barrier = Arc::new(Barrier::new(2));
(barrier.clone(), barrier)
}
#[test]
fn concurrent_send_recv() {
let (mut recv, send) = channel_starting_with(0);
let (barrier, barrier2) = barrier_pair();
thread::spawn(move || {
barrier2.wait(); for num in 1..1000 {
send.update(num).unwrap();
}
send.update(1000).unwrap();
barrier2.wait(); for num in 1001..2001 {
send.update(num).unwrap();
}
barrier2.wait(); });
let mut distinct_recvs = 1;
let mut last_result = *recv.latest();
barrier.wait(); while last_result < 1000 {
let next = *recv.latest();
if next != last_result {
distinct_recvs += 1;
}
last_result = next;
}
assert!(distinct_recvs > 1);
println!("received: {}", distinct_recvs);
assert_eq!(*recv.latest(), 1000);
barrier.wait(); barrier.wait(); assert_eq!(*recv.latest(), 2000);
}
#[test]
fn non_blocking_write_during_read() {
let (mut name_get, name) = channel_starting_with("Nothing".to_owned());
let (barrier, barrier2) = barrier_pair();
thread::spawn(move || {
barrier2.wait(); name.update("Something".to_owned()).unwrap();
barrier2.wait(); });
{
let got = name_get.latest();
assert_eq!(*got, "Nothing".to_owned());
barrier.wait(); barrier.wait(); }
let got2 = name_get.latest();
assert_eq!(*got2, "Something".to_owned());
}
#[test]
fn error_writing_to_dead_reader() {
let (val_get, val) = channel_starting_with(0);
mem::drop(val_get);
assert_eq!(val.update(123), Err(NoReceiverError(123)));
}
#[test]
fn updater_has_no_receiver() {
let (receiver, updater) = channel_starting_with(0);
assert!(!updater.has_no_receiver());
mem::drop(receiver);
assert!(updater.has_no_receiver());
}
#[test]
fn receiver_has_no_updater() {
let (receiver, updater) = channel_starting_with(0);
assert!(!receiver.has_no_updater());
let updater2 = updater.clone();
assert!(!receiver.has_no_updater());
mem::drop(updater);
assert!(!receiver.has_no_updater());
mem::drop(updater2);
assert!(receiver.has_no_updater());
}
#[test]
fn latest_mut() {
let (mut val_get, _) = channel_starting_with("".to_owned());
{
val_get.latest_mut().push_str("hello");
}
assert_eq!(val_get.latest(), "hello");
}
#[derive(Eq, PartialEq, Debug)]
struct Unclonable(u32);
#[test]
fn multiple_updaters() {
let (mut val_get, val1) = channel_starting_with(Unclonable(1));
let val2 = val1.clone();
val1.update(Unclonable(2)).unwrap();
assert_eq!(*val_get.latest(), Unclonable(2));
val2.update(Unclonable(3)).unwrap();
assert_eq!(*val_get.latest(), Unclonable(3));
}
#[test]
fn no_args_channel() {
let (mut val_get, val) = channel();
assert_eq!(*val_get.latest(), None);
val.update(Some(123)).unwrap();
assert_eq!(*val_get.latest(), Some(123));
}
#[test]
fn unwrap_non_debug() {
struct NotDebug(u8);
let (_val_get, val) = channel_starting_with(NotDebug(0));
val.update(NotDebug(3))
.expect("This should compile even though `NotDebug` is not Debug");
}
}