use serde::{Deserialize, Serialize};
use std::{collections::HashSet, fmt, hash::Hash, mem::take, ops::Deref, sync::Arc};
use tokio::sync::{RwLock, RwLockReadGuard, oneshot, watch};
use tracing::Instrument;
use super::{ChangeNotifier, ChangeSender, RecvError, SendError, default_on_err, send_event};
use crate::{exec, prelude::*};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum HashSetEvent<T> {
Set(T),
Remove(T),
Clear,
ShrinkToFit,
Done,
#[serde(skip)]
InitialComplete,
}
pub struct ObservableHashSet<T, Codec = crate::codec::Default> {
hs: HashSet<T>,
tx: rch::broadcast::Sender<HashSetEvent<T>, Codec>,
change: ChangeSender,
on_err: Arc<dyn Fn(SendError) + Send + Sync>,
done: bool,
}
impl<T, Codec> fmt::Debug for ObservableHashSet<T, Codec>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.hs.fmt(f)
}
}
impl<T, Codec> From<HashSet<T>> for ObservableHashSet<T, Codec>
where
T: Clone + RemoteSend,
Codec: crate::codec::Codec,
{
fn from(hs: HashSet<T>) -> Self {
let (tx, _rx) = rch::broadcast::channel::<_, _, { rch::DEFAULT_BUFFER }>(1);
Self { hs, tx, change: ChangeSender::new(), on_err: Arc::new(default_on_err), done: false }
}
}
impl<T, Codec> From<ObservableHashSet<T, Codec>> for HashSet<T> {
fn from(ohs: ObservableHashSet<T, Codec>) -> Self {
ohs.hs
}
}
impl<T, Codec> Default for ObservableHashSet<T, Codec>
where
T: Clone + RemoteSend,
Codec: crate::codec::Codec,
{
fn default() -> Self {
Self::from(HashSet::new())
}
}
impl<T, Codec> ObservableHashSet<T, Codec>
where
T: Eq + Hash + Clone + RemoteSend,
Codec: crate::codec::Codec,
{
pub fn new() -> Self {
Self::default()
}
pub fn set_error_handler<E>(&mut self, on_err: E)
where
E: Fn(SendError) + Send + Sync + 'static,
{
self.on_err = Arc::new(on_err);
}
pub fn subscribe(&self, buffer: usize) -> HashSetSubscription<T, Codec> {
HashSetSubscription::new(
HashSetInitialValue::new_value(self.hs.clone()),
if self.done { None } else { Some(self.tx.subscribe(buffer)) },
)
}
pub fn subscribe_incremental(&self, buffer: usize) -> HashSetSubscription<T, Codec> {
HashSetSubscription::new(
HashSetInitialValue::new_incremental(self.hs.clone(), self.on_err.clone()),
if self.done { None } else { Some(self.tx.subscribe(buffer)) },
)
}
pub fn subscriber_count(&self) -> usize {
self.tx.receiver_count()
}
pub fn notifier(&self) -> ChangeNotifier {
self.change.subscribe()
}
pub fn insert(&mut self, value: T) -> bool {
self.assert_not_done();
self.change.notify();
send_event(&self.tx, &*self.on_err, HashSetEvent::Set(value.clone()));
self.hs.insert(value)
}
pub fn replace(&mut self, value: T) -> Option<T> {
self.assert_not_done();
self.change.notify();
send_event(&self.tx, &*self.on_err, HashSetEvent::Set(value.clone()));
self.hs.replace(value)
}
pub fn remove<Q>(&mut self, value: &Q) -> bool
where
T: std::borrow::Borrow<Q>,
Q: Hash + Eq,
{
self.assert_not_done();
match self.hs.take(value) {
Some(v) => {
self.change.notify();
send_event(&self.tx, &*self.on_err, HashSetEvent::Remove(v));
true
}
None => false,
}
}
pub fn take<Q>(&mut self, value: &Q) -> Option<T>
where
T: std::borrow::Borrow<Q>,
Q: Hash + Eq,
{
self.assert_not_done();
match self.hs.take(value) {
Some(v) => {
self.change.notify();
send_event(&self.tx, &*self.on_err, HashSetEvent::Remove(v.clone()));
Some(v)
}
None => None,
}
}
pub fn clear(&mut self) {
self.assert_not_done();
if !self.hs.is_empty() {
self.hs.clear();
self.change.notify();
send_event(&self.tx, &*self.on_err, HashSetEvent::Clear);
}
}
pub fn retain<F>(&mut self, mut f: F)
where
F: FnMut(&T) -> bool,
{
self.assert_not_done();
self.hs.retain(|v| {
if f(v) {
true
} else {
self.change.notify();
send_event(&self.tx, &*self.on_err, HashSetEvent::Remove(v.clone()));
false
}
});
}
pub fn shrink_to_fit(&mut self) {
self.assert_not_done();
send_event(&self.tx, &*self.on_err, HashSetEvent::ShrinkToFit);
self.hs.shrink_to_fit()
}
fn assert_not_done(&self) {
if self.done {
panic!("observable hash set cannot be changed after done has been called");
}
}
pub fn done(&mut self) {
if !self.done {
send_event(&self.tx, &*self.on_err, HashSetEvent::Done);
self.done = true;
}
}
pub fn is_done(&self) -> bool {
self.done
}
pub fn into_inner(self) -> HashSet<T> {
self.into()
}
}
impl<T, Codec> Deref for ObservableHashSet<T, Codec> {
type Target = HashSet<T>;
fn deref(&self) -> &Self::Target {
&self.hs
}
}
impl<T, Codec> Extend<T> for ObservableHashSet<T, Codec>
where
T: RemoteSend + Eq + Hash + Clone,
Codec: crate::codec::Codec,
{
fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
for value in iter {
self.insert(value);
}
}
}
struct MirroredHashSetInner<T> {
hs: HashSet<T>,
complete: bool,
done: bool,
error: Option<RecvError>,
max_size: usize,
}
impl<T> MirroredHashSetInner<T>
where
T: Eq + Hash,
{
fn handle_event(&mut self, event: HashSetEvent<T>) -> Result<(), RecvError> {
match event {
HashSetEvent::InitialComplete => {
self.complete = true;
}
HashSetEvent::Set(v) => {
self.hs.insert(v);
if self.hs.len() > self.max_size {
return Err(RecvError::MaxSizeExceeded(self.max_size));
}
}
HashSetEvent::Remove(k) => {
self.hs.remove(&k);
}
HashSetEvent::Clear => {
self.hs.clear();
}
HashSetEvent::ShrinkToFit => {
self.hs.shrink_to_fit();
}
HashSetEvent::Done => {
self.done = true;
}
}
Ok(())
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(bound(serialize = "T: RemoteSend + Eq + Hash, Codec: crate::codec::Codec"))]
#[serde(bound(deserialize = "T: RemoteSend + Eq + Hash, Codec: crate::codec::Codec"))]
enum HashSetInitialValue<T, Codec = crate::codec::Default> {
Value(HashSet<T>),
Incremental {
len: usize,
rx: rch::mpsc::Receiver<T, Codec>,
},
}
impl<T, Codec> HashSetInitialValue<T, Codec>
where
T: RemoteSend + Eq + Hash + Clone,
Codec: crate::codec::Codec,
{
fn new_value(hs: HashSet<T>) -> Self {
Self::Value(hs)
}
fn new_incremental(hs: HashSet<T>, on_err: Arc<dyn Fn(SendError) + Send + Sync>) -> Self {
let (tx, rx) = rch::mpsc::channel(128);
let len = hs.len();
exec::spawn(
async move {
for v in hs.into_iter() {
match tx.send(v).await {
Ok(_) => (),
Err(err) if err.is_disconnected() => break,
Err(err) => match err.try_into() {
Ok(err) => (on_err)(err),
Err(_) => unreachable!(),
},
}
}
}
.in_current_span(),
);
Self::Incremental { len, rx }
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(bound(serialize = "T: RemoteSend + Eq + Hash, Codec: crate::codec::Codec"))]
#[serde(bound(deserialize = "T: RemoteSend + Eq + Hash, Codec: crate::codec::Codec"))]
pub struct HashSetSubscription<T, Codec = crate::codec::Default> {
initial: HashSetInitialValue<T, Codec>,
#[serde(skip, default)]
complete: bool,
events: Option<rch::broadcast::Receiver<HashSetEvent<T>, Codec>>,
#[serde(skip, default)]
done: bool,
}
impl<T, Codec> HashSetSubscription<T, Codec>
where
T: RemoteSend + Eq + Hash + Clone,
Codec: crate::codec::Codec,
{
fn new(
initial: HashSetInitialValue<T, Codec>, events: Option<rch::broadcast::Receiver<HashSetEvent<T>, Codec>>,
) -> Self {
Self { initial, complete: false, events, done: false }
}
pub fn is_incremental(&self) -> bool {
matches!(self.initial, HashSetInitialValue::Incremental { .. })
}
pub fn is_complete(&self) -> bool {
self.complete
}
pub fn is_done(&self) -> bool {
self.events.is_none() || self.done
}
pub fn take_initial(&mut self) -> Option<HashSet<T>> {
match &mut self.initial {
HashSetInitialValue::Value(value) if !self.complete => {
self.complete = true;
Some(take(value))
}
_ => None,
}
}
pub async fn recv(&mut self) -> Result<Option<HashSetEvent<T>>, RecvError> {
if !self.complete {
match &mut self.initial {
HashSetInitialValue::Incremental { len, rx } => {
if *len > 0 {
match rx.recv().await? {
Some(v) => {
*len -= 1;
return Ok(Some(HashSetEvent::Set(v)));
}
None => return Err(RecvError::Closed),
}
} else {
self.complete = true;
return Ok(Some(HashSetEvent::InitialComplete));
}
}
HashSetInitialValue::Value(_) => {
panic!("take_initial must be called before recv for non-incremental subscription");
}
}
}
if let Some(rx) = &mut self.events {
match rx.recv().await? {
HashSetEvent::Done => self.events = None,
evt => return Ok(Some(evt)),
}
}
if self.done {
Ok(None)
} else {
self.done = true;
Ok(Some(HashSetEvent::Done))
}
}
}
impl<T, Codec> HashSetSubscription<T, Codec>
where
T: RemoteSend + Eq + Hash + Clone + RemoteSend + Sync,
Codec: crate::codec::Codec,
{
pub fn mirror(mut self, max_size: usize) -> MirroredHashSet<T, Codec> {
let (tx, _rx) = rch::broadcast::channel::<_, _, { rch::DEFAULT_BUFFER }>(1);
let (changed_tx, changed_rx) = watch::channel(());
let (dropped_tx, mut dropped_rx) = oneshot::channel();
let inner = Arc::new(RwLock::new(Some(MirroredHashSetInner {
hs: self.take_initial().unwrap_or_default(),
complete: self.is_complete(),
done: self.is_done(),
error: None,
max_size,
})));
let inner_task = inner.clone();
let tx_send = tx.clone();
exec::spawn(
async move {
loop {
let event = tokio::select! {
event = self.recv() => event,
_ = &mut dropped_rx => return,
};
let mut inner = inner_task.write().await;
let inner = match inner.as_mut() {
Some(inner) => inner,
None => return,
};
changed_tx.send_replace(());
match event {
Ok(Some(event)) => {
if tx_send.receiver_count() > 0 {
let _ = tx_send.send(event.clone());
}
if let Err(err) = inner.handle_event(event) {
inner.error = Some(err);
return;
}
if inner.done {
break;
}
}
Ok(None) => break,
Err(err) => {
inner.error = Some(err);
return;
}
}
}
}
.in_current_span(),
);
MirroredHashSet { inner, tx, changed_rx, _dropped_tx: dropped_tx }
}
}
pub struct MirroredHashSet<T, Codec = crate::codec::Default> {
inner: Arc<RwLock<Option<MirroredHashSetInner<T>>>>,
tx: rch::broadcast::Sender<HashSetEvent<T>, Codec>,
changed_rx: watch::Receiver<()>,
_dropped_tx: oneshot::Sender<()>,
}
impl<T, Codec> fmt::Debug for MirroredHashSet<T, Codec> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("MirroredHashSet").finish()
}
}
impl<T, Codec> MirroredHashSet<T, Codec>
where
T: RemoteSend + Eq + Hash + Clone,
Codec: crate::codec::Codec,
{
pub async fn borrow(&self) -> Result<MirroredHashSetRef<'_, T>, RecvError> {
let inner = self.inner.read().await;
let inner = RwLockReadGuard::map(inner, |inner| inner.as_ref().unwrap());
match &inner.error {
None => Ok(MirroredHashSetRef(inner)),
Some(err) => Err(err.clone()),
}
}
pub async fn borrow_and_update(&mut self) -> Result<MirroredHashSetRef<'_, T>, RecvError> {
let inner = self.inner.read().await;
self.changed_rx.borrow_and_update();
let inner = RwLockReadGuard::map(inner, |inner| inner.as_ref().unwrap());
match &inner.error {
None => Ok(MirroredHashSetRef(inner)),
Some(err) => Err(err.clone()),
}
}
pub async fn detach(self) -> HashSet<T> {
let mut inner = self.inner.write().await;
inner.take().unwrap().hs
}
pub async fn changed(&mut self) {
let _ = self.changed_rx.changed().await;
}
pub async fn subscribe(&self, buffer: usize) -> Result<HashSetSubscription<T, Codec>, RecvError> {
let view = self.borrow().await?;
let initial = view.clone();
let events = if view.is_done() { None } else { Some(self.tx.subscribe(buffer)) };
Ok(HashSetSubscription::new(HashSetInitialValue::new_value(initial), events))
}
pub async fn subscribe_incremental(&self, buffer: usize) -> Result<HashSetSubscription<T, Codec>, RecvError> {
let view = self.borrow().await?;
let initial = view.clone();
let events = if view.is_done() { None } else { Some(self.tx.subscribe(buffer)) };
Ok(HashSetSubscription::new(
HashSetInitialValue::new_incremental(initial, Arc::new(default_on_err)),
events,
))
}
}
impl<T, Codec> Drop for MirroredHashSet<T, Codec> {
fn drop(&mut self) {
}
}
pub struct MirroredHashSetRef<'a, T>(RwLockReadGuard<'a, MirroredHashSetInner<T>>);
impl<T> MirroredHashSetRef<'_, T> {
pub fn is_complete(&self) -> bool {
self.0.complete
}
pub fn is_done(&self) -> bool {
self.0.done
}
}
impl<T> fmt::Debug for MirroredHashSetRef<'_, T>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.hs.fmt(f)
}
}
impl<T> Deref for MirroredHashSetRef<'_, T> {
type Target = HashSet<T>;
fn deref(&self) -> &Self::Target {
&self.0.hs
}
}
impl<T> Drop for MirroredHashSetRef<'_, T> {
fn drop(&mut self) {
}
}