use std::{
borrow::Cow,
collections::{HashMap, HashSet},
sync::Arc,
};
use itertools::Itertools;
use petgraph::graph::NodeIndex;
use zenoh_protocol::{
core::{Region, ZenohIdProto},
network::declare::{
common::ext::WireExprType, ext, Declare, DeclareBody, DeclareSubscriber, SubscriberId,
UndeclareSubscriber,
},
};
use super::Hat;
use crate::net::{
protocol::network::Network,
routing::{
dispatcher::{
face::FaceState,
pubsub::SubscriberInfo,
resource::{NodeId, Resource},
tables::{Route, RoutingExpr, TablesData},
},
gateway::{Direction, RouteBuilder},
hat::{DispatcherContext, HatBaseTrait, HatPubSubTrait, Sources},
RoutingContext,
},
};
#[allow(unused_imports)]
use crate::zenoh_core::polyfill::*;
impl Hat {
pub(super) fn pubsub_tree_change(
&self,
tables: &mut TablesData,
new_children: &[Vec<NodeIndex>],
) {
let net = self.net();
for (tree_sid, tree_children) in new_children.iter().enumerate() {
if !tree_children.is_empty() {
let tree_idx = NodeIndex::new(tree_sid);
if net.graph.contains_node(tree_idx) {
let tree_id = net.graph[tree_idx].zid;
let subs_res = &self.router_subs;
for res in subs_res {
let subs = &self.res_hat(res).router_subs;
for sub in subs {
if *sub == tree_id {
let sub_info = SubscriberInfo;
self.send_sourced_subscriber_to_net_children(
tables,
tree_children,
res,
None,
&sub_info,
tree_sid as NodeId,
);
}
}
}
}
}
}
}
fn send_sourced_subscriber_to_net_children(
&self,
tables: &TablesData,
children: &[NodeIndex],
res: &Arc<Resource>,
src_face: Option<&Arc<FaceState>>,
_sub_info: &SubscriberInfo,
node_id: NodeId,
) {
let net = self.net();
for child in children {
if net.graph.contains_node(*child) {
match self.face(tables, &net.graph[*child].zid).cloned() {
Some(mut someface) => {
if src_face
.map(|src_face| someface.id != src_face.id)
.unwrap_or(true)
{
let key_expr = Resource::decl_key(res, &mut someface);
tracing::debug!(dst = %someface);
someface.primitives.send_declare(RoutingContext::with_expr(
&mut Declare {
interest_id: None,
ext_qos: ext::QoSType::DECLARE,
ext_tstamp: None,
ext_nodeid: ext::NodeIdType { node_id },
body: DeclareBody::DeclareSubscriber(DeclareSubscriber {
id: SubscriberId::default(), wire_expr: key_expr,
}),
},
res.expr().to_string(),
));
}
}
None => {
tracing::trace!("Unable to find face for zid {}", net.graph[*child].zid)
}
}
}
}
}
fn propagate_sourced_subscriber(
&self,
tables: &TablesData,
res: &Arc<Resource>,
sub_info: &SubscriberInfo,
src_face: Option<&Arc<FaceState>>,
source: &ZenohIdProto,
) {
let net = self.net();
match net.get_idx(source) {
Some(tree_sid) => {
if net.trees.len() > tree_sid.index() {
self.send_sourced_subscriber_to_net_children(
tables,
&net.trees[tree_sid.index()].children,
res,
src_face,
sub_info,
tree_sid.index() as NodeId,
);
} else {
tracing::trace!(
"Propagating sub {}: tree for node {} sid:{} not yet ready",
res.expr(),
tree_sid.index(),
source
);
}
}
None => tracing::error!(
"Error propagating sub {}: cannot get index of {}!",
res.expr(),
source
),
}
}
#[inline]
fn send_forget_sourced_subscriber_to_net_children(
&self,
tables: &TablesData,
net: &Network,
children: &[NodeIndex],
res: &Arc<Resource>,
src_face: Option<&Arc<FaceState>>,
routing_context: Option<NodeId>,
) {
for child in children {
if net.graph.contains_node(*child) {
match self.face(tables, &net.graph[*child].zid).cloned() {
Some(mut someface) => {
if src_face
.map(|src_face| someface.id != src_face.id)
.unwrap_or(true)
{
let wire_expr = Resource::decl_key(res, &mut someface);
tracing::debug!(dst = %someface);
someface.primitives.send_declare(RoutingContext::with_expr(
&mut Declare {
interest_id: None,
ext_qos: ext::QoSType::DECLARE,
ext_tstamp: None,
ext_nodeid: ext::NodeIdType {
node_id: routing_context.unwrap_or(0),
},
body: DeclareBody::UndeclareSubscriber(UndeclareSubscriber {
id: SubscriberId::default(), ext_wire_expr: WireExprType { wire_expr },
}),
},
res.expr().to_string(),
));
}
}
None => {
tracing::trace!("Unable to find face for zid {}", net.graph[*child].zid)
}
}
}
}
}
fn propagate_forget_sourced_subscriber(
&self,
tables: &TablesData,
res: &Arc<Resource>,
src_face: Option<&Arc<FaceState>>,
source: &ZenohIdProto,
) {
let net = self.routers_net.as_ref().unwrap();
match net.get_idx(source) {
Some(tree_sid) => {
if net.trees.len() > tree_sid.index() {
self.send_forget_sourced_subscriber_to_net_children(
tables,
net,
&net.trees[tree_sid.index()].children,
res,
src_face,
Some(tree_sid.index() as NodeId),
);
} else {
tracing::trace!(
"Propagating forget sub {}: tree for node {} sid:{} not yet ready",
res.expr(),
tree_sid.index(),
source
);
}
}
None => tracing::error!(
"Error propagating forget sub {}: cannot get index of {}!",
res.expr(),
source
),
}
}
pub(super) fn unregister_node_subscribers(
&mut self,
zid: &ZenohIdProto,
) -> HashSet<Arc<Resource>> {
let removed_routers = self
.net_mut()
.remove_link(zid)
.into_iter()
.map(|(_, node)| node)
.collect::<HashSet<_>>();
let mut resources = HashSet::new();
for mut res in self.router_subs.iter().cloned().collect_vec() {
self.res_hat_mut(&mut res)
.router_subs
.retain(|router| !removed_routers.contains(router));
if self.res_hat(&res).router_subs.is_empty() {
self.router_subs.retain(|r| !Arc::ptr_eq(r, &res));
resources.insert(res);
}
}
resources
}
}
impl HatPubSubTrait for Hat {
#[tracing::instrument(level = "debug", skip(tables), ret)]
fn sourced_subscribers(&self, tables: &TablesData) -> HashMap<Arc<Resource>, Sources> {
self.router_subs
.iter()
.map(|sub| {
(
sub.clone(),
Sources {
routers: Vec::from_iter(
self.res_hat(sub)
.router_subs
.iter()
.copied()
.filter(|router| router != &tables.zid),
),
peers: Vec::default(),
clients: Vec::default(),
},
)
})
.collect()
}
#[tracing::instrument(level = "debug", skip(_tables), ret)]
fn sourced_publishers(&self, _tables: &TablesData) -> HashMap<Arc<Resource>, Sources> {
HashMap::default()
}
#[tracing::instrument(level = "debug", skip(tables, src_region), ret)]
fn compute_data_route(
&self,
tables: &TablesData,
src_region: &Region,
expr: &RoutingExpr,
node_id: NodeId,
) -> Arc<Route> {
#[inline]
fn insert_faces_for_subs(
this: &Hat,
route: &mut RouteBuilder<Direction>,
expr: &RoutingExpr,
tables: &TablesData,
net: &Network,
source: NodeId,
subs: &HashSet<ZenohIdProto>,
) {
if net.trees.len() > source as usize {
for sub in subs {
if let Some(sub_idx) = net.get_idx(sub) {
if net.trees[source as usize].directions.len() > sub_idx.index() {
if let Some(direction) =
net.trees[source as usize].directions[sub_idx.index()]
{
if net.graph.contains_node(direction) {
if let Some(face) = this.face(tables, &net.graph[direction].zid)
{
tracing::debug!(dst = %face, dst.has_subscriber = true);
route.insert(face.id, || {
let wire_expr = expr.get_best_key(face.id);
Direction {
dst_face: face.clone(),
wire_expr: wire_expr.to_owned(),
node_id: source,
}
});
}
}
}
}
}
}
} else {
tracing::trace!("Tree for node sid:{} not yet ready", source);
}
}
let mut route = RouteBuilder::<Direction>::new();
let Some(key_expr) = expr.key_expr() else {
return Arc::new(route.build());
};
let matches = expr
.resource()
.as_ref()
.and_then(|res| res.ctx.as_ref())
.map(|ctx| Cow::from(&ctx.matches))
.unwrap_or_else(|| Cow::from(Resource::get_matches(tables, key_expr)));
for mres in matches.iter() {
let mres = mres.upgrade().unwrap();
let net = self.net();
let router_source = if *src_region == self.region() {
node_id
} else {
net.idx.index() as NodeId
};
insert_faces_for_subs(
self,
&mut route,
expr,
tables,
net,
router_source,
&self.res_hat(&mres).router_subs,
);
}
Arc::new(route.build())
}
#[tracing::instrument(level = "debug", skip(ctx, _id, node_id, info), ret)]
fn register_subscriber(
&mut self,
ctx: DispatcherContext,
_id: SubscriberId,
mut res: Arc<Resource>,
node_id: NodeId,
info: &SubscriberInfo,
) {
debug_assert!(self.owns(ctx.src_face));
let Some(router) = self.get_router(ctx.src_face, node_id) else {
tracing::error!(%node_id, "Subscriber from unknown router");
return;
};
debug_assert_ne!(router, ctx.tables.zid);
self.res_hat_mut(&mut res).router_subs.insert(router);
self.router_subs.insert(res.clone());
self.propagate_sourced_subscriber(ctx.tables, &res, info, Some(ctx.src_face), &router);
}
#[tracing::instrument(level = "debug", skip(ctx, _id, node_id), ret)]
fn unregister_subscriber(
&mut self,
ctx: DispatcherContext,
_id: SubscriberId,
res: Option<Arc<Resource>>,
node_id: NodeId,
) -> Option<Arc<Resource>> {
debug_assert!(self.owns(ctx.src_face));
let Some(router) = self.get_router(ctx.src_face, node_id) else {
tracing::error!(%node_id, "Subscriber from unknown router");
return None;
};
debug_assert_ne!(router, ctx.tables.zid);
let Some(mut res) = res else {
tracing::error!("Subscriber undeclaration in router region with no resource");
return None;
};
self.res_hat_mut(&mut res).router_subs.remove(&router);
if self.res_hat(&res).router_subs.is_empty() {
self.router_subs.retain(|r| !Arc::ptr_eq(r, &res));
}
self.propagate_forget_sourced_subscriber(ctx.tables, &res, Some(ctx.src_face), &router);
Some(res)
}
#[tracing::instrument(level = "debug", skip(ctx), ret)]
fn unregister_face_subscribers(&mut self, ctx: DispatcherContext) -> HashSet<Arc<Resource>> {
self.unregister_node_subscribers(&ctx.src_face.zid)
}
#[tracing::instrument(level = "debug", skip(ctx), ret)]
fn propagate_subscriber(
&mut self,
ctx: DispatcherContext,
mut res: Arc<Resource>,
other_info: Option<SubscriberInfo>,
) {
let Some(other_info) = other_info else {
debug_assert!(self.owns(ctx.src_face));
return;
};
if !self.res_hat(&res).router_subs.contains(&ctx.tables.zid) {
self.res_hat_mut(&mut res)
.router_subs
.insert(ctx.tables.zid);
self.router_subs.insert(res.clone());
self.propagate_sourced_subscriber(ctx.tables, &res, &other_info, None, &ctx.tables.zid);
}
}
#[tracing::instrument(level = "debug", skip(ctx), ret)]
fn unpropagate_subscriber(&mut self, ctx: DispatcherContext, mut res: Arc<Resource>) {
if self.owns(ctx.src_face) {
return;
}
let was_propagated = self.res_hat(&res).router_subs.contains(&ctx.tables.zid);
debug_assert!(was_propagated);
if was_propagated {
self.res_hat_mut(&mut res)
.router_subs
.remove(&ctx.tables.zid);
if self.res_hat(&res).router_subs.is_empty() {
self.router_subs.retain(|r| !Arc::ptr_eq(r, &res));
}
self.propagate_forget_sourced_subscriber(ctx.tables, &res, None, &ctx.tables.zid);
}
}
#[tracing::instrument(level = "trace", ret)]
fn remote_subscribers_of(&self, tables: &TablesData, res: &Resource) -> Option<SubscriberInfo> {
self.res_hat(res)
.router_subs
.iter()
.any(|router| router != &tables.zid)
.then_some(SubscriberInfo)
}
#[allow(clippy::incompatible_msrv)]
#[tracing::instrument(level = "trace", skip(tables), ret)]
fn remote_subscribers_matching(
&self,
tables: &TablesData,
res: Option<&Resource>,
) -> HashMap<Arc<Resource>, SubscriberInfo> {
self.router_subs
.iter()
.filter_map(|sub| {
if self
.res_hat(sub)
.router_subs
.iter()
.any(|router| router != &tables.zid)
&& res.is_none_or(|res| res.matches(sub))
{
Some((sub.clone(), SubscriberInfo))
} else {
None
}
})
.collect()
}
}