use crate::TryPopError;
use arc_swap::ArcSwapOption;
use std::{
fmt::Debug,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
#[derive(Debug)]
pub struct Pub<T> {
tail: Link<T>,
state: Arc<State>,
}
impl<T> Default for Pub<T> {
fn default() -> Self {
let me = Self {
tail: Default::default(),
state: Default::default(),
};
me.state.publishers(1);
me
}
}
impl<T> Clone for Pub<T> {
fn clone(&self) -> Self {
self.state.publishers(1);
Self {
tail: self.tail.clone(),
state: self.state.clone(),
}
}
}
impl<T> Drop for Pub<T> {
fn drop(&mut self) {
self.state.publishers(-1);
}
}
impl<T: Clone> Pub<T> {
fn from(subscriber: &Sub<T>) -> Self {
subscriber.state.publishers(1);
Pub {
state: subscriber.state.clone(),
tail: Link {
next: subscriber.head.next.clone(),
},
}
}
pub fn subscribe(&self) -> Sub<T> {
Sub::from(self)
}
pub fn push(&mut self, value: T) -> bool {
let subs = self.state.subscribers(0);
let none: Option<Arc<(T, Link<T>)>> = None;
let link = Link::default();
let tail = link.clone();
let value = Arc::new((value.into(), link));
let mut target = self.tail.next.clone();
loop {
let replaced = target.compare_and_swap(&none, Some(value.clone()));
if replaced.is_none() {
break;
}
target = self.tail.last().next;
}
self.tail = tail;
subs != 0
}
}
#[derive(Debug)]
pub struct Sub<T> {
state: Arc<State>,
head: Link<T>,
}
impl<T> Default for Sub<T> {
fn default() -> Self {
let me = Self {
head: Default::default(),
state: Default::default(),
};
me.state.subscribers(1);
me
}
}
impl<T> Clone for Sub<T> {
fn clone(&self) -> Self {
let me = Self {
state: self.state.clone(),
head: self.head.clone(),
};
me.state.subscribers(1);
me
}
}
impl<T> Drop for Sub<T> {
fn drop(&mut self) {
self.state.subscribers(-1);
}
}
impl<T: Clone> Sub<T> {
fn from(publisher: &Pub<T>) -> Self {
publisher.state.subscribers(1);
Self {
state: publisher.state.clone(),
head: publisher.tail.clone(),
}
}
pub fn publish(&self) -> Pub<T> {
Pub::from(self)
}
pub fn try_pop(&mut self) -> Result<T, TryPopError> {
let publishers = self.state.publishers(0);
let next = self.head.next.load_full();
let Some((value, link)) = next.as_deref() else {
return match publishers {
0 => Err(TryPopError::Finished),
_ => Err(TryPopError::Empty),
};
};
self.head = link.clone();
Ok(value.clone())
}
}
#[derive(Debug, Default)]
struct State {
subscribers: AtomicU64,
publishers: AtomicU64,
}
impl State {
fn publishers(&self, change: i8) -> u64 {
match change.into() {
0 => self.publishers.load(Ordering::SeqCst),
1.. => self.publishers.fetch_add(change as u64, Ordering::SeqCst),
_ => self
.publishers
.fetch_sub((-change) as u64, Ordering::SeqCst),
}
}
fn subscribers(&self, change: i8) -> u64 {
match change.into() {
0 => self.subscribers.load(Ordering::SeqCst),
1.. => self.subscribers.fetch_add(change as u64, Ordering::SeqCst),
_ => self
.subscribers
.fetch_sub((-change) as u64, Ordering::SeqCst),
}
}
}
#[derive(Debug)]
struct Link<T> {
next: Arc<ArcSwapOption<(T, Link<T>)>>,
}
impl<T> Clone for Link<T> {
fn clone(&self) -> Self {
Self {
next: self.next.clone(),
}
}
}
impl<T> Default for Link<T> {
fn default() -> Self {
Self {
next: Default::default(),
}
}
}
impl<T> Link<T> {
fn last(&self) -> Link<T> {
let mut last = self.clone();
loop {
match last.next.load().as_deref() {
Some((_v, l)) => last = l.clone(),
None => break,
}
}
last
}
}
#[test]
fn sync_pub_sub() {
fn dbg_pub<T: Debug>(publ: &Pub<T>, name: impl std::fmt::Display) {
eprintln!("pub {name} {publ:#?}, {:p}", publ.tail.next);
}
fn dbg_sub<T: Debug>(sub: &Sub<T>, name: impl std::fmt::Display) {
eprintln!(
"sub {name} {sub:#?}, {:p} -> {:p}",
sub.head.next,
sub.head
.next
.load_full()
.as_deref()
.map(|(_v, l)| l.next.clone())
.unwrap_or_default()
);
}
let mut publisher = Pub::default();
let mut subscriber1 = publisher.subscribe();
publisher.push(0);
let mut subscriber2 = publisher.subscribe();
dbg_sub(&subscriber1, 1);
dbg_sub(&subscriber2, 2);
dbg_pub(&publisher, 0);
publisher.push(1);
publisher.push(2);
assert_eq!(subscriber2.try_pop(), Ok(1));
assert_eq!(subscriber1.try_pop(), Ok(0));
assert_eq!(subscriber1.try_pop(), Ok(1));
assert_eq!(subscriber1.try_pop(), Ok(2));
assert_eq!(subscriber1.try_pop(), Err(TryPopError::Empty));
assert_eq!(subscriber2.try_pop(), Ok(2));
assert_eq!(subscriber2.try_pop(), Err(TryPopError::Empty));
}
#[test]
fn threads_pub_sub() {
let mut publisher = Pub::default();
let subs = (0..20).map(|_| publisher.subscribe()).collect::<Vec<_>>();
let mut values = vec![];
for n in 0..5 {
publisher.push(n);
values.push(n);
eprintln!("pub {n}");
}
let mut handles = vec![];
for (i, mut subscriber) in subs.into_iter().enumerate() {
handles.push(
std::thread::Builder::new()
.name(format!("t{i}"))
.spawn(move || {
let mut values = vec![];
loop {
match subscriber.try_pop() {
Ok(value) => {
eprintln!("t{i} got {value}");
values.push(value);
}
Err(TryPopError::Empty) => {
eprintln!("t{i} got None");
}
Err(TryPopError::Finished) => {
eprintln!("t{i} finished");
break;
}
}
}
values
})
.expect("t1"),
);
}
handles.push(
std::thread::Builder::new()
.name("pub".to_owned())
.spawn(|| {
for n in 5..10 {
publisher.push(n);
values.push(n);
eprintln!("pub {n}");
}
eprintln!("pub drop");
drop(publisher);
values
})
.expect("pub"),
);
for h in handles {
let name = h.thread().name().expect("thread namme").to_owned();
assert_eq!(
h.join().expect(&name),
vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
"{name}"
)
}
}
#[test]
fn threads_multi_pub_sub() {
use std::collections::BTreeMap;
let publisher = Pub::default();
let mut subscriber = publisher.subscribe();
let mut handles = vec![];
for i in 0..20 {
let mut publisher = publisher.clone();
handles.push(
std::thread::Builder::new()
.name(format!("t{i:02}"))
.spawn(move || {
let mut subscriber = publisher.subscribe();
eprintln!("t{i} push");
publisher.push(i);
eprintln!("t{i} pushed");
drop(publisher);
let mut values = vec![];
loop {
match subscriber.try_pop() {
Ok(value) => {
eprintln!("t{i} got {value}");
values.push(value);
}
Err(TryPopError::Empty) => {
}
Err(TryPopError::Finished) => {
eprintln!("t{i} finished");
break;
}
}
}
values
})
.expect("t1"),
);
}
drop(publisher);
let mut all = BTreeMap::new();
for h in handles {
let name = h.thread().name().expect("thread namme").to_owned();
let values = h.join().expect(&name);
eprintln!("{name}: {} {values:?}", values.len());
all.insert(name, values);
}
let check_entry = all.entry("check".into()).or_default();
while let Ok(v) = subscriber.try_pop() {
check_entry.push(v);
}
eprintln!(
"check: {len} {values:?}",
len = check_entry.len(),
values = check_entry.as_slice()
);
for (name, mut values) in all {
values.sort();
assert_eq!(
values,
vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
"{name}"
)
}
}