Skip to main content

heliosdb_proxy/edge/
registry.rs

1//! Home-side registry of subscribed edges + invalidation broadcast.
2//!
3//! When an edge boots in `EdgeRole::Edge`, it calls home's
4//! `POST /api/edge/register` once at startup. The home stores the
5//! edge's address + auth token and adds it to the broadcast set.
6//!
7//! On every committed write, the home calls `broadcast` with an
8//! `InvalidationEvent { up_to_version, tables }`. Each registered
9//! edge receives a copy via the SSE channel that `register` opened.
10//!
11//! Edges that fail to ack within `ack_timeout` are removed from the
12//! set — the home does *not* retry forever. A missed invalidation
13//! degrades correctness only as far as the cache TTL: stale entries
14//! age out within `default_ttl`. That's the explicit "bounded
15//! staleness" contract from the module doc.
16
17use parking_lot::RwLock;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22use tokio::sync::mpsc;
23
24/// One registered edge node from the home's perspective.
25#[derive(Debug, Clone, Serialize)]
26pub struct EdgeNode {
27    pub edge_id: String,
28    pub region: String,
29    /// HTTP base URL the home pings for ack-checks.
30    pub base_url: String,
31    /// First-seen + last-acked timestamps.
32    pub registered_at: String,
33    pub last_seen: String,
34    /// Total invalidations broadcast to this edge.
35    pub invalidations_sent: u64,
36}
37
38/// Wire shape of an invalidation. Carried over the SSE channel from
39/// home to every edge.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct InvalidationEvent {
42    /// Logical version assigned by the home at write commit.
43    pub up_to_version: u64,
44    /// Tables touched. Empty = invalidate every cached entry within
45    /// the version bound.
46    pub tables: Vec<String>,
47    /// Wall-clock at which the home committed the write — useful for
48    /// log correlation.
49    pub committed_at: String,
50}
51
52/// Per-edge in-process channel the registry pushes events into.
53/// Each edge holds the matching receiver in its SSE connection.
54struct EdgeSubscription {
55    node: EdgeNode,
56    sender: mpsc::Sender<InvalidationEvent>,
57    /// last_seen as Instant for liveness check (the public node
58    /// stringifies this too).
59    last_seen_inst: Instant,
60}
61
62/// Home-side edge registry. Cheap to clone via Arc.
63#[derive(Clone)]
64pub struct EdgeRegistry {
65    inner: Arc<RwLock<HashMap<String, EdgeSubscription>>>,
66    max_edges: usize,
67    /// Edges that don't ack within this window get expired on the
68    /// next broadcast pass.
69    liveness_window: Duration,
70}
71
72impl EdgeRegistry {
73    pub fn new(max_edges: usize, liveness_window: Duration) -> Self {
74        Self {
75            inner: Arc::new(RwLock::new(HashMap::new())),
76            max_edges,
77            liveness_window,
78        }
79    }
80
81    /// Register a new edge. Returns the receiver the SSE handler
82    /// holds open. Caller is responsible for keeping the receiver
83    /// alive — when it drops, the next broadcast prunes the edge.
84    ///
85    /// The channel is bounded: a slow edge that doesn't drain
86    /// fast enough back-pressures into the broadcast call (which
87    /// is async). Default capacity 64 events lets bursts ride
88    /// through without dropping.
89    pub fn register(
90        &self,
91        edge_id: &str,
92        region: &str,
93        base_url: &str,
94        now_iso: &str,
95    ) -> Result<mpsc::Receiver<InvalidationEvent>, RegistryError> {
96        let mut g = self.inner.write();
97        if !g.contains_key(edge_id) && g.len() >= self.max_edges {
98            return Err(RegistryError::CapacityExceeded(self.max_edges));
99        }
100        let (tx, rx) = mpsc::channel(64);
101        let sub = EdgeSubscription {
102            node: EdgeNode {
103                edge_id: edge_id.to_string(),
104                region: region.to_string(),
105                base_url: base_url.to_string(),
106                registered_at: now_iso.to_string(),
107                last_seen: now_iso.to_string(),
108                invalidations_sent: 0,
109            },
110            sender: tx,
111            last_seen_inst: Instant::now(),
112        };
113        g.insert(edge_id.to_string(), sub);
114        Ok(rx)
115    }
116
117    /// Remove an edge — used when the home decides to evict
118    /// (manual unregister, or cleanup during shutdown).
119    pub fn unregister(&self, edge_id: &str) -> bool {
120        self.inner.write().remove(edge_id).is_some()
121    }
122
123    /// Broadcast an invalidation to every subscribed edge. Edges
124    /// whose channel has closed (receiver dropped) are pruned.
125    /// Returns (sent, pruned).
126    pub async fn broadcast(&self, ev: InvalidationEvent) -> (u32, u32) {
127        // Snapshot recipients under the lock, then send outside it
128        // so we don't hold the write lock across await points.
129        let recipients: Vec<(String, mpsc::Sender<InvalidationEvent>)> = {
130            let g = self.inner.read();
131            g.iter()
132                .map(|(id, sub)| (id.clone(), sub.sender.clone()))
133                .collect()
134        };
135
136        let mut sent = 0u32;
137        let mut dead: Vec<String> = Vec::new();
138        for (id, tx) in recipients {
139            match tx.send(ev.clone()).await {
140                Ok(()) => {
141                    sent += 1;
142                }
143                Err(_) => {
144                    dead.push(id);
145                }
146            }
147        }
148
149        // Prune closed channels + bump per-edge counters under the
150        // write lock.
151        let mut g = self.inner.write();
152        for id in &dead {
153            g.remove(id);
154        }
155        for sub in g.values_mut() {
156            sub.node.invalidations_sent =
157                sub.node.invalidations_sent.saturating_add(1);
158            sub.last_seen_inst = Instant::now();
159        }
160        (sent, dead.len() as u32)
161    }
162
163    /// Read-only snapshot of currently-registered edges. Used by
164    /// the admin UI / `/api/edge` endpoint.
165    pub fn list(&self) -> Vec<EdgeNode> {
166        self.inner
167            .read()
168            .values()
169            .map(|s| s.node.clone())
170            .collect()
171    }
172
173    pub fn count(&self) -> usize {
174        self.inner.read().len()
175    }
176
177    /// Garbage-collect edges that haven't been seen within
178    /// `liveness_window`. Returns the count pruned.
179    pub fn prune_stale(&self) -> u32 {
180        let cutoff = Instant::now() - self.liveness_window;
181        let mut g = self.inner.write();
182        let dead: Vec<String> = g
183            .iter()
184            .filter(|(_, s)| s.last_seen_inst < cutoff)
185            .map(|(id, _)| id.clone())
186            .collect();
187        for id in &dead {
188            g.remove(id);
189        }
190        dead.len() as u32
191    }
192}
193
194#[derive(Debug, Clone, PartialEq, Eq)]
195pub enum RegistryError {
196    CapacityExceeded(usize),
197}
198
199impl std::fmt::Display for RegistryError {
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        match self {
202            RegistryError::CapacityExceeded(n) => {
203                write!(f, "edge registry full (max {})", n)
204            }
205        }
206    }
207}
208
209impl std::error::Error for RegistryError {}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[tokio::test]
216    async fn register_returns_receiver_with_invalidations() {
217        let r = EdgeRegistry::new(10, Duration::from_secs(60));
218        let mut rx = r.register("edge-1", "us-east", "https://e1", "ts").unwrap();
219        assert_eq!(r.count(), 1);
220        let (sent, pruned) = r
221            .broadcast(InvalidationEvent {
222                up_to_version: 5,
223                tables: vec!["users".into()],
224                committed_at: "ts".into(),
225            })
226            .await;
227        assert_eq!(sent, 1);
228        assert_eq!(pruned, 0);
229        let ev = rx.recv().await.expect("receive");
230        assert_eq!(ev.up_to_version, 5);
231        assert_eq!(ev.tables, vec!["users".to_string()]);
232    }
233
234    #[tokio::test]
235    async fn broadcast_prunes_dropped_receivers() {
236        let r = EdgeRegistry::new(10, Duration::from_secs(60));
237        let _rx_keep = r.register("edge-keep", "us-east", "u", "ts").unwrap();
238        {
239            let _rx_drop = r.register("edge-drop", "us-west", "u", "ts").unwrap();
240            // _rx_drop dropped at the end of this scope.
241        }
242        let (sent, pruned) = r
243            .broadcast(InvalidationEvent {
244                up_to_version: 1,
245                tables: vec![],
246                committed_at: "ts".into(),
247            })
248            .await;
249        assert_eq!(sent, 1);
250        assert_eq!(pruned, 1);
251        assert_eq!(r.count(), 1);
252    }
253
254    #[test]
255    fn register_rejects_when_at_capacity() {
256        let r = EdgeRegistry::new(2, Duration::from_secs(60));
257        let _a = r.register("a", "us-east", "u", "ts").unwrap();
258        let _b = r.register("b", "us-west", "u", "ts").unwrap();
259        let err = r.register("c", "eu-west", "u", "ts").unwrap_err();
260        assert!(matches!(err, RegistryError::CapacityExceeded(2)));
261    }
262
263    #[test]
264    fn register_replaces_existing_id() {
265        let r = EdgeRegistry::new(2, Duration::from_secs(60));
266        let _a1 = r.register("a", "us-east", "u", "t1").unwrap();
267        // Re-register with same id under a different region — replaces
268        // the slot, count stays the same.
269        let _a2 = r.register("a", "eu-west", "u", "t2").unwrap();
270        assert_eq!(r.count(), 1);
271        let nodes = r.list();
272        assert_eq!(nodes[0].region, "eu-west");
273    }
274
275    #[test]
276    fn unregister_removes_edge() {
277        let r = EdgeRegistry::new(10, Duration::from_secs(60));
278        let _rx = r.register("edge-1", "us-east", "u", "ts").unwrap();
279        assert!(r.unregister("edge-1"));
280        assert_eq!(r.count(), 0);
281        // Idempotent.
282        assert!(!r.unregister("edge-1"));
283    }
284
285    #[test]
286    fn list_returns_snapshot() {
287        let r = EdgeRegistry::new(10, Duration::from_secs(60));
288        let _a = r.register("a", "r1", "u1", "ts").unwrap();
289        let _b = r.register("b", "r2", "u2", "ts").unwrap();
290        let mut nodes = r.list();
291        nodes.sort_by(|a, b| a.edge_id.cmp(&b.edge_id));
292        assert_eq!(nodes.len(), 2);
293        assert_eq!(nodes[0].edge_id, "a");
294        assert_eq!(nodes[1].edge_id, "b");
295    }
296
297    #[tokio::test]
298    async fn invalidations_sent_counter_increments() {
299        let r = EdgeRegistry::new(10, Duration::from_secs(60));
300        let mut _rx = r.register("e1", "r", "u", "ts").unwrap();
301        for _ in 0..3 {
302            let _ = r
303                .broadcast(InvalidationEvent {
304                    up_to_version: 1,
305                    tables: vec![],
306                    committed_at: "ts".into(),
307                })
308                .await;
309        }
310        let n = r.list();
311        assert_eq!(n[0].invalidations_sent, 3);
312    }
313
314    #[test]
315    fn prune_stale_removes_old_entries() {
316        let r = EdgeRegistry::new(10, Duration::from_millis(10));
317        let _rx = r.register("old", "r", "u", "ts").unwrap();
318        std::thread::sleep(Duration::from_millis(20));
319        let pruned = r.prune_stale();
320        assert_eq!(pruned, 1);
321        assert_eq!(r.count(), 0);
322    }
323}