1use remoc::prelude::*;
22use serde::{Deserialize, Serialize};
23use std::{collections::HashSet, fmt, hash::Hash, mem::take, ops::Deref, sync::Arc};
24use tokio::sync::{oneshot, watch, RwLock, RwLockReadGuard};
25
26use crate::{default_on_err, send_event, ChangeNotifier, ChangeSender, RecvError, SendError};
27
28#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
30pub enum HashSetEvent<T> {
31 #[serde(skip)]
34 InitialComplete,
35 Set(T),
37 Remove(T),
39 Clear,
41 ShrinkToFit,
43 Done,
46}
47
48pub struct ObservableHashSet<T, Codec = remoc::codec::Default> {
53 hs: HashSet<T>,
54 tx: rch::broadcast::Sender<HashSetEvent<T>, Codec>,
55 change: ChangeSender,
56 on_err: Arc<dyn Fn(SendError) + Send + Sync>,
57 done: bool,
58}
59
60impl<T, Codec> fmt::Debug for ObservableHashSet<T, Codec>
61where
62 T: fmt::Debug,
63{
64 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65 self.hs.fmt(f)
66 }
67}
68
69impl<T, Codec> From<HashSet<T>> for ObservableHashSet<T, Codec>
70where
71 T: Clone + RemoteSend,
72 Codec: remoc::codec::Codec,
73{
74 fn from(hs: HashSet<T>) -> Self {
75 let (tx, _rx) = rch::broadcast::channel::<_, _, rch::buffer::Default>(1);
76 Self { hs, tx, change: ChangeSender::new(), on_err: Arc::new(default_on_err), done: false }
77 }
78}
79
80impl<T, Codec> From<ObservableHashSet<T, Codec>> for HashSet<T> {
81 fn from(ohs: ObservableHashSet<T, Codec>) -> Self {
82 ohs.hs
83 }
84}
85
86impl<T, Codec> Default for ObservableHashSet<T, Codec>
87where
88 T: Clone + RemoteSend,
89
90 Codec: remoc::codec::Codec,
91{
92 fn default() -> Self {
93 Self::from(HashSet::new())
94 }
95}
96
97impl<T, Codec> ObservableHashSet<T, Codec>
98where
99 T: Eq + Hash + Clone + RemoteSend,
100 Codec: remoc::codec::Codec,
101{
102 pub fn new() -> Self {
104 Self::default()
105 }
106
107 pub fn set_error_handler<E>(&mut self, on_err: E)
110 where
111 E: Fn(SendError) + Send + Sync + 'static,
112 {
113 self.on_err = Arc::new(on_err);
114 }
115
116 pub fn subscribe(&self, buffer: usize) -> HashSetSubscription<T, Codec> {
123 HashSetSubscription::new(
124 HashSetInitialValue::new_value(self.hs.clone()),
125 if self.done { None } else { Some(self.tx.subscribe(buffer)) },
126 )
127 }
128
129 pub fn subscribe_incremental(&self, buffer: usize) -> HashSetSubscription<T, Codec> {
137 HashSetSubscription::new(
138 HashSetInitialValue::new_incremental(self.hs.clone(), self.on_err.clone()),
139 if self.done { None } else { Some(self.tx.subscribe(buffer)) },
140 )
141 }
142
143 pub fn subscriber_count(&self) -> usize {
145 self.tx.receiver_count()
146 }
147
148 pub fn notifier(&self) -> ChangeNotifier {
151 self.change.subscribe()
152 }
153
154 pub fn insert(&mut self, value: T) -> bool {
163 self.assert_not_done();
164 self.change.notify();
165
166 send_event(&self.tx, &*self.on_err, HashSetEvent::Set(value.clone()));
167 self.hs.insert(value)
168 }
169
170 pub fn replace(&mut self, value: T) -> Option<T> {
179 self.assert_not_done();
180 self.change.notify();
181
182 send_event(&self.tx, &*self.on_err, HashSetEvent::Set(value.clone()));
183 self.hs.replace(value)
184 }
185
186 pub fn remove<Q>(&mut self, value: &Q) -> bool
195 where
196 T: std::borrow::Borrow<Q>,
197 Q: Hash + Eq,
198 {
199 self.assert_not_done();
200
201 match self.hs.take(value) {
202 Some(v) => {
203 self.change.notify();
204 send_event(&self.tx, &*self.on_err, HashSetEvent::Remove(v));
205 true
206 }
207 None => false,
208 }
209 }
210
211 pub fn take<Q>(&mut self, value: &Q) -> Option<T>
220 where
221 T: std::borrow::Borrow<Q>,
222 Q: Hash + Eq,
223 {
224 self.assert_not_done();
225
226 match self.hs.take(value) {
227 Some(v) => {
228 self.change.notify();
229 send_event(&self.tx, &*self.on_err, HashSetEvent::Remove(v.clone()));
230 Some(v)
231 }
232 None => None,
233 }
234 }
235
236 pub fn clear(&mut self) {
243 self.assert_not_done();
244
245 if !self.hs.is_empty() {
246 self.hs.clear();
247 self.change.notify();
248 send_event(&self.tx, &*self.on_err, HashSetEvent::Clear);
249 }
250 }
251
252 pub fn retain<F>(&mut self, mut f: F)
259 where
260 F: FnMut(&T) -> bool,
261 {
262 self.assert_not_done();
263
264 self.hs.retain(|v| {
265 if f(v) {
266 true
267 } else {
268 self.change.notify();
269 send_event(&self.tx, &*self.on_err, HashSetEvent::Remove(v.clone()));
270 false
271 }
272 });
273 }
274
275 pub fn shrink_to_fit(&mut self) {
282 self.assert_not_done();
283 send_event(&self.tx, &*self.on_err, HashSetEvent::ShrinkToFit);
284 self.hs.shrink_to_fit()
285 }
286
287 fn assert_not_done(&self) {
289 if self.done {
290 panic!("observable hash set cannot be changed after done has been called");
291 }
292 }
293
294 pub fn done(&mut self) {
300 if !self.done {
301 send_event(&self.tx, &*self.on_err, HashSetEvent::Done);
302 self.done = true;
303 }
304 }
305
306 pub fn is_done(&self) -> bool {
311 self.done
312 }
313
314 pub fn into_inner(self) -> HashSet<T> {
319 self.into()
320 }
321}
322
323impl<T, Codec> Deref for ObservableHashSet<T, Codec> {
324 type Target = HashSet<T>;
325
326 fn deref(&self) -> &Self::Target {
327 &self.hs
328 }
329}
330
331impl<T, Codec> Extend<T> for ObservableHashSet<T, Codec>
332where
333 T: RemoteSend + Eq + Hash + Clone,
334 Codec: remoc::codec::Codec,
335{
336 fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
337 for value in iter {
338 self.insert(value);
339 }
340 }
341}
342
343struct MirroredHashSetInner<T> {
344 hs: HashSet<T>,
345 complete: bool,
346 done: bool,
347 error: Option<RecvError>,
348 max_size: usize,
349}
350
351impl<T> MirroredHashSetInner<T>
352where
353 T: Eq + Hash,
354{
355 fn handle_event(&mut self, event: HashSetEvent<T>) -> Result<(), RecvError> {
356 match event {
357 HashSetEvent::InitialComplete => {
358 self.complete = true;
359 }
360 HashSetEvent::Set(v) => {
361 self.hs.insert(v);
362 if self.hs.len() > self.max_size {
363 return Err(RecvError::MaxSizeExceeded(self.max_size));
364 }
365 }
366 HashSetEvent::Remove(k) => {
367 self.hs.remove(&k);
368 }
369 HashSetEvent::Clear => {
370 self.hs.clear();
371 }
372 HashSetEvent::ShrinkToFit => {
373 self.hs.shrink_to_fit();
374 }
375 HashSetEvent::Done => {
376 self.done = true;
377 }
378 }
379 Ok(())
380 }
381}
382
383#[derive(Debug, Serialize, Deserialize)]
385#[serde(bound(serialize = "T: RemoteSend + Eq + Hash, Codec: remoc::codec::Codec"))]
386#[serde(bound(deserialize = "T: RemoteSend + Eq + Hash, Codec: remoc::codec::Codec"))]
387enum HashSetInitialValue<T, Codec = remoc::codec::Default> {
388 Value(HashSet<T>),
390 Incremental {
392 len: usize,
394 rx: rch::mpsc::Receiver<T, Codec>,
396 },
397}
398
399impl<T, Codec> HashSetInitialValue<T, Codec>
400where
401 T: RemoteSend + Eq + Hash + Clone,
402 Codec: remoc::codec::Codec,
403{
404 fn new_value(hs: HashSet<T>) -> Self {
406 Self::Value(hs)
407 }
408
409 fn new_incremental(hs: HashSet<T>, on_err: Arc<dyn Fn(SendError) + Send + Sync>) -> Self {
411 let (tx, rx) = rch::mpsc::channel(128);
412 let len = hs.len();
413
414 tokio::spawn(async move {
415 for v in hs.into_iter() {
416 match tx.send(v).await {
417 Ok(()) => (),
418 Err(err) if err.is_disconnected() => break,
419 Err(err) => match err.try_into() {
420 Ok(err) => (on_err)(err),
421 Err(_) => unreachable!(),
422 },
423 }
424 }
425 });
426
427 Self::Incremental { len, rx }
428 }
429}
430
431#[derive(Debug, Serialize, Deserialize)]
441#[serde(bound(serialize = "T: RemoteSend + Eq + Hash, Codec: remoc::codec::Codec"))]
442#[serde(bound(deserialize = "T: RemoteSend + Eq + Hash, Codec: remoc::codec::Codec"))]
443pub struct HashSetSubscription<T, Codec = remoc::codec::Default> {
444 initial: HashSetInitialValue<T, Codec>,
446 #[serde(skip, default)]
448 complete: bool,
449 events: Option<rch::broadcast::Receiver<HashSetEvent<T>, Codec>>,
453 #[serde(skip, default)]
455 done: bool,
456}
457
458impl<T, Codec> HashSetSubscription<T, Codec>
459where
460 T: RemoteSend + Eq + Hash + Clone,
461 Codec: remoc::codec::Codec,
462{
463 fn new(
464 initial: HashSetInitialValue<T, Codec>, events: Option<rch::broadcast::Receiver<HashSetEvent<T>, Codec>>,
465 ) -> Self {
466 Self { initial, complete: false, events, done: false }
467 }
468
469 pub fn is_incremental(&self) -> bool {
471 matches!(self.initial, HashSetInitialValue::Incremental { .. })
472 }
473
474 pub fn is_complete(&self) -> bool {
478 self.complete
479 }
480
481 pub fn is_done(&self) -> bool {
484 self.events.is_none() || self.done
485 }
486
487 pub fn take_initial(&mut self) -> Option<HashSet<T>> {
496 match &mut self.initial {
497 HashSetInitialValue::Value(value) if !self.complete => {
498 self.complete = true;
499 Some(take(value))
500 }
501 _ => None,
502 }
503 }
504
505 pub async fn recv(&mut self) -> Result<Option<HashSetEvent<T>>, RecvError> {
511 if !self.complete {
513 match &mut self.initial {
514 HashSetInitialValue::Incremental { len, rx } => {
515 if *len > 0 {
516 match rx.recv().await? {
517 Some(v) => {
518 *len -= 1;
520 return Ok(Some(HashSetEvent::Set(v)));
521 }
522 None => return Err(RecvError::Closed),
523 }
524 } else {
525 self.complete = true;
527 return Ok(Some(HashSetEvent::InitialComplete));
528 }
529 }
530 HashSetInitialValue::Value(_) => {
531 panic!("take_initial must be called before recv for non-incremental subscription");
532 }
533 }
534 }
535
536 if let Some(rx) = &mut self.events {
538 match rx.recv().await? {
539 HashSetEvent::Done => self.events = None,
540 evt => return Ok(Some(evt)),
541 }
542 }
543
544 if self.done {
546 Ok(None)
547 } else {
548 self.done = true;
549 Ok(Some(HashSetEvent::Done))
550 }
551 }
552}
553
554impl<T, Codec> HashSetSubscription<T, Codec>
555where
556 T: RemoteSend + Eq + Hash + Clone + RemoteSend + Sync,
557 Codec: remoc::codec::Codec,
558{
559 pub fn mirror(mut self, max_size: usize) -> MirroredHashSet<T, Codec> {
565 let (tx, _rx) = rch::broadcast::channel::<_, _, rch::buffer::Default>(1);
566 let (changed_tx, changed_rx) = watch::channel(());
567 let (dropped_tx, mut dropped_rx) = oneshot::channel();
568
569 let inner = Arc::new(RwLock::new(Some(MirroredHashSetInner {
571 hs: self.take_initial().unwrap_or_default(),
572 complete: self.is_complete(),
573 done: self.is_done(),
574 error: None,
575 max_size,
576 })));
577 let inner_task = inner.clone();
578
579 let tx_send = tx.clone();
581 tokio::spawn(async move {
582 loop {
583 let event = tokio::select! {
584 event = self.recv() => event,
585 _ = &mut dropped_rx => return,
586 };
587
588 let mut inner = inner_task.write().await;
589 let mut inner = match inner.as_mut() {
590 Some(inner) => inner,
591 None => return,
592 };
593
594 changed_tx.send_replace(());
595
596 match event {
597 Ok(Some(event)) => {
598 if tx_send.receiver_count() > 0 {
599 let _ = tx_send.send(event.clone());
600 }
601
602 if let Err(err) = inner.handle_event(event) {
603 inner.error = Some(err);
604 return;
605 }
606
607 if inner.done {
608 break;
609 }
610 }
611 Ok(None) => break,
612 Err(err) => {
613 inner.error = Some(err);
614 return;
615 }
616 }
617 }
618 });
619
620 MirroredHashSet { inner, tx, changed_rx, _dropped_tx: dropped_tx }
621 }
622}
623
624pub struct MirroredHashSet<T, Codec = remoc::codec::Default> {
626 inner: Arc<RwLock<Option<MirroredHashSetInner<T>>>>,
627 tx: rch::broadcast::Sender<HashSetEvent<T>, Codec>,
628 changed_rx: watch::Receiver<()>,
629 _dropped_tx: oneshot::Sender<()>,
630}
631
632impl<T, Codec> fmt::Debug for MirroredHashSet<T, Codec> {
633 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
634 f.debug_struct("MirroredHashSet").finish()
635 }
636}
637
638impl<T, Codec> MirroredHashSet<T, Codec>
639where
640 T: RemoteSend + Eq + Hash + Clone,
641 Codec: remoc::codec::Codec,
642{
643 pub async fn borrow(&self) -> Result<MirroredHashSetRef<'_, T>, RecvError> {
653 let inner = self.inner.read().await;
654 let inner = RwLockReadGuard::map(inner, |inner| inner.as_ref().unwrap());
655 match &inner.error {
656 None => Ok(MirroredHashSetRef(inner)),
657 Some(err) => Err(err.clone()),
658 }
659 }
660
661 pub async fn borrow_and_update(&mut self) -> Result<MirroredHashSetRef<'_, T>, RecvError> {
674 let inner = self.inner.read().await;
675 self.changed_rx.borrow_and_update();
676 let inner = RwLockReadGuard::map(inner, |inner| inner.as_ref().unwrap());
677 match &inner.error {
678 None => Ok(MirroredHashSetRef(inner)),
679 Some(err) => Err(err.clone()),
680 }
681 }
682
683 pub async fn detach(self) -> HashSet<T> {
685 let mut inner = self.inner.write().await;
686 inner.take().unwrap().hs
687 }
688
689 pub async fn changed(&mut self) {
694 let _ = self.changed_rx.changed().await;
695 }
696
697 pub async fn subscribe(&self, buffer: usize) -> Result<HashSetSubscription<T, Codec>, RecvError> {
704 let view = self.borrow().await?;
705 let initial = view.clone();
706 let events = if view.is_done() { None } else { Some(self.tx.subscribe(buffer)) };
707
708 Ok(HashSetSubscription::new(HashSetInitialValue::new_value(initial), events))
709 }
710
711 pub async fn subscribe_incremental(&self, buffer: usize) -> Result<HashSetSubscription<T, Codec>, RecvError> {
719 let view = self.borrow().await?;
720 let initial = view.clone();
721 let events = if view.is_done() { None } else { Some(self.tx.subscribe(buffer)) };
722
723 Ok(HashSetSubscription::new(
724 HashSetInitialValue::new_incremental(initial, Arc::new(default_on_err)),
725 events,
726 ))
727 }
728}
729
730impl<T, Codec> Drop for MirroredHashSet<T, Codec> {
731 fn drop(&mut self) {
732 }
734}
735
736pub struct MirroredHashSetRef<'a, T>(RwLockReadGuard<'a, MirroredHashSetInner<T>>);
738
739impl<'a, T> MirroredHashSetRef<'a, T> {
740 pub fn is_complete(&self) -> bool {
743 self.0.complete
744 }
745
746 pub fn is_done(&self) -> bool {
749 self.0.done
750 }
751}
752
753impl<'a, T> fmt::Debug for MirroredHashSetRef<'a, T>
754where
755 T: fmt::Debug,
756{
757 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
758 self.0.hs.fmt(f)
759 }
760}
761
762impl<'a, T> Deref for MirroredHashSetRef<'a, T> {
763 type Target = HashSet<T>;
764
765 fn deref(&self) -> &Self::Target {
766 &self.0.hs
767 }
768}
769
770impl<'a, T> Drop for MirroredHashSetRef<'a, T> {
771 fn drop(&mut self) {
772 }
774}