1use parking_lot::RwLock;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22use tokio::sync::mpsc;
23
24#[derive(Debug, Clone, Serialize)]
26pub struct EdgeNode {
27 pub edge_id: String,
28 pub region: String,
29 pub base_url: String,
31 pub registered_at: String,
33 pub last_seen: String,
34 pub invalidations_sent: u64,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct InvalidationEvent {
42 pub up_to_version: u64,
44 pub tables: Vec<String>,
47 pub committed_at: String,
50}
51
52struct EdgeSubscription {
55 node: EdgeNode,
56 sender: mpsc::Sender<InvalidationEvent>,
57 last_seen_inst: Instant,
60}
61
62#[derive(Clone)]
64pub struct EdgeRegistry {
65 inner: Arc<RwLock<HashMap<String, EdgeSubscription>>>,
66 max_edges: usize,
67 liveness_window: Duration,
70}
71
72impl EdgeRegistry {
73 pub fn new(max_edges: usize, liveness_window: Duration) -> Self {
74 Self {
75 inner: Arc::new(RwLock::new(HashMap::new())),
76 max_edges,
77 liveness_window,
78 }
79 }
80
81 pub fn register(
90 &self,
91 edge_id: &str,
92 region: &str,
93 base_url: &str,
94 now_iso: &str,
95 ) -> Result<mpsc::Receiver<InvalidationEvent>, RegistryError> {
96 let mut g = self.inner.write();
97 if !g.contains_key(edge_id) && g.len() >= self.max_edges {
98 return Err(RegistryError::CapacityExceeded(self.max_edges));
99 }
100 let (tx, rx) = mpsc::channel(64);
101 let sub = EdgeSubscription {
102 node: EdgeNode {
103 edge_id: edge_id.to_string(),
104 region: region.to_string(),
105 base_url: base_url.to_string(),
106 registered_at: now_iso.to_string(),
107 last_seen: now_iso.to_string(),
108 invalidations_sent: 0,
109 },
110 sender: tx,
111 last_seen_inst: Instant::now(),
112 };
113 g.insert(edge_id.to_string(), sub);
114 Ok(rx)
115 }
116
117 pub fn unregister(&self, edge_id: &str) -> bool {
120 self.inner.write().remove(edge_id).is_some()
121 }
122
123 pub async fn broadcast(&self, ev: InvalidationEvent) -> (u32, u32) {
127 let recipients: Vec<(String, mpsc::Sender<InvalidationEvent>)> = {
130 let g = self.inner.read();
131 g.iter()
132 .map(|(id, sub)| (id.clone(), sub.sender.clone()))
133 .collect()
134 };
135
136 let mut sent = 0u32;
137 let mut dead: Vec<String> = Vec::new();
138 for (id, tx) in recipients {
139 match tx.send(ev.clone()).await {
140 Ok(()) => {
141 sent += 1;
142 }
143 Err(_) => {
144 dead.push(id);
145 }
146 }
147 }
148
149 let mut g = self.inner.write();
152 for id in &dead {
153 g.remove(id);
154 }
155 for sub in g.values_mut() {
156 sub.node.invalidations_sent =
157 sub.node.invalidations_sent.saturating_add(1);
158 sub.last_seen_inst = Instant::now();
159 }
160 (sent, dead.len() as u32)
161 }
162
163 pub fn list(&self) -> Vec<EdgeNode> {
166 self.inner
167 .read()
168 .values()
169 .map(|s| s.node.clone())
170 .collect()
171 }
172
173 pub fn count(&self) -> usize {
174 self.inner.read().len()
175 }
176
177 pub fn prune_stale(&self) -> u32 {
180 let cutoff = Instant::now() - self.liveness_window;
181 let mut g = self.inner.write();
182 let dead: Vec<String> = g
183 .iter()
184 .filter(|(_, s)| s.last_seen_inst < cutoff)
185 .map(|(id, _)| id.clone())
186 .collect();
187 for id in &dead {
188 g.remove(id);
189 }
190 dead.len() as u32
191 }
192}
193
194#[derive(Debug, Clone, PartialEq, Eq)]
195pub enum RegistryError {
196 CapacityExceeded(usize),
197}
198
199impl std::fmt::Display for RegistryError {
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 match self {
202 RegistryError::CapacityExceeded(n) => {
203 write!(f, "edge registry full (max {})", n)
204 }
205 }
206 }
207}
208
209impl std::error::Error for RegistryError {}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[tokio::test]
216 async fn register_returns_receiver_with_invalidations() {
217 let r = EdgeRegistry::new(10, Duration::from_secs(60));
218 let mut rx = r.register("edge-1", "us-east", "https://e1", "ts").unwrap();
219 assert_eq!(r.count(), 1);
220 let (sent, pruned) = r
221 .broadcast(InvalidationEvent {
222 up_to_version: 5,
223 tables: vec!["users".into()],
224 committed_at: "ts".into(),
225 })
226 .await;
227 assert_eq!(sent, 1);
228 assert_eq!(pruned, 0);
229 let ev = rx.recv().await.expect("receive");
230 assert_eq!(ev.up_to_version, 5);
231 assert_eq!(ev.tables, vec!["users".to_string()]);
232 }
233
234 #[tokio::test]
235 async fn broadcast_prunes_dropped_receivers() {
236 let r = EdgeRegistry::new(10, Duration::from_secs(60));
237 let _rx_keep = r.register("edge-keep", "us-east", "u", "ts").unwrap();
238 {
239 let _rx_drop = r.register("edge-drop", "us-west", "u", "ts").unwrap();
240 }
242 let (sent, pruned) = r
243 .broadcast(InvalidationEvent {
244 up_to_version: 1,
245 tables: vec![],
246 committed_at: "ts".into(),
247 })
248 .await;
249 assert_eq!(sent, 1);
250 assert_eq!(pruned, 1);
251 assert_eq!(r.count(), 1);
252 }
253
254 #[test]
255 fn register_rejects_when_at_capacity() {
256 let r = EdgeRegistry::new(2, Duration::from_secs(60));
257 let _a = r.register("a", "us-east", "u", "ts").unwrap();
258 let _b = r.register("b", "us-west", "u", "ts").unwrap();
259 let err = r.register("c", "eu-west", "u", "ts").unwrap_err();
260 assert!(matches!(err, RegistryError::CapacityExceeded(2)));
261 }
262
263 #[test]
264 fn register_replaces_existing_id() {
265 let r = EdgeRegistry::new(2, Duration::from_secs(60));
266 let _a1 = r.register("a", "us-east", "u", "t1").unwrap();
267 let _a2 = r.register("a", "eu-west", "u", "t2").unwrap();
270 assert_eq!(r.count(), 1);
271 let nodes = r.list();
272 assert_eq!(nodes[0].region, "eu-west");
273 }
274
275 #[test]
276 fn unregister_removes_edge() {
277 let r = EdgeRegistry::new(10, Duration::from_secs(60));
278 let _rx = r.register("edge-1", "us-east", "u", "ts").unwrap();
279 assert!(r.unregister("edge-1"));
280 assert_eq!(r.count(), 0);
281 assert!(!r.unregister("edge-1"));
283 }
284
285 #[test]
286 fn list_returns_snapshot() {
287 let r = EdgeRegistry::new(10, Duration::from_secs(60));
288 let _a = r.register("a", "r1", "u1", "ts").unwrap();
289 let _b = r.register("b", "r2", "u2", "ts").unwrap();
290 let mut nodes = r.list();
291 nodes.sort_by(|a, b| a.edge_id.cmp(&b.edge_id));
292 assert_eq!(nodes.len(), 2);
293 assert_eq!(nodes[0].edge_id, "a");
294 assert_eq!(nodes[1].edge_id, "b");
295 }
296
297 #[tokio::test]
298 async fn invalidations_sent_counter_increments() {
299 let r = EdgeRegistry::new(10, Duration::from_secs(60));
300 let mut _rx = r.register("e1", "r", "u", "ts").unwrap();
301 for _ in 0..3 {
302 let _ = r
303 .broadcast(InvalidationEvent {
304 up_to_version: 1,
305 tables: vec![],
306 committed_at: "ts".into(),
307 })
308 .await;
309 }
310 let n = r.list();
311 assert_eq!(n[0].invalidations_sent, 3);
312 }
313
314 #[test]
315 fn prune_stale_removes_old_entries() {
316 let r = EdgeRegistry::new(10, Duration::from_millis(10));
317 let _rx = r.register("old", "r", "u", "ts").unwrap();
318 std::thread::sleep(Duration::from_millis(20));
319 let pruned = r.prune_stale();
320 assert_eq!(pruned, 1);
321 assert_eq!(r.count(), 0);
322 }
323}