use crate::backend::native::v3::pubsub::types::{PubSubEvent, SubscriberId, SubscriptionFilter};
use std::sync::mpsc::{self, Receiver, Sender};
use std::sync::{Arc, Mutex};
type SubscriberEntry = (SubscriberId, Sender<PubSubEvent>, SubscriptionFilter);
#[derive(Debug)]
pub struct Publisher {
senders: Arc<Mutex<Vec<SubscriberEntry>>>,
next_id: Arc<Mutex<u64>>,
}
impl Default for Publisher {
fn default() -> Self {
Self::new()
}
}
impl Publisher {
pub fn new() -> Self {
Self {
senders: Arc::new(Mutex::new(Vec::new())),
next_id: Arc::new(Mutex::new(1)),
}
}
pub fn subscribe(&self, filter: SubscriptionFilter) -> (SubscriberId, Receiver<PubSubEvent>) {
let (tx, rx) = mpsc::channel();
let id = {
let mut next = self.next_id.lock().unwrap();
let id = *next;
*next = next.wrapping_add(1);
SubscriberId::new(id)
};
let mut senders = self.senders.lock().unwrap();
senders.push((id, tx, filter));
(id, rx)
}
pub fn unsubscribe(&self, subscriber_id: SubscriberId) -> bool {
let mut senders = self.senders.lock().unwrap();
let pos = senders.iter().position(|(id, _, _)| *id == subscriber_id);
if let Some(pos) = pos {
senders.swap_remove(pos);
true
} else {
false
}
}
pub fn emit(&self, event: PubSubEvent) {
let senders = self.senders.lock().unwrap();
for (_, sender, filter) in senders.iter() {
if filter.matches(&event) {
let _ = sender.send(event.clone());
}
}
}
pub fn subscriber_count(&self) -> usize {
let senders = self.senders.lock().unwrap();
senders.len()
}
pub fn has_subscribers(&self) -> bool {
self.subscriber_count() > 0
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_publisher_creation() {
let publisher = Publisher::new();
assert_eq!(publisher.subscriber_count(), 0);
assert!(!publisher.has_subscribers());
}
#[test]
fn test_subscribe_unsubscribe() {
let publisher = Publisher::new();
let (id, _rx) = publisher.subscribe(SubscriptionFilter::all());
assert_eq!(publisher.subscriber_count(), 1);
assert!(publisher.has_subscribers());
let removed = publisher.unsubscribe(id);
assert!(removed);
assert_eq!(publisher.subscriber_count(), 0);
let removed = publisher.unsubscribe(id);
assert!(!removed);
}
#[test]
fn test_emit_event() {
let publisher = Publisher::new();
let (id, rx) = publisher.subscribe(SubscriptionFilter::all());
let event = PubSubEvent::NodeChanged {
node_id: 1,
snapshot_id: 1,
};
publisher.emit(event.clone());
let received = rx.recv_timeout(Duration::from_millis(100));
assert_eq!(received, Ok(event));
publisher.unsubscribe(id);
}
#[test]
fn test_filter_matching() {
let publisher = Publisher::new();
let (_id1, rx1) = publisher.subscribe(SubscriptionFilter::all());
let (_id2, rx2) = publisher.subscribe(SubscriptionFilter::nodes_only());
let node_event = PubSubEvent::NodeChanged {
node_id: 1,
snapshot_id: 1,
};
publisher.emit(node_event.clone());
assert_eq!(
rx1.recv_timeout(Duration::from_millis(100)),
Ok(node_event.clone())
);
assert_eq!(rx2.recv_timeout(Duration::from_millis(100)), Ok(node_event));
let edge_event = PubSubEvent::EdgeChanged {
edge_id: 1,
from_node: 1,
to_node: 2,
snapshot_id: 1,
};
publisher.emit(edge_event.clone());
assert_eq!(rx1.recv_timeout(Duration::from_millis(100)), Ok(edge_event));
assert!(rx2.recv_timeout(Duration::from_millis(50)).is_err());
}
#[test]
fn test_multiple_subscribers_receive_independent_events() {
let publisher = Publisher::new();
let (_id1, rx1) = publisher.subscribe(SubscriptionFilter::all());
let (_id2, rx2) = publisher.subscribe(SubscriptionFilter::all());
let event = PubSubEvent::SnapshotCommitted { snapshot_id: 42 };
publisher.emit(event.clone());
assert_eq!(
rx1.recv_timeout(Duration::from_millis(100)),
Ok(event.clone())
);
assert_eq!(rx2.recv_timeout(Duration::from_millis(100)), Ok(event));
}
}