use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use tokio::sync::broadcast;
use aa_proto::assembly::gateway::v1::invalidation_event::Payload;
use aa_proto::assembly::gateway::v1::{ApprovalResolved, Decision, InvalidationEvent, PolicyInvalidated};
use aa_runtime::approval::{ApprovalDecision, ApprovalResolvedNotifier};
pub type AssemblyId = String;
const REPLAY_RING_CAPACITY: usize = 1024;
const SUBSCRIBER_CHANNEL_CAPACITY: usize = 1024;
struct Subscriber {
tx: broadcast::Sender<InvalidationEvent>,
next_seq: AtomicU64,
ring: Mutex<VecDeque<InvalidationEvent>>,
}
pub struct SubscriptionHandle {
pub replay: Vec<InvalidationEvent>,
pub receiver: broadcast::Receiver<InvalidationEvent>,
}
#[derive(Default)]
pub struct InvalidationHub {
subscribers: RwLock<HashMap<AssemblyId, Arc<Subscriber>>>,
}
impl InvalidationHub {
pub fn new() -> Arc<Self> {
Arc::new(Self::default())
}
pub fn subscribe(&self, assembly_id: impl Into<AssemblyId>, last_seq_seen: u64) -> SubscriptionHandle {
let assembly_id = assembly_id.into();
let mut subscribers = self
.subscribers
.write()
.expect("invalidation subscribers lock poisoned");
let subscriber = subscribers
.entry(assembly_id)
.or_insert_with(|| {
let (tx, _rx) = broadcast::channel(SUBSCRIBER_CHANNEL_CAPACITY);
Arc::new(Subscriber {
tx,
next_seq: AtomicU64::new(1),
ring: Mutex::new(VecDeque::new()),
})
})
.clone();
let receiver = subscriber.tx.subscribe();
let replay: Vec<InvalidationEvent> = {
let ring = subscriber.ring.lock().expect("replay ring lock poisoned");
ring.iter().filter(|event| event.seq > last_seq_seen).cloned().collect()
};
let subscriber_count = subscribers.len();
drop(subscribers);
if !replay.is_empty() {
metrics::counter!("aa_invalidation_replay_count").increment(replay.len() as u64);
}
metrics::gauge!("aa_invalidation_subscribers").set(subscriber_count as f64);
SubscriptionHandle { replay, receiver }
}
pub fn broadcast_policy_invalidated(&self, agent_id: impl Into<String>, policy_version: u64) {
self.fan_out(Payload::PolicyInvalidated(PolicyInvalidated {
agent_id: agent_id.into(),
policy_version,
}));
}
pub fn broadcast_approval_resolved(&self, request_id: impl Into<String>, decision: Decision) {
self.fan_out(Payload::ApprovalResolved(ApprovalResolved {
request_id: request_id.into(),
decision: decision as i32,
}));
}
fn fan_out(&self, payload: Payload) {
let subscribers = self.subscribers.read().expect("invalidation subscribers lock poisoned");
for subscriber in subscribers.values() {
let seq = subscriber.next_seq.fetch_add(1, Ordering::Relaxed);
let event = InvalidationEvent {
seq,
payload: Some(payload.clone()),
};
{
let mut ring = subscriber.ring.lock().expect("replay ring lock poisoned");
ring.push_back(event.clone());
while ring.len() > REPLAY_RING_CAPACITY {
ring.pop_front();
}
}
let _ = subscriber.tx.send(event);
metrics::counter!("aa_invalidation_events_broadcast").increment(1);
}
}
pub fn ack(&self, assembly_id: &str, seq: u64) {
let subscribers = self.subscribers.read().expect("invalidation subscribers lock poisoned");
if let Some(subscriber) = subscribers.get(assembly_id) {
let mut ring = subscriber.ring.lock().expect("replay ring lock poisoned");
while ring.front().is_some_and(|event| event.seq <= seq) {
ring.pop_front();
}
}
}
pub fn subscriber_count(&self) -> usize {
self.subscribers
.read()
.expect("invalidation subscribers lock poisoned")
.len()
}
}
impl ApprovalResolvedNotifier for InvalidationHub {
fn notify_resolved(&self, request_id: &str, decision: &ApprovalDecision) {
let wire = match decision {
ApprovalDecision::Approved { .. } => Decision::Approved,
ApprovalDecision::Rejected { .. } => Decision::Denied,
ApprovalDecision::TimedOut { .. } => return,
};
self.broadcast_approval_resolved(request_id, wire);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
fn policy_agent(event: &InvalidationEvent) -> &str {
match event.payload.as_ref().expect("payload set") {
Payload::PolicyInvalidated(p) => &p.agent_id,
Payload::ApprovalResolved(_) => panic!("expected PolicyInvalidated"),
}
}
#[tokio::test]
async fn broadcast_reaches_live_subscriber_within_100ms() {
let hub = InvalidationHub::new();
let mut handle = hub.subscribe("asm-1", 0);
assert!(handle.replay.is_empty());
let start = std::time::Instant::now();
hub.broadcast_policy_invalidated("agent-x", 7);
let event = tokio::time::timeout(Duration::from_millis(100), handle.receiver.recv())
.await
.expect("event delivered within 100 ms")
.expect("channel open");
assert!(start.elapsed() < Duration::from_millis(100));
assert_eq!(event.seq, 1);
assert_eq!(policy_agent(&event), "agent-x");
}
#[tokio::test]
async fn reconnect_replays_only_events_after_last_seq() {
let hub = InvalidationHub::new();
let handle = hub.subscribe("asm-1", 0);
drop(handle);
hub.broadcast_policy_invalidated("agent-a", 1);
hub.broadcast_policy_invalidated("agent-b", 2);
let full = hub.subscribe("asm-1", 0);
assert_eq!(full.replay.len(), 2);
assert_eq!(full.replay[0].seq, 1);
assert_eq!(full.replay[1].seq, 2);
let partial = hub.subscribe("asm-1", 1);
assert_eq!(partial.replay.len(), 1);
assert_eq!(partial.replay[0].seq, 2);
assert_eq!(policy_agent(&partial.replay[0]), "agent-b");
}
#[tokio::test]
async fn ack_trims_replay_ring() {
let hub = InvalidationHub::new();
let _handle = hub.subscribe("asm-1", 0);
hub.broadcast_policy_invalidated("agent-a", 1);
hub.broadcast_policy_invalidated("agent-b", 2);
hub.ack("asm-1", 1);
let reconnect = hub.subscribe("asm-1", 0);
assert_eq!(reconnect.replay.len(), 1);
assert_eq!(reconnect.replay[0].seq, 2);
}
#[test]
fn each_subscriber_gets_independent_sequence() {
let hub = InvalidationHub::new();
let _a = hub.subscribe("asm-1", 0);
let _b = hub.subscribe("asm-2", 0);
assert_eq!(hub.subscriber_count(), 2);
hub.broadcast_policy_invalidated("agent-a", 1);
let reconnect_a = hub.subscribe("asm-1", 0);
let reconnect_b = hub.subscribe("asm-2", 0);
assert_eq!(reconnect_a.replay.len(), 1);
assert_eq!(reconnect_b.replay.len(), 1);
assert_eq!(reconnect_a.replay[0].seq, 1);
assert_eq!(reconnect_b.replay[0].seq, 1);
}
#[tokio::test]
async fn broadcast_approval_resolved_reaches_subscriber() {
let hub = InvalidationHub::new();
let mut handle = hub.subscribe("asm-1", 0);
hub.broadcast_approval_resolved("req-42", Decision::Approved);
let event = tokio::time::timeout(Duration::from_millis(100), handle.receiver.recv())
.await
.expect("event delivered within 100 ms")
.expect("channel open");
assert_eq!(event.seq, 1);
match event.payload.expect("payload set") {
Payload::ApprovalResolved(ar) => {
assert_eq!(ar.request_id, "req-42");
assert_eq!(ar.decision(), Decision::Approved);
}
other => panic!("expected ApprovalResolved, got {other:?}"),
}
}
}