use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TrySendError;
use log::warn;
use super::util::panic_in_test;
use super::packets::IncomingPublishPacket;
use super::commands::{FastCallback, SubscriptionKind};
use crate::wrappers::LmcHashMap;
#[derive(Default)]
pub struct Subscription
{
lossy_queue: Vec<mpsc::Sender<IncomingPublishPacket>>,
unbounded_queue: Vec<mpsc::UnboundedSender<IncomingPublishPacket>>,
fast_callbacks: Vec<FastCallback>,
subscribed: bool
}
fn prune<T, P: FnMut(&T) -> bool>(vec: &mut Vec<T>, mut predicate: P)
{
let mut i = 0;
while i < vec.len() {
if predicate(&vec[i]) {
vec.swap_remove(i);
} else {
i += 1;
}
}
}
impl Subscription
{
pub fn dispatch(&mut self, msg: &IncomingPublishPacket)
{
prune(&mut self.lossy_queue, |queue| {
matches!(queue.try_send(msg.clone()), Err(TrySendError::Closed(_)))
});
prune(&mut self.unbounded_queue, |queue| {
queue.send(msg.clone()).is_err()
});
for cb in &mut self.fast_callbacks {
(cb.f)(msg.clone());
}
}
fn add_fast_callback(&mut self, new_cb: FastCallback)
{
if self.fast_callbacks.last().map(|last_cb| new_cb.id > last_cb.id).unwrap_or(true) {
self.fast_callbacks.push(new_cb);
} else if let Err(index) = self.fast_callbacks.binary_search_by_key(&new_cb.id, |x| x.id) {
self.fast_callbacks.insert(index, new_cb);
} else {
panic_in_test!("Duplicate fast callback with ID {}", new_cb.id);
}
}
pub fn remove_fast_callback(&mut self, id: u32)
{
if let Ok(index) = self.fast_callbacks.binary_search_by_key(&id, |x| x.id) {
self.fast_callbacks.remove(index);
} else {
warn!("Fast callback with ID {} not found", id);
}
}
pub fn add(&mut self, kind: SubscriptionKind)
{
match kind {
SubscriptionKind::Void => {},
SubscriptionKind::Lossy(queue) => self.lossy_queue.push(queue),
SubscriptionKind::Unbounded(queue) => self.unbounded_queue.push(queue),
SubscriptionKind::FastCallback(cb) => self.add_fast_callback(cb)
}
}
fn clear(&mut self)
{
self.lossy_queue.clear();
self.unbounded_queue.clear();
self.fast_callbacks.clear();
self.subscribed = false;
}
pub fn set_subscribed(&mut self)
{
self.subscribed = true;
}
pub fn is_subscribed(&self) -> bool
{
self.subscribed
}
}
#[derive(Clone, Copy)]
struct TopicSlicer<'a>
{
head: &'a str,
tail: Option<&'a str>
}
impl<'a> TopicSlicer<'a>
{
fn new(s: &'a str) -> Self
{
if let Some(head_pos) = s.find('/') {
let head = &s[..head_pos];
let tail = &s[head_pos + 1..];
Self { head, tail: Some(tail) }
} else {
Self { head: s, tail: None }
}
}
fn tail(self) -> Option<Self>
{
self.tail.map(Self::new)
}
}
#[derive(Default)]
struct SubscriptionSetNode
{
exact: LmcHashMap<String, Box<SubscriptionSetEntry>>,
any: Option<Box<SubscriptionSetEntry>>,
any_recursive: Subscription
}
#[derive(Default)]
struct SubscriptionSetEntry
{
sub: Subscription,
children: SubscriptionSetNode
}
impl SubscriptionSetEntry
{
fn dispatch(&mut self, tail: Option<TopicSlicer>, msg: &IncomingPublishPacket)
{
if let Some(sub_topic) = tail {
self.children.dispatch(sub_topic, msg);
} else {
self.sub.dispatch(msg);
}
}
}
impl SubscriptionSetNode
{
fn dispatch(&mut self, topic: TopicSlicer, msg: &IncomingPublishPacket)
{
self.any_recursive.dispatch(msg);
if let Some(any) = &mut self.any {
any.dispatch(topic.tail(), msg);
}
if let Some(exact) = self.exact.get_mut(topic.head) {
exact.dispatch(topic.tail(), msg);
}
}
fn get_or_create_exact(&mut self, k: &str) -> &mut SubscriptionSetEntry
{
self.exact
.entry(k.to_string())
.or_insert_with(|| Box::new(Default::default()))
}
}
#[derive(Default)]
pub struct SubscriptionSet
{
root: SubscriptionSetNode
}
impl SubscriptionSet
{
pub fn dispatch(&mut self, msg: &IncomingPublishPacket)
{
let topic = TopicSlicer::new(msg.topic());
self.root.dispatch(topic, msg);
}
pub fn get_or_create(&mut self, topic: &str) -> Result<&mut Subscription, ()>
{
let mut topic_slicer = TopicSlicer::new(topic);
let mut node = &mut self.root;
while let Some(tail) = topic_slicer.tail() {
if topic_slicer.head == "#" {
return Err(());
}
node = if topic_slicer.head == "*" {
&mut node.any.get_or_insert_with(|| Box::new(Default::default())).as_mut().children
} else {
&mut node.get_or_create_exact(topic_slicer.head).children
};
topic_slicer = tail;
}
if topic_slicer.head == "#" {
Ok(&mut node.any_recursive)
} else if topic_slicer.head == "*" {
let entry = node.any.get_or_insert_with(|| Box::new(Default::default())).as_mut();
Ok(&mut entry.sub)
} else {
let entry = node.get_or_create_exact(topic_slicer.head);
Ok(&mut entry.sub)
}
}
pub fn get(&mut self, topic: &str) -> Result<Option<&mut Subscription>, ()>
{
let mut topic_slicer = TopicSlicer::new(topic);
let mut node = &mut self.root;
while let Some(tail) = topic_slicer.tail() {
if topic_slicer.head == "#" {
return Err(());
}
node = if topic_slicer.head == "*" {
match &mut node.any {
Some(entry) => &mut entry.children,
None => return Ok(None)
}
} else {
match node.exact.get_mut(topic_slicer.head) {
Some(entry) => &mut entry.children,
None => return Ok(None)
}
};
topic_slicer = tail;
}
if topic_slicer.head == "#" {
Ok(Some(&mut node.any_recursive))
} else if topic_slicer.head == "*" {
Ok(node.any.as_mut().map(|entry| &mut entry.sub))
} else {
Ok(node.exact.get_mut(topic_slicer.head).map(|entry| &mut entry.sub))
}
}
pub fn clear(&mut self, topic: &str) -> Result<bool, ()>
{
if let Some(sub) = self.get(topic)? {
if sub.is_subscribed() {
sub.clear();
return Ok(true);
}
}
Ok(false)
}
}