use std::collections::{HashMap, HashSet};
use super::ClientId;
use crate::protocol::common::server::connection_graph_update::{
AdvertisedService, ConnectionGraphUpdate, PublishedTopic, SubscribedTopic,
};
type MapOfSets = HashMap<String, HashSet<String>>;
#[derive(Debug, Default, Clone)]
pub struct ConnectionGraph {
published_topics: MapOfSets,
subscribed_topics: MapOfSets,
advertised_services: MapOfSets,
subscribers: HashSet<ClientId>,
}
impl ConnectionGraph {
pub fn new() -> Self {
Self::default()
}
pub fn set_published_topic(
&mut self,
topic: impl Into<String>,
publisher_ids: impl IntoIterator<Item = impl Into<String>>,
) {
self.published_topics.insert(
topic.into(),
HashSet::from_iter(publisher_ids.into_iter().map(|id| id.into())),
);
}
pub fn set_subscribed_topic(
&mut self,
topic: impl Into<String>,
subscriber_ids: impl IntoIterator<Item = impl Into<String>>,
) {
self.subscribed_topics.insert(
topic.into(),
HashSet::from_iter(subscriber_ids.into_iter().map(|id| id.into())),
);
}
pub fn set_advertised_service(
&mut self,
service: impl Into<String>,
provider_ids: impl IntoIterator<Item = impl Into<String>>,
) {
self.advertised_services.insert(
service.into(),
HashSet::from_iter(provider_ids.into_iter().map(|id| id.into())),
);
}
pub(crate) fn add_subscriber(&mut self, client_id: ClientId) -> bool {
self.subscribers.insert(client_id)
}
pub(crate) fn remove_subscriber(&mut self, client_id: ClientId) -> bool {
self.subscribers.remove(&client_id)
}
#[cfg(feature = "remote-access")]
pub(crate) fn clear_subscribers(&mut self) {
self.subscribers.clear();
}
pub(crate) fn has_subscribers(&self) -> bool {
!self.subscribers.is_empty()
}
pub(crate) fn is_subscriber(&self, client_id: ClientId) -> bool {
self.subscribers.contains(&client_id)
}
pub(crate) fn diff(&self, other: &ConnectionGraph) -> ConnectionGraphUpdate {
let mut diff = ConnectionGraphUpdate::default();
for (name, publisher_ids) in &other.published_topics {
if let Some(self_publisher_ids) = self.published_topics.get(name) {
if self_publisher_ids == publisher_ids {
continue;
}
}
diff.published_topics.push(PublishedTopic {
name: name.clone(),
publisher_ids: publisher_ids.iter().cloned().collect(),
});
}
for (name, subscriber_ids) in &other.subscribed_topics {
if let Some(self_subscriber_ids) = self.subscribed_topics.get(name) {
if self_subscriber_ids == subscriber_ids {
continue;
}
}
diff.subscribed_topics.push(SubscribedTopic {
name: name.clone(),
subscriber_ids: subscriber_ids.iter().cloned().collect(),
});
}
for (name, provider_ids) in &other.advertised_services {
if let Some(self_provider_ids) = self.advertised_services.get(name) {
if self_provider_ids == provider_ids {
continue;
}
}
diff.advertised_services.push(AdvertisedService {
name: name.clone(),
provider_ids: provider_ids.iter().cloned().collect(),
});
}
diff.removed_services = self
.advertised_services
.keys()
.filter(|name| !other.advertised_services.contains_key(*name))
.cloned()
.collect();
let removed_topics: HashSet<_> = self
.published_topics
.keys()
.chain(self.subscribed_topics.keys())
.filter(|name| {
!other.published_topics.contains_key(*name)
&& !other.subscribed_topics.contains_key(*name)
})
.collect();
diff.removed_topics = removed_topics.into_iter().cloned().collect();
diff
}
pub(crate) fn as_initial_update(&self) -> ConnectionGraphUpdate {
ConnectionGraph::default().diff(self)
}
pub(crate) fn update(&mut self, new: ConnectionGraph) -> ConnectionGraphUpdate {
let diff = self.diff(&new);
self.published_topics = new.published_topics;
self.subscribed_topics = new.subscribed_topics;
self.advertised_services = new.advertised_services;
diff
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_update() {
let mut graph = ConnectionGraph::new();
let updated = ConnectionGraph::new();
let diff = graph.update(updated);
assert_eq!(diff, ConnectionGraphUpdate::default());
}
#[test]
fn test_new_published_topic() {
let mut graph = ConnectionGraph::new();
let mut updated = ConnectionGraph::new();
updated.published_topics.insert(
"topic1".to_string(),
HashSet::from(["publisher1".to_string()]),
);
let diff = graph.update(updated);
assert_eq!(
diff,
ConnectionGraphUpdate {
published_topics: vec![PublishedTopic::new("topic1", ["publisher1"])],
..Default::default()
},
);
}
#[test]
fn test_removed_topic() {
let mut graph = ConnectionGraph::new();
graph.published_topics.insert(
"topic1".to_string(),
HashSet::from(["publisher1".to_string()]),
);
let updated = ConnectionGraph::new();
let diff = graph.update(updated);
assert_eq!(
diff,
ConnectionGraphUpdate {
removed_topics: vec!["topic1".into()],
..Default::default()
}
);
}
#[test]
fn test_changed_publishers() {
let mut graph = ConnectionGraph::new();
graph.published_topics.insert(
"topic1".to_string(),
HashSet::from(["publisher1".to_string()]),
);
let mut updated = ConnectionGraph::new();
updated.published_topics.insert(
"topic1".to_string(),
HashSet::from(["publisher2".to_string()]),
);
let diff = graph.update(updated);
assert_eq!(
diff,
ConnectionGraphUpdate {
published_topics: vec![PublishedTopic::new("topic1", ["publisher2"])],
..Default::default()
}
);
}
#[test]
fn test_service_changes() {
let mut graph = ConnectionGraph::new();
graph.advertised_services.insert(
"service1".to_string(),
HashSet::from(["provider1".to_string()]),
);
let mut updated = ConnectionGraph::new();
updated.advertised_services.insert(
"service2".to_string(),
HashSet::from(["provider2".to_string()]),
);
let diff = graph.update(updated);
assert_eq!(
diff,
ConnectionGraphUpdate {
advertised_services: vec![AdvertisedService::new("service2", ["provider2"])],
removed_services: vec!["service1".into()],
..Default::default()
}
);
}
#[test]
fn test_complex_update() {
let mut graph = ConnectionGraph::new();
graph.published_topics.insert(
"topic1".to_string(),
HashSet::from(["publisher1".to_string()]),
);
graph.subscribed_topics.insert(
"topic1".to_string(),
HashSet::from(["subscriber1".to_string()]),
);
graph.advertised_services.insert(
"service1".to_string(),
HashSet::from(["provider1".to_string()]),
);
let mut updated = ConnectionGraph::new();
updated.published_topics.insert(
"topic2".to_string(),
HashSet::from(["publisher2".to_string()]),
);
updated.subscribed_topics.insert(
"topic2".to_string(),
HashSet::from(["subscriber2".to_string()]),
);
updated.advertised_services.insert(
"service2".to_string(),
HashSet::from(["provider2".to_string()]),
);
let diff = graph.update(updated);
assert_eq!(
diff,
ConnectionGraphUpdate {
published_topics: vec![PublishedTopic::new("topic2", ["publisher2"])],
subscribed_topics: vec![SubscribedTopic::new("topic2", ["subscriber2"])],
advertised_services: vec![AdvertisedService::new("service2", ["provider2"])],
removed_topics: vec!["topic1".into()],
removed_services: vec!["service1".into()],
}
);
}
}