use hashbrown::{HashMap, hash_map::Entry};
use miette::{IntoDiagnostic, Result, miette};
use tokio::sync::mpsc::Sender;
use tracing::{debug, warn};
use worterbuch_common::{KeySegment, PStateEvent, RegularKeySegment, StateEvent, SubscriptionId};
type Subs = Vec<Subscriber>;
type Tree = HashMap<KeySegment, Node>;
#[derive(Clone, Debug)]
pub enum EventSender {
State(Sender<StateEvent>),
PState(Sender<PStateEvent>),
}
#[derive(Clone, Debug)]
pub struct Subscriber {
pattern: Vec<KeySegment>,
tx: EventSender,
id: SubscriptionId,
unique: bool,
}
impl Subscriber {
pub fn new(
id: SubscriptionId,
pattern: Vec<KeySegment>,
tx: EventSender,
unique: bool,
) -> Subscriber {
Subscriber {
pattern,
tx,
id,
unique,
}
}
pub async fn send_pstate(&self, event: PStateEvent) -> Result<()> {
if let EventSender::PState(tx) = &self.tx {
tx.send(event).await.into_diagnostic()?;
Ok(())
} else {
Err(miette!(
"Tried to send a PSatetEvent to a StateEvent subscriber"
))
}
}
pub async fn send_state(&self, event: StateEvent) -> Result<()> {
if let EventSender::State(tx) = &self.tx {
tx.send(event).await.into_diagnostic()?;
Ok(())
} else {
Err(miette!(
"Tried to send a SatetEvent to a PStateEvent subscriber"
))
}
}
pub fn is_unique(&self) -> bool {
self.unique
}
pub fn is_pstate_subscriber(&self) -> bool {
match self.tx {
EventSender::State(_) => false,
EventSender::PState(_) => true,
}
}
}
#[derive(Clone, Debug)]
pub struct LsSubscriber {
pub parent: Vec<RegularKeySegment>,
tx: Sender<Vec<RegularKeySegment>>,
pub id: SubscriptionId,
}
impl LsSubscriber {
pub fn new(
id: SubscriptionId,
parent: Vec<RegularKeySegment>,
tx: Sender<Vec<RegularKeySegment>>,
) -> LsSubscriber {
LsSubscriber { parent, tx, id }
}
pub async fn send(&self, children: Vec<RegularKeySegment>) -> Result<()> {
self.tx.send(children).await.into_diagnostic()?;
Ok(())
}
}
#[derive(Default)]
pub struct Node {
pub subscribers: Subs,
pub tree: Tree,
}
#[derive(Default)]
pub struct Subscribers {
data: Node,
}
impl Subscribers {
pub fn get_subscribers(&self, key: &[RegularKeySegment]) -> Vec<Subscriber> {
let mut all_subscribers = Vec::new();
add_matches(&self.data, key, &mut all_subscribers);
all_subscribers
}
pub fn add_subscriber(&mut self, pattern: &[KeySegment], subscriber: Subscriber) {
debug!("Adding subscriber for pattern {:?}", pattern);
let mut current = &mut self.data;
for elem in pattern {
current = match current.tree.entry(elem.to_owned()) {
Entry::Occupied(e) => e.into_mut(),
Entry::Vacant(e) => e.insert(Node::default()),
};
}
current.subscribers.push(subscriber);
}
pub fn unsubscribe(&mut self, pattern: &[KeySegment], subscription: &SubscriptionId) -> bool {
let mut current = &mut self.data;
for elem in pattern {
if let Some(node) = current.tree.get_mut(elem) {
current = node;
} else {
warn!("No subscriber found for pattern {:?}", pattern);
return false;
}
}
let mut removed = false;
current.subscribers.retain(|s| {
let retain = &s.id != subscription;
removed = removed || !retain;
if !retain {
debug!("Removing subscription {subscription:?} to pattern {pattern:?}");
}
retain
});
if !removed {
debug!("no matching subscription found")
}
removed
}
pub fn remove_subscriber(&mut self, subscriber: Subscriber) {
let mut current = &mut self.data;
for elem in &subscriber.pattern {
if let Some(node) = current.tree.get_mut(elem) {
current = node;
} else {
warn!("No subscriber found for pattern {:?}", subscriber.pattern);
return;
}
}
current.subscribers.retain(|s| s.id != subscriber.id);
}
}
fn add_matches(
mut current: &Node,
remaining_path: &[RegularKeySegment],
all_subscribers: &mut Vec<Subscriber>,
) {
let mut remaining_path = remaining_path;
for elem in remaining_path {
remaining_path = &remaining_path[1..];
if let Some(node) = current.tree.get(&KeySegment::Wildcard) {
add_matches(node, remaining_path, all_subscribers);
}
if let Some(node) = current.tree.get(&KeySegment::MultiWildcard) {
add_all_children(node, all_subscribers);
}
if let Some(node) = current.tree.get(&KeySegment::from(elem.to_owned())) {
current = node;
} else {
return;
}
}
all_subscribers.extend(current.subscribers.clone());
}
fn add_all_children(node: &Node, all_subscribers: &mut Vec<Subscriber>) {
all_subscribers.extend(node.subscribers.clone());
for node in node.tree.values() {
add_all_children(node, all_subscribers);
}
}
#[cfg(test)]
mod test {
use super::*;
use tokio::sync::mpsc::channel;
use uuid::Uuid;
use worterbuch_common::parse_segments;
fn reg_key_segs(key: &str) -> Vec<RegularKeySegment> {
parse_segments(key).unwrap()
}
fn key_segs(key: &str) -> Vec<KeySegment> {
KeySegment::parse(key)
}
#[test]
fn get_subscribers() {
let mut subscribers = Subscribers::default();
let (tx, _rx) = channel(1);
let pattern = KeySegment::parse("test/?/b/#");
let id = SubscriptionId {
client_id: Uuid::new_v4(),
transaction_id: 123,
};
let subscriber = Subscriber::new(
id,
pattern.clone().into_iter().map(|s| s.to_owned()).collect(),
EventSender::PState(tx),
false,
);
subscribers.add_subscriber(&pattern, subscriber);
let res = subscribers.get_subscribers(®_key_segs("test/a/b/c/d"));
assert_eq!(res.len(), 1);
let res = subscribers.get_subscribers(®_key_segs("test/a/b"));
assert_eq!(res.len(), 0);
}
#[test]
fn subscribers_are_cleaned_up() {
let mut subscribers = Subscribers::default();
let (tx, _rx) = channel(1);
let pattern = key_segs("test/?/b/#");
let id = SubscriptionId {
client_id: Uuid::new_v4(),
transaction_id: 123,
};
let subscriber = Subscriber::new(
id.clone(),
pattern.clone().into_iter().map(|s| s.to_owned()).collect(),
EventSender::PState(tx),
false,
);
let res = subscribers.get_subscribers(®_key_segs("test/a/b/c/d"));
assert_eq!(res.len(), 0);
subscribers.add_subscriber(&pattern, subscriber);
let res = subscribers.get_subscribers(®_key_segs("test/a/b/c/d"));
assert_eq!(res.len(), 1);
subscribers.unsubscribe(&pattern, &id);
let res = subscribers.get_subscribers(®_key_segs("test/a/b/c/d"));
assert_eq!(res.len(), 0);
}
}