use ankurah_proto::{self as proto, Attested};
use tracing::warn;
use crate::{
error::SubscriptionError,
node::Node,
policy::PolicyAgent,
reactor::{ReactorSubscription, ReactorUpdate},
storage::StorageEngine,
};
use ankurah_signals::{Subscribe, SubscriptionGuard};
pub struct SubscriptionHandler {
_peer_id: proto::EntityId,
subscription: ReactorSubscription,
_guard: SubscriptionGuard,
}
impl SubscriptionHandler {
pub fn new<SE, PA>(peer_id: proto::EntityId, node: &Node<SE, PA>) -> Self
where
SE: StorageEngine + Send + Sync + 'static,
PA: PolicyAgent + Send + Sync + 'static,
{
let subscription = node.reactor.subscribe();
let weak_node = node.weak();
let guard = subscription.subscribe(move |update: ReactorUpdate| {
tracing::info!("SubscriptionHandler[{}] received reactor update with {} items", peer_id, update.items.len());
if let Some(node) = weak_node.upgrade() {
tracing::debug!("SubscriptionHandler[{}] sending update to peer {}", peer_id, peer_id);
node.send_update(
peer_id,
proto::NodeUpdateBody::SubscriptionUpdate {
items: update.items.into_iter().filter_map(|item| convert_item(&node, peer_id, item)).collect(),
},
);
}
});
Self { _peer_id: peer_id, subscription, _guard: guard }
}
pub fn subscription_id(&self) -> crate::reactor::ReactorSubscriptionId { self.subscription.id() }
pub fn subscription(&self) -> &ReactorSubscription { &self.subscription }
pub fn remove_predicate(&self, query_id: proto::QueryId) -> Result<(), SubscriptionError> {
self.subscription.remove_predicate(query_id)?;
Ok(())
}
pub async fn subscribe_query<SE, PA>(
&self,
node: &Node<SE, PA>,
query_id: proto::QueryId,
collection_id: proto::CollectionId,
mut selection: ankql::ast::Selection,
cdata: &PA::ContextData,
version: u32,
known_matches: Vec<proto::KnownEntity>,
) -> anyhow::Result<proto::NodeResponseBody>
where
SE: StorageEngine + Send + Sync + 'static,
PA: PolicyAgent + Send + Sync + 'static,
{
if version == 0 {
return Err(anyhow::anyhow!("Invalid version 0 for subscription"));
}
node.policy_agent.can_access_collection(cdata, &collection_id)?;
selection.predicate = node.policy_agent.filter_predicate(cdata, &collection_id, selection.predicate)?;
let storage_collection = node.collections.get(&collection_id).await?;
let matching_entities = node
.reactor
.upsert_query(self.subscription.id(), query_id, collection_id.clone(), selection.clone(), node, cdata, version)
.await?;
let initial_states: Vec<_> = matching_entities
.into_iter()
.filter_map(|e| {
let entity_state = e.to_entity_state().ok()?;
let attestation = node.policy_agent.attest_state(node, &entity_state);
Some(Attested::opt(entity_state, attestation))
})
.collect();
let expanded_states = crate::util::expand_states::expand_states(
initial_states,
known_matches.iter().map(|k| k.entity_id).collect::<Vec<_>>(),
&storage_collection,
)
.await?;
let known_map: std::collections::HashMap<_, _> = known_matches.into_iter().map(|k| (k.entity_id, k.head)).collect();
let mut deltas = Vec::with_capacity(expanded_states.len());
for state in expanded_states {
if node.policy_agent.check_read(cdata, &state.payload.entity_id, &collection_id, &state.payload.state).is_err() {
continue;
}
if let Some(delta) = node.generate_entity_delta(&known_map, state, &storage_collection).await? {
deltas.push(delta);
}
}
Ok(proto::NodeResponseBody::QuerySubscribed { query_id, deltas })
}
}
fn convert_item<SE, PA>(
node: &Node<SE, PA>,
peer_id: proto::EntityId,
item: crate::reactor::ReactorUpdateItem,
) -> Option<proto::SubscriptionUpdateItem>
where
SE: StorageEngine + Send + Sync + 'static,
PA: PolicyAgent + Send + Sync + 'static,
{
let entity_state = match item.entity.to_entity_state() {
Ok(entity_state) => entity_state,
Err(e) => {
warn!("Failed to convert entity {} to EntityState for peer {}: {}", item.entity.id(), peer_id, e);
return None;
}
};
let attestation = node.policy_agent.attest_state(node, &entity_state);
let attested_state = Attested::opt(entity_state, attestation);
let attested_events = item.events;
let content = proto::UpdateContent::StateAndEvent(attested_state.into(), attested_events.into_iter().map(|e| e.into()).collect());
let predicate_relevance = item
.predicate_relevance
.into_iter()
.map(|(pred_id, membership)| {
let proto_membership = match membership {
crate::reactor::MembershipChange::Initial => proto::MembershipChange::Initial,
crate::reactor::MembershipChange::Add => proto::MembershipChange::Add,
crate::reactor::MembershipChange::Remove => proto::MembershipChange::Remove,
};
(pred_id, proto_membership)
})
.collect();
Some(proto::SubscriptionUpdateItem {
entity_id: item.entity.id(),
collection: item.entity.collection().clone(),
content,
predicate_relevance,
})
}