use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
#[derive(Debug, Clone, Serialize)]
pub struct EdgeNode {
pub edge_id: String,
pub region: String,
pub base_url: String,
pub registered_at: String,
pub last_seen: String,
pub invalidations_sent: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InvalidationEvent {
pub up_to_version: u64,
pub tables: Vec<String>,
pub committed_at: String,
}
struct EdgeSubscription {
node: EdgeNode,
sender: mpsc::Sender<InvalidationEvent>,
last_seen_inst: Instant,
}
#[derive(Clone)]
pub struct EdgeRegistry {
inner: Arc<RwLock<HashMap<String, EdgeSubscription>>>,
max_edges: usize,
liveness_window: Duration,
}
impl EdgeRegistry {
pub fn new(max_edges: usize, liveness_window: Duration) -> Self {
Self {
inner: Arc::new(RwLock::new(HashMap::new())),
max_edges,
liveness_window,
}
}
pub fn register(
&self,
edge_id: &str,
region: &str,
base_url: &str,
now_iso: &str,
) -> Result<mpsc::Receiver<InvalidationEvent>, RegistryError> {
let mut g = self.inner.write();
if !g.contains_key(edge_id) && g.len() >= self.max_edges {
return Err(RegistryError::CapacityExceeded(self.max_edges));
}
let (tx, rx) = mpsc::channel(64);
let sub = EdgeSubscription {
node: EdgeNode {
edge_id: edge_id.to_string(),
region: region.to_string(),
base_url: base_url.to_string(),
registered_at: now_iso.to_string(),
last_seen: now_iso.to_string(),
invalidations_sent: 0,
},
sender: tx,
last_seen_inst: Instant::now(),
};
g.insert(edge_id.to_string(), sub);
Ok(rx)
}
pub fn unregister(&self, edge_id: &str) -> bool {
self.inner.write().remove(edge_id).is_some()
}
pub async fn broadcast(&self, ev: InvalidationEvent) -> (u32, u32) {
let recipients: Vec<(String, mpsc::Sender<InvalidationEvent>)> = {
let g = self.inner.read();
g.iter()
.map(|(id, sub)| (id.clone(), sub.sender.clone()))
.collect()
};
let mut sent = 0u32;
let mut dead: Vec<String> = Vec::new();
for (id, tx) in recipients {
match tx.send(ev.clone()).await {
Ok(()) => {
sent += 1;
}
Err(_) => {
dead.push(id);
}
}
}
let mut g = self.inner.write();
for id in &dead {
g.remove(id);
}
for sub in g.values_mut() {
sub.node.invalidations_sent =
sub.node.invalidations_sent.saturating_add(1);
sub.last_seen_inst = Instant::now();
}
(sent, dead.len() as u32)
}
pub fn list(&self) -> Vec<EdgeNode> {
self.inner
.read()
.values()
.map(|s| s.node.clone())
.collect()
}
pub fn count(&self) -> usize {
self.inner.read().len()
}
pub fn prune_stale(&self) -> u32 {
let cutoff = Instant::now() - self.liveness_window;
let mut g = self.inner.write();
let dead: Vec<String> = g
.iter()
.filter(|(_, s)| s.last_seen_inst < cutoff)
.map(|(id, _)| id.clone())
.collect();
for id in &dead {
g.remove(id);
}
dead.len() as u32
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RegistryError {
CapacityExceeded(usize),
}
impl std::fmt::Display for RegistryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RegistryError::CapacityExceeded(n) => {
write!(f, "edge registry full (max {})", n)
}
}
}
}
impl std::error::Error for RegistryError {}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn register_returns_receiver_with_invalidations() {
let r = EdgeRegistry::new(10, Duration::from_secs(60));
let mut rx = r.register("edge-1", "us-east", "https://e1", "ts").unwrap();
assert_eq!(r.count(), 1);
let (sent, pruned) = r
.broadcast(InvalidationEvent {
up_to_version: 5,
tables: vec!["users".into()],
committed_at: "ts".into(),
})
.await;
assert_eq!(sent, 1);
assert_eq!(pruned, 0);
let ev = rx.recv().await.expect("receive");
assert_eq!(ev.up_to_version, 5);
assert_eq!(ev.tables, vec!["users".to_string()]);
}
#[tokio::test]
async fn broadcast_prunes_dropped_receivers() {
let r = EdgeRegistry::new(10, Duration::from_secs(60));
let _rx_keep = r.register("edge-keep", "us-east", "u", "ts").unwrap();
{
let _rx_drop = r.register("edge-drop", "us-west", "u", "ts").unwrap();
}
let (sent, pruned) = r
.broadcast(InvalidationEvent {
up_to_version: 1,
tables: vec![],
committed_at: "ts".into(),
})
.await;
assert_eq!(sent, 1);
assert_eq!(pruned, 1);
assert_eq!(r.count(), 1);
}
#[test]
fn register_rejects_when_at_capacity() {
let r = EdgeRegistry::new(2, Duration::from_secs(60));
let _a = r.register("a", "us-east", "u", "ts").unwrap();
let _b = r.register("b", "us-west", "u", "ts").unwrap();
let err = r.register("c", "eu-west", "u", "ts").unwrap_err();
assert!(matches!(err, RegistryError::CapacityExceeded(2)));
}
#[test]
fn register_replaces_existing_id() {
let r = EdgeRegistry::new(2, Duration::from_secs(60));
let _a1 = r.register("a", "us-east", "u", "t1").unwrap();
let _a2 = r.register("a", "eu-west", "u", "t2").unwrap();
assert_eq!(r.count(), 1);
let nodes = r.list();
assert_eq!(nodes[0].region, "eu-west");
}
#[test]
fn unregister_removes_edge() {
let r = EdgeRegistry::new(10, Duration::from_secs(60));
let _rx = r.register("edge-1", "us-east", "u", "ts").unwrap();
assert!(r.unregister("edge-1"));
assert_eq!(r.count(), 0);
assert!(!r.unregister("edge-1"));
}
#[test]
fn list_returns_snapshot() {
let r = EdgeRegistry::new(10, Duration::from_secs(60));
let _a = r.register("a", "r1", "u1", "ts").unwrap();
let _b = r.register("b", "r2", "u2", "ts").unwrap();
let mut nodes = r.list();
nodes.sort_by(|a, b| a.edge_id.cmp(&b.edge_id));
assert_eq!(nodes.len(), 2);
assert_eq!(nodes[0].edge_id, "a");
assert_eq!(nodes[1].edge_id, "b");
}
#[tokio::test]
async fn invalidations_sent_counter_increments() {
let r = EdgeRegistry::new(10, Duration::from_secs(60));
let mut _rx = r.register("e1", "r", "u", "ts").unwrap();
for _ in 0..3 {
let _ = r
.broadcast(InvalidationEvent {
up_to_version: 1,
tables: vec![],
committed_at: "ts".into(),
})
.await;
}
let n = r.list();
assert_eq!(n[0].invalidations_sent, 3);
}
#[test]
fn prune_stale_removes_old_entries() {
let r = EdgeRegistry::new(10, Duration::from_millis(10));
let _rx = r.register("old", "r", "u", "ts").unwrap();
std::thread::sleep(Duration::from_millis(20));
let pruned = r.prune_stale();
assert_eq!(pruned, 1);
assert_eq!(r.count(), 0);
}
}