use std::{
collections::HashMap,
fmt::{self, Debug, Formatter},
mem,
slice::Iter,
sync::{Arc, Mutex},
};
use itertools::Itertools;
use crate::{messages::Reason, MatchingPolicy, ID, URI};
use super::super::{random_id, ConnectionInfo};
pub struct SubscriptionPatternNode<P: PatternData> {
edges: HashMap<String, SubscriptionPatternNode<P>>,
connections: Vec<DataWrapper<P>>,
prefix_connections: Vec<DataWrapper<P>>,
id: ID,
prefix_id: ID,
}
pub trait PatternData {
fn get_id(&self) -> ID;
}
struct DataWrapper<P: PatternData> {
subscriber: P,
policy: MatchingPolicy,
}
pub struct MatchIterator<'a, P>
where
P: PatternData,
{
uri: Vec<String>,
current: Box<StackFrame<'a, P>>,
}
struct StackFrame<'a, P>
where
P: PatternData,
{
node: &'a SubscriptionPatternNode<P>,
state: IterState<'a, P>,
depth: usize,
parent: Option<Box<StackFrame<'a, P>>>,
}
#[derive(Debug)]
pub struct PatternError {
reason: Reason,
}
#[derive(Clone)]
enum IterState<'a, P: PatternData>
where
P: PatternData,
{
None,
Wildcard,
Strict,
Prefix(Iter<'a, DataWrapper<P>>),
PrefixComplete,
Subs(Iter<'a, DataWrapper<P>>),
AllComplete,
}
impl PatternError {
#[inline]
pub fn new(reason: Reason) -> PatternError {
PatternError { reason }
}
pub fn reason(self) -> Reason {
self.reason
}
}
impl PatternData for Arc<Mutex<ConnectionInfo>> {
fn get_id(&self) -> ID {
self.lock().unwrap().id
}
}
impl<'a, P: PatternData> Debug for IterState<'a, P> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(
f,
"{}",
match *self {
IterState::None => "None",
IterState::Wildcard => "Wildcard",
IterState::Strict => "Strict",
IterState::Prefix(_) => "Prefix",
IterState::PrefixComplete => "PrefixComplete",
IterState::Subs(_) => "Subs",
IterState::AllComplete => "AllComplete",
}
)
}
}
impl<P: PatternData> Debug for SubscriptionPatternNode<P> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.fmt_with_indent(f, 0)
}
}
impl<P: PatternData> SubscriptionPatternNode<P> {
fn fmt_with_indent(&self, f: &mut Formatter<'_>, indent: usize) -> fmt::Result {
writeln!(
f,
"{} pre: {:?} subs: {:?}",
self.id,
self.prefix_connections
.iter()
.map(|sub| sub.subscriber.get_id())
.join(","),
self.connections
.iter()
.map(|sub| sub.subscriber.get_id())
.join(","),
)?;
for (chunk, node) in &self.edges {
for _ in 0..indent * 2 {
write!(f, " ")?;
}
write!(f, "{} - ", chunk)?;
node.fmt_with_indent(f, indent + 1)?;
}
Ok(())
}
pub fn subscribe_with(
&mut self,
topic: &URI,
subscriber: P,
matching_policy: MatchingPolicy,
) -> Result<ID, PatternError> {
let mut uri_bits = topic.uri.split('.');
let initial = match uri_bits.next() {
Some(initial) => initial,
None => return Err(PatternError::new(Reason::InvalidURI)),
};
let edge = self
.edges
.entry(initial.to_string())
.or_insert_with(SubscriptionPatternNode::new);
edge.add_subscription(uri_bits, subscriber, matching_policy)
}
pub fn unsubscribe_with(
&mut self,
topic: &str,
subscriber: &P,
is_prefix: bool,
) -> Result<ID, PatternError> {
let uri_bits = topic.split('.');
self.remove_subscription(uri_bits, subscriber.get_id(), is_prefix)
}
#[inline]
pub fn new() -> SubscriptionPatternNode<P> {
SubscriptionPatternNode {
edges: HashMap::new(),
connections: Vec::new(),
prefix_connections: Vec::new(),
id: random_id(),
prefix_id: random_id(),
}
}
fn add_subscription<'a, I>(
&mut self,
mut uri_bits: I,
subscriber: P,
matching_policy: MatchingPolicy,
) -> Result<ID, PatternError>
where
I: Iterator<Item = &'a str>,
{
match uri_bits.next() {
Some(uri_bit) => {
if uri_bit.is_empty() && matching_policy != MatchingPolicy::Wildcard {
return Err(PatternError::new(Reason::InvalidURI));
}
let edge = self
.edges
.entry(uri_bit.to_string())
.or_insert_with(SubscriptionPatternNode::new);
edge.add_subscription(uri_bits, subscriber, matching_policy)
}
None => {
if matching_policy == MatchingPolicy::Prefix {
self.prefix_connections.push(DataWrapper {
subscriber,
policy: matching_policy,
});
Ok(self.prefix_id)
} else {
self.connections.push(DataWrapper {
subscriber,
policy: matching_policy,
});
Ok(self.id)
}
}
}
}
fn remove_subscription<'a, I>(
&mut self,
mut uri_bits: I,
subscriber_id: u64,
is_prefix: bool,
) -> Result<ID, PatternError>
where
I: Iterator<Item = &'a str>,
{
match uri_bits.next() {
Some(uri_bit) => {
if let Some(edge) = self.edges.get_mut(uri_bit) {
edge.remove_subscription(uri_bits, subscriber_id, is_prefix)
} else {
Err(PatternError::new(Reason::InvalidURI))
}
}
None => {
if is_prefix {
self.prefix_connections
.retain(|sub| sub.subscriber.get_id() != subscriber_id);
Ok(self.prefix_id)
} else {
self.connections
.retain(|sub| sub.subscriber.get_id() != subscriber_id);
Ok(self.id)
}
}
}
}
pub fn filter(&self, topic: URI) -> MatchIterator<'_, P> {
MatchIterator {
current: Box::new(StackFrame {
node: self,
depth: 0,
state: IterState::None,
parent: None,
}),
uri: topic.uri.split('.').map(|s| s.to_string()).collect(),
}
}
}
impl<'a, P: PatternData> MatchIterator<'a, P> {
fn push(&mut self, child: &'a SubscriptionPatternNode<P>) {
let new_node = Box::new(StackFrame {
parent: None,
depth: self.current.depth + 1,
node: child,
state: IterState::None,
});
let parent = mem::replace(&mut self.current, new_node);
self.current.parent = Some(parent);
}
fn traverse(&mut self) -> Option<(&'a P, ID, MatchingPolicy)> {
match self.current.state {
IterState::None => {
self.current.state = IterState::Prefix(self.current.node.prefix_connections.iter())
}
IterState::Prefix(_) => {
self.current.state = IterState::PrefixComplete;
}
IterState::PrefixComplete => {
if self.current.depth == self.uri.len() {
self.current.state = IterState::Subs(self.current.node.connections.iter());
} else if let Some(child) = self.current.node.edges.get("") {
self.current.state = IterState::Wildcard;
self.push(child);
} else if let Some(child) =
self.current.node.edges.get(&self.uri[self.current.depth])
{
self.current.state = IterState::Strict;
self.push(child);
} else {
self.current.state = IterState::AllComplete;
}
}
IterState::Wildcard => {
if self.current.depth == self.uri.len() {
self.current.state = IterState::AllComplete;
} else if let Some(child) =
self.current.node.edges.get(&self.uri[self.current.depth])
{
self.current.state = IterState::Strict;
self.push(child);
} else {
self.current.state = IterState::AllComplete;
}
}
IterState::Strict => {
self.current.state = IterState::AllComplete;
}
IterState::Subs(_) => {
self.current.state = IterState::AllComplete;
}
IterState::AllComplete => {
if self.current.depth == 0 {
return None;
} else {
let parent = self.current.parent.take();
let _ = mem::replace(&mut self.current, parent.unwrap());
}
}
};
self.next()
}
}
impl<'a, P: PatternData> Iterator for MatchIterator<'a, P> {
type Item = (&'a P, ID, MatchingPolicy);
fn next(&mut self) -> Option<(&'a P, ID, MatchingPolicy)> {
let prefix_id = self.current.node.prefix_id;
let node_id = self.current.node.id;
match self.current.state {
IterState::Prefix(ref mut prefix_iter) => {
let next = prefix_iter.next();
if let Some(next) = next {
return Some((&next.subscriber, prefix_id, next.policy));
}
}
IterState::Subs(ref mut sub_iter) => {
let next = sub_iter.next();
if let Some(next) = next {
return Some((&next.subscriber, node_id, next.policy));
}
}
_ => {}
};
self.traverse()
}
}
#[cfg(test)]
mod test {
use super::{PatternData, SubscriptionPatternNode};
use crate::{MatchingPolicy, ID, URI};
#[derive(Clone)]
struct MockData {
id: ID,
}
impl PatternData for MockData {
fn get_id(&self) -> ID {
self.id
}
}
impl MockData {
pub fn new(id: ID) -> MockData {
MockData { id }
}
}
#[test]
fn adding_patterns() {
let connection1 = MockData::new(1);
let connection2 = MockData::new(2);
let connection3 = MockData::new(3);
let connection4 = MockData::new(4);
let mut root = SubscriptionPatternNode::new();
let ids = [
root.subscribe_with(
&URI::new("com.example.test..topic"),
connection1,
MatchingPolicy::Wildcard,
)
.unwrap(),
root.subscribe_with(
&URI::new("com.example.test.specific.topic"),
connection2,
MatchingPolicy::Strict,
)
.unwrap(),
root.subscribe_with(
&URI::new("com.example"),
connection3,
MatchingPolicy::Prefix,
)
.unwrap(),
root.subscribe_with(
&URI::new("com.example.test"),
connection4,
MatchingPolicy::Prefix,
)
.unwrap(),
];
assert_eq!(
root.filter(URI::new("com.example.test.specific.topic"))
.map(|(_connection, id, _policy)| id)
.collect::<Vec<_>>(),
vec![ids[2], ids[3], ids[0], ids[1]]
);
}
#[test]
fn removing_patterns() {
let connection1 = MockData::new(1);
let connection2 = MockData::new(2);
let connection3 = MockData::new(3);
let connection4 = MockData::new(4);
let mut root = SubscriptionPatternNode::new();
let ids = [
root.subscribe_with(
&URI::new("com.example.test..topic"),
connection1.clone(),
MatchingPolicy::Wildcard,
)
.unwrap(),
root.subscribe_with(
&URI::new("com.example.test.specific.topic"),
connection2,
MatchingPolicy::Strict,
)
.unwrap(),
root.subscribe_with(
&URI::new("com.example"),
connection3,
MatchingPolicy::Prefix,
)
.unwrap(),
root.subscribe_with(
&URI::new("com.example.test"),
connection4.clone(),
MatchingPolicy::Prefix,
)
.unwrap(),
];
root.unsubscribe_with("com.example.test..topic", &connection1, false)
.unwrap();
root.unsubscribe_with("com.example.test", &connection4, true)
.unwrap();
assert_eq!(
root.filter(URI::new("com.example.test.specific.topic"))
.map(|(_connection, id, _policy)| id)
.collect::<Vec<_>>(),
vec![ids[2], ids[1]]
)
}
}