Skip to main content

xds_cache/
watch.rs

1//! Watch system for cache update notifications.
2//!
3//! The watch system provides:
4//! - Unique watch identifiers ([`WatchId`])
5//! - Watch subscriptions ([`Watch`]) for receiving updates
6//! - Watch management ([`WatchManager`]) for handling multiple subscriptions
7
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11
12use parking_lot::Mutex;
13use tokio::sync::mpsc;
14use tracing::{debug, trace, warn};
15use xds_core::{NodeHash, XdsError, XdsResult};
16
17use crate::Snapshot;
18
19/// Unique identifier for a watch subscription.
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub struct WatchId(u64);
22
23impl WatchId {
24    /// Create a new unique watch ID.
25    fn next() -> Self {
26        static COUNTER: AtomicU64 = AtomicU64::new(1);
27        Self(COUNTER.fetch_add(1, Ordering::Relaxed))
28    }
29
30    /// Get the numeric value of this watch ID.
31    #[inline]
32    pub fn as_u64(&self) -> u64 {
33        self.0
34    }
35}
36
37impl std::fmt::Display for WatchId {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        write!(f, "watch-{}", self.0)
40    }
41}
42
43/// A watch subscription for receiving snapshot updates.
44///
45/// When a snapshot is updated for a node, all active watches for that node
46/// receive the new snapshot through their channel.
47#[derive(Debug)]
48pub struct Watch {
49    /// Unique identifier for this watch.
50    id: WatchId,
51    /// Node this watch is subscribed to.
52    node_hash: NodeHash,
53    /// Receiver for snapshot updates.
54    receiver: mpsc::Receiver<Arc<Snapshot>>,
55}
56
57impl Watch {
58    /// Get the unique identifier for this watch.
59    #[inline]
60    pub fn id(&self) -> WatchId {
61        self.id
62    }
63
64    /// Get the node hash this watch is subscribed to.
65    #[inline]
66    pub fn node_hash(&self) -> NodeHash {
67        self.node_hash
68    }
69
70    /// Receive the next snapshot update.
71    ///
72    /// Returns `None` if the watch has been cancelled.
73    pub async fn recv(&mut self) -> Option<Arc<Snapshot>> {
74        self.receiver.recv().await
75    }
76
77    /// Try to receive a snapshot update without waiting.
78    ///
79    /// Returns:
80    /// - `Ok(snapshot)` if an update is available
81    /// - `Err(TryRecvError::Empty)` if no update is available
82    /// - `Err(TryRecvError::Disconnected)` if the watch has been cancelled
83    pub fn try_recv(&mut self) -> Result<Arc<Snapshot>, mpsc::error::TryRecvError> {
84        self.receiver.try_recv()
85    }
86}
87
88/// Sender half of a watch, used internally to send updates.
89#[derive(Debug, Clone)]
90#[allow(dead_code)] // Used for debugging and future features
91pub(crate) struct WatchSender {
92    id: WatchId,
93    node_hash: NodeHash,
94    sender: mpsc::Sender<Arc<Snapshot>>,
95}
96
97#[allow(dead_code)] // Methods used for debugging and future features
98impl WatchSender {
99    /// Try to send a snapshot update.
100    ///
101    /// Uses `try_send` to avoid blocking. If the channel is full,
102    /// the update is dropped (the receiver will get the next one).
103    pub fn try_send(&self, snapshot: Arc<Snapshot>) -> XdsResult<()> {
104        match self.sender.try_send(snapshot) {
105            Ok(()) => Ok(()),
106            Err(mpsc::error::TrySendError::Full(_)) => {
107                // Channel full, skip this update
108                trace!(watch_id = %self.id, "watch channel full, skipping update");
109                Ok(())
110            }
111            Err(mpsc::error::TrySendError::Closed(_)) => Err(XdsError::WatchClosed {
112                watch_id: self.id.0,
113            }),
114        }
115    }
116
117    /// Get the watch ID.
118    #[inline]
119    pub fn id(&self) -> WatchId {
120        self.id
121    }
122}
123
124/// Manager for watch subscriptions.
125///
126/// Handles creating, tracking, and cancelling watches.
127/// Uses a `Mutex` internally but operations are fast (no I/O).
128#[derive(Debug)]
129pub struct WatchManager {
130    /// Map of node hash to active watch senders.
131    watches: Mutex<HashMap<NodeHash, Vec<WatchSender>>>,
132    /// Channel buffer size for new watches.
133    channel_buffer: usize,
134}
135
136impl Default for WatchManager {
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142impl WatchManager {
143    /// Create a new watch manager with default settings.
144    pub fn new() -> Self {
145        Self::with_buffer_size(16)
146    }
147
148    /// Create a new watch manager with a custom channel buffer size.
149    pub fn with_buffer_size(buffer_size: usize) -> Self {
150        Self {
151            watches: Mutex::new(HashMap::new()),
152            channel_buffer: buffer_size,
153        }
154    }
155
156    /// Create a new watch for a node.
157    ///
158    /// Returns a `Watch` that will receive snapshot updates for the specified node.
159    pub fn create_watch(&self, node_hash: NodeHash) -> Watch {
160        let id = WatchId::next();
161        let (sender, receiver) = mpsc::channel(self.channel_buffer);
162
163        let watch_sender = WatchSender {
164            id,
165            node_hash,
166            sender,
167        };
168
169        // Lock is held briefly, no I/O
170        {
171            let mut watches = self.watches.lock();
172            watches.entry(node_hash).or_default().push(watch_sender);
173        }
174
175        debug!(watch_id = %id, node = %node_hash, "created watch");
176
177        Watch {
178            id,
179            node_hash,
180            receiver,
181        }
182    }
183
184    /// Cancel a watch subscription.
185    ///
186    /// The watch will no longer receive updates.
187    pub fn cancel_watch(&self, watch_id: WatchId) {
188        let mut watches = self.watches.lock();
189
190        // Find and remove the watch
191        for senders in watches.values_mut() {
192            if let Some(pos) = senders.iter().position(|s| s.id == watch_id) {
193                senders.swap_remove(pos);
194                debug!(watch_id = %watch_id, "cancelled watch");
195                return;
196            }
197        }
198
199        warn!(watch_id = %watch_id, "attempted to cancel unknown watch");
200    }
201
202    /// Notify all watches for a node about a snapshot update.
203    ///
204    /// Removes any closed watches automatically.
205    pub fn notify(&self, node_hash: NodeHash, snapshot: Arc<Snapshot>) {
206        // Clone senders while holding lock briefly
207        let senders: Vec<WatchSender> = {
208            let watches = self.watches.lock();
209            watches.get(&node_hash).cloned().unwrap_or_default()
210        };
211
212        if senders.is_empty() {
213            return;
214        }
215
216        // Track which watches failed (closed)
217        let mut closed_ids = Vec::new();
218
219        for sender in &senders {
220            if let Err(XdsError::WatchClosed { watch_id }) = sender.try_send(Arc::clone(&snapshot))
221            {
222                closed_ids.push(WatchId(watch_id));
223            }
224        }
225
226        // Remove closed watches
227        if !closed_ids.is_empty() {
228            let mut watches = self.watches.lock();
229            if let Some(senders) = watches.get_mut(&node_hash) {
230                senders.retain(|s| !closed_ids.contains(&s.id));
231            }
232            debug!(count = closed_ids.len(), "removed closed watches");
233        }
234
235        trace!(
236            node = %node_hash,
237            watch_count = senders.len() - closed_ids.len(),
238            "notified watches of snapshot update"
239        );
240    }
241
242    /// Get the number of active watches for a node.
243    pub fn watch_count(&self, node_hash: NodeHash) -> usize {
244        let watches = self.watches.lock();
245        watches.get(&node_hash).map(|v| v.len()).unwrap_or(0)
246    }
247
248    /// Get the total number of active watches across all nodes.
249    pub fn total_watch_count(&self) -> usize {
250        let watches = self.watches.lock();
251        watches.values().map(|v| v.len()).sum()
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use std::sync::Arc as StdArc;
259    use std::thread;
260
261    #[test]
262    fn watch_id_unique() {
263        let id1 = WatchId::next();
264        let id2 = WatchId::next();
265        assert_ne!(id1, id2);
266    }
267
268    #[test]
269    fn watch_id_display() {
270        let id = WatchId::next();
271        let display = format!("{}", id);
272        assert!(display.starts_with("watch-"));
273    }
274
275    #[test]
276    fn watch_id_concurrent_uniqueness() {
277        use std::collections::HashSet;
278        use std::sync::Mutex;
279
280        let ids = StdArc::new(Mutex::new(HashSet::new()));
281        let mut handles = vec![];
282
283        // Spawn 10 threads, each generating 100 IDs
284        for _ in 0..10 {
285            let ids = StdArc::clone(&ids);
286            handles.push(thread::spawn(move || {
287                for _ in 0..100 {
288                    let id = WatchId::next();
289                    ids.lock().unwrap().insert(id.0);
290                }
291            }));
292        }
293
294        for handle in handles {
295            handle.join().unwrap();
296        }
297
298        // All 1000 IDs should be unique
299        assert_eq!(ids.lock().unwrap().len(), 1000);
300    }
301
302    #[tokio::test]
303    async fn watch_manager_create_and_notify() {
304        let manager = WatchManager::new();
305        let node = NodeHash::from_id("test-node");
306
307        let mut watch = manager.create_watch(node);
308        assert_eq!(manager.watch_count(node), 1);
309
310        let snapshot = Arc::new(Snapshot::builder().version("v1").build());
311        manager.notify(node, snapshot.clone());
312
313        let received = watch.recv().await.unwrap();
314        assert_eq!(received.version(), "v1");
315    }
316
317    #[test]
318    fn watch_manager_cancel() {
319        let manager = WatchManager::new();
320        let node = NodeHash::from_id("test-node");
321
322        let watch = manager.create_watch(node);
323        assert_eq!(manager.watch_count(node), 1);
324
325        manager.cancel_watch(watch.id());
326        assert_eq!(manager.watch_count(node), 0);
327    }
328
329    #[test]
330    fn watch_manager_cancel_nonexistent() {
331        let manager = WatchManager::new();
332        // Should not panic
333        manager.cancel_watch(WatchId::next());
334    }
335
336    #[tokio::test]
337    async fn watch_manager_multiple_watches_same_node() {
338        let manager = WatchManager::new();
339        let node = NodeHash::from_id("test-node");
340
341        let mut watch1 = manager.create_watch(node);
342        let mut watch2 = manager.create_watch(node);
343        let mut watch3 = manager.create_watch(node);
344
345        assert_eq!(manager.watch_count(node), 3);
346        assert_eq!(manager.total_watch_count(), 3);
347
348        let snapshot = Arc::new(Snapshot::builder().version("v1").build());
349        manager.notify(node, snapshot);
350
351        // All watches should receive the notification
352        let r1 = watch1.recv().await.unwrap();
353        let r2 = watch2.recv().await.unwrap();
354        let r3 = watch3.recv().await.unwrap();
355
356        assert_eq!(r1.version(), "v1");
357        assert_eq!(r2.version(), "v1");
358        assert_eq!(r3.version(), "v1");
359    }
360
361    #[tokio::test]
362    async fn watch_manager_multiple_nodes() {
363        let manager = WatchManager::new();
364        let node1 = NodeHash::from_id("node-1");
365        let node2 = NodeHash::from_id("node-2");
366
367        let mut watch1 = manager.create_watch(node1);
368        let mut watch2 = manager.create_watch(node2);
369
370        assert_eq!(manager.total_watch_count(), 2);
371
372        // Notify only node1
373        let snapshot1 = Arc::new(Snapshot::builder().version("v1").build());
374        manager.notify(node1, snapshot1);
375
376        // watch1 should receive, watch2 should not (use try_recv)
377        let r1 = watch1.recv().await.unwrap();
378        assert_eq!(r1.version(), "v1");
379
380        // Notify node2
381        let snapshot2 = Arc::new(Snapshot::builder().version("v2").build());
382        manager.notify(node2, snapshot2);
383
384        let r2 = watch2.recv().await.unwrap();
385        assert_eq!(r2.version(), "v2");
386    }
387
388    #[tokio::test]
389    async fn watch_manager_notify_nonexistent_node() {
390        let manager = WatchManager::new();
391        let node = NodeHash::from_id("nonexistent");
392
393        // Should not panic
394        let snapshot = Arc::new(Snapshot::builder().version("v1").build());
395        manager.notify(node, snapshot);
396    }
397
398    #[test]
399    fn watch_manager_cleanup_cancelled_watches() {
400        let manager = WatchManager::new();
401        let node = NodeHash::from_id("test-node");
402
403        let watch1 = manager.create_watch(node);
404        let watch2 = manager.create_watch(node);
405        let watch3 = manager.create_watch(node);
406
407        assert_eq!(manager.watch_count(node), 3);
408
409        manager.cancel_watch(watch2.id());
410        assert_eq!(manager.watch_count(node), 2);
411
412        manager.cancel_watch(watch1.id());
413        assert_eq!(manager.watch_count(node), 1);
414
415        manager.cancel_watch(watch3.id());
416        assert_eq!(manager.watch_count(node), 0);
417    }
418
419    #[tokio::test]
420    async fn watch_receive_timeout() {
421        use tokio::time::{timeout, Duration};
422
423        let manager = WatchManager::new();
424        let node = NodeHash::from_id("test-node");
425
426        let mut watch = manager.create_watch(node);
427
428        // No notification sent, should timeout
429        let result = timeout(Duration::from_millis(10), watch.recv()).await;
430        assert!(result.is_err(), "Should timeout without notification");
431    }
432
433    #[tokio::test]
434    async fn watch_dropped_sender_closes_watch() {
435        let node = NodeHash::from_id("test-node");
436        let mut watch;
437
438        {
439            let manager = WatchManager::new();
440            watch = manager.create_watch(node);
441            // manager dropped here
442        }
443
444        // Channel should be closed
445        let result = watch.recv().await;
446        assert!(
447            result.is_none(),
448            "Watch should close when manager is dropped"
449        );
450    }
451
452    #[test]
453    fn watch_with_custom_buffer_size() {
454        let manager = WatchManager::with_buffer_size(1);
455        let node = NodeHash::from_id("test-node");
456
457        let _watch = manager.create_watch(node);
458        assert_eq!(manager.channel_buffer, 1);
459    }
460}