1use parking_lot::RwLock;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22use tokio::sync::mpsc;
23
24#[derive(Debug, Clone, Serialize)]
26pub struct EdgeNode {
27 pub edge_id: String,
28 pub region: String,
29 pub base_url: String,
31 pub registered_at: String,
33 pub last_seen: String,
34 pub invalidations_sent: u64,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct InvalidationEvent {
42 pub up_to_version: u64,
44 pub tables: Vec<String>,
47 pub committed_at: String,
50}
51
52struct EdgeSubscription {
55 node: EdgeNode,
56 sender: mpsc::Sender<InvalidationEvent>,
57 last_seen_inst: Instant,
60}
61
62#[derive(Clone)]
64pub struct EdgeRegistry {
65 inner: Arc<RwLock<HashMap<String, EdgeSubscription>>>,
66 max_edges: usize,
67 liveness_window: Duration,
70}
71
72impl EdgeRegistry {
73 pub fn new(max_edges: usize, liveness_window: Duration) -> Self {
74 Self {
75 inner: Arc::new(RwLock::new(HashMap::new())),
76 max_edges,
77 liveness_window,
78 }
79 }
80
81 pub fn register(
90 &self,
91 edge_id: &str,
92 region: &str,
93 base_url: &str,
94 now_iso: &str,
95 ) -> Result<mpsc::Receiver<InvalidationEvent>, RegistryError> {
96 let mut g = self.inner.write();
97 if !g.contains_key(edge_id) && g.len() >= self.max_edges {
98 return Err(RegistryError::CapacityExceeded(self.max_edges));
99 }
100 let (tx, rx) = mpsc::channel(64);
101 let sub = EdgeSubscription {
102 node: EdgeNode {
103 edge_id: edge_id.to_string(),
104 region: region.to_string(),
105 base_url: base_url.to_string(),
106 registered_at: now_iso.to_string(),
107 last_seen: now_iso.to_string(),
108 invalidations_sent: 0,
109 },
110 sender: tx,
111 last_seen_inst: Instant::now(),
112 };
113 g.insert(edge_id.to_string(), sub);
114 Ok(rx)
115 }
116
117 pub fn unregister(&self, edge_id: &str) -> bool {
120 self.inner.write().remove(edge_id).is_some()
121 }
122
123 pub async fn broadcast(&self, ev: InvalidationEvent) -> (u32, u32) {
127 let recipients: Vec<(String, mpsc::Sender<InvalidationEvent>)> = {
130 let g = self.inner.read();
131 g.iter()
132 .map(|(id, sub)| (id.clone(), sub.sender.clone()))
133 .collect()
134 };
135
136 let mut sent = 0u32;
137 let mut dead: Vec<String> = Vec::new();
138 for (id, tx) in recipients {
139 match tx.send(ev.clone()).await {
140 Ok(()) => {
141 sent += 1;
142 }
143 Err(_) => {
144 dead.push(id);
145 }
146 }
147 }
148
149 let mut g = self.inner.write();
152 for id in &dead {
153 g.remove(id);
154 }
155 for sub in g.values_mut() {
156 sub.node.invalidations_sent = sub.node.invalidations_sent.saturating_add(1);
157 sub.last_seen_inst = Instant::now();
158 }
159 (sent, dead.len() as u32)
160 }
161
162 pub fn list(&self) -> Vec<EdgeNode> {
165 self.inner.read().values().map(|s| s.node.clone()).collect()
166 }
167
168 pub fn count(&self) -> usize {
169 self.inner.read().len()
170 }
171
172 pub fn prune_stale(&self) -> u32 {
175 let cutoff = Instant::now() - self.liveness_window;
176 let mut g = self.inner.write();
177 let dead: Vec<String> = g
178 .iter()
179 .filter(|(_, s)| s.last_seen_inst < cutoff)
180 .map(|(id, _)| id.clone())
181 .collect();
182 for id in &dead {
183 g.remove(id);
184 }
185 dead.len() as u32
186 }
187}
188
189#[derive(Debug, Clone, PartialEq, Eq)]
190pub enum RegistryError {
191 CapacityExceeded(usize),
192}
193
194impl std::fmt::Display for RegistryError {
195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 match self {
197 RegistryError::CapacityExceeded(n) => {
198 write!(f, "edge registry full (max {})", n)
199 }
200 }
201 }
202}
203
204impl std::error::Error for RegistryError {}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[tokio::test]
211 async fn register_returns_receiver_with_invalidations() {
212 let r = EdgeRegistry::new(10, Duration::from_secs(60));
213 let mut rx = r.register("edge-1", "us-east", "https://e1", "ts").unwrap();
214 assert_eq!(r.count(), 1);
215 let (sent, pruned) = r
216 .broadcast(InvalidationEvent {
217 up_to_version: 5,
218 tables: vec!["users".into()],
219 committed_at: "ts".into(),
220 })
221 .await;
222 assert_eq!(sent, 1);
223 assert_eq!(pruned, 0);
224 let ev = rx.recv().await.expect("receive");
225 assert_eq!(ev.up_to_version, 5);
226 assert_eq!(ev.tables, vec!["users".to_string()]);
227 }
228
229 #[tokio::test]
230 async fn broadcast_prunes_dropped_receivers() {
231 let r = EdgeRegistry::new(10, Duration::from_secs(60));
232 let _rx_keep = r.register("edge-keep", "us-east", "u", "ts").unwrap();
233 {
234 let _rx_drop = r.register("edge-drop", "us-west", "u", "ts").unwrap();
235 }
237 let (sent, pruned) = r
238 .broadcast(InvalidationEvent {
239 up_to_version: 1,
240 tables: vec![],
241 committed_at: "ts".into(),
242 })
243 .await;
244 assert_eq!(sent, 1);
245 assert_eq!(pruned, 1);
246 assert_eq!(r.count(), 1);
247 }
248
249 #[test]
250 fn register_rejects_when_at_capacity() {
251 let r = EdgeRegistry::new(2, Duration::from_secs(60));
252 let _a = r.register("a", "us-east", "u", "ts").unwrap();
253 let _b = r.register("b", "us-west", "u", "ts").unwrap();
254 let err = r.register("c", "eu-west", "u", "ts").unwrap_err();
255 assert!(matches!(err, RegistryError::CapacityExceeded(2)));
256 }
257
258 #[test]
259 fn register_replaces_existing_id() {
260 let r = EdgeRegistry::new(2, Duration::from_secs(60));
261 let _a1 = r.register("a", "us-east", "u", "t1").unwrap();
262 let _a2 = r.register("a", "eu-west", "u", "t2").unwrap();
265 assert_eq!(r.count(), 1);
266 let nodes = r.list();
267 assert_eq!(nodes[0].region, "eu-west");
268 }
269
270 #[test]
271 fn unregister_removes_edge() {
272 let r = EdgeRegistry::new(10, Duration::from_secs(60));
273 let _rx = r.register("edge-1", "us-east", "u", "ts").unwrap();
274 assert!(r.unregister("edge-1"));
275 assert_eq!(r.count(), 0);
276 assert!(!r.unregister("edge-1"));
278 }
279
280 #[test]
281 fn list_returns_snapshot() {
282 let r = EdgeRegistry::new(10, Duration::from_secs(60));
283 let _a = r.register("a", "r1", "u1", "ts").unwrap();
284 let _b = r.register("b", "r2", "u2", "ts").unwrap();
285 let mut nodes = r.list();
286 nodes.sort_by(|a, b| a.edge_id.cmp(&b.edge_id));
287 assert_eq!(nodes.len(), 2);
288 assert_eq!(nodes[0].edge_id, "a");
289 assert_eq!(nodes[1].edge_id, "b");
290 }
291
292 #[tokio::test]
293 async fn invalidations_sent_counter_increments() {
294 let r = EdgeRegistry::new(10, Duration::from_secs(60));
295 let mut _rx = r.register("e1", "r", "u", "ts").unwrap();
296 for _ in 0..3 {
297 let _ = r
298 .broadcast(InvalidationEvent {
299 up_to_version: 1,
300 tables: vec![],
301 committed_at: "ts".into(),
302 })
303 .await;
304 }
305 let n = r.list();
306 assert_eq!(n[0].invalidations_sent, 3);
307 }
308
309 #[test]
310 fn prune_stale_removes_old_entries() {
311 let r = EdgeRegistry::new(10, Duration::from_millis(10));
312 let _rx = r.register("old", "r", "u", "ts").unwrap();
313 std::thread::sleep(Duration::from_millis(20));
314 let pruned = r.prune_stale();
315 assert_eq!(pruned, 1);
316 assert_eq!(r.count(), 0);
317 }
318}