nomad_protocol/extensions/
selective_sync.rs

1//! Selective Sync extension (0x0005)
2//!
3//! Allows clients to subscribe to specific regions of state rather than
4//! receiving all updates. Useful for large state spaces where clients
5//! only need a subset (e.g., viewport in a document, area in a game world).
6//!
7//! Wire format for extension negotiation:
8//! ```text
9//! +0  Flags (1 byte)
10//!     - bit 0: Region subscribe/unsubscribe supported
11//!     - bit 1: Region expressions supported (patterns)
12//!     - bit 2: Nested regions supported
13//! +1  Max regions (2 bytes LE16) - maximum concurrent subscriptions
14//! +3  Max expression length (2 bytes LE16) - maximum pattern length
15//! ```
16//!
17//! Wire format for subscription change:
18//! ```text
19//! +0  Operation (1 byte)
20//!     - 0x00: Subscribe to region
21//!     - 0x01: Unsubscribe from region
22//!     - 0x02: Subscribe with pattern
23//!     - 0x03: Clear all subscriptions
24//! +1  Region spec (variable, based on operation)
25//! ```
26//!
27//! Region spec formats:
28//! - Subscribe/Unsubscribe: Region ID (4 bytes LE32)
29//! - Pattern: Length (2 bytes LE16) + Pattern bytes
30
31use super::negotiation::{ext_type, Extension, NegotiationError};
32use std::collections::HashSet;
33
34/// Selective sync flags
35pub mod selective_sync_flags {
36    /// Basic region subscribe/unsubscribe
37    pub const REGION_OPS: u8 = 0x01;
38    /// Pattern-based subscriptions (e.g., "users/*")
39    pub const PATTERNS: u8 = 0x02;
40    /// Nested/hierarchical regions
41    pub const NESTED: u8 = 0x04;
42}
43
44/// Selective sync configuration
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct SelectiveSyncConfig {
47    /// Feature flags
48    pub flags: u8,
49    /// Maximum concurrent subscriptions
50    pub max_regions: u16,
51    /// Maximum pattern length (if patterns supported)
52    pub max_expression_len: u16,
53}
54
55impl Default for SelectiveSyncConfig {
56    fn default() -> Self {
57        Self {
58            flags: selective_sync_flags::REGION_OPS,
59            max_regions: 256,
60            max_expression_len: 128,
61        }
62    }
63}
64
65impl SelectiveSyncConfig {
66    /// Create config with all features enabled
67    pub fn full() -> Self {
68        Self {
69            flags: selective_sync_flags::REGION_OPS | selective_sync_flags::PATTERNS | selective_sync_flags::NESTED,
70            max_regions: 1024,
71            max_expression_len: 256,
72        }
73    }
74
75    /// Check if region operations are supported
76    pub fn supports_regions(&self) -> bool {
77        (self.flags & selective_sync_flags::REGION_OPS) != 0
78    }
79
80    /// Check if pattern subscriptions are supported
81    pub fn supports_patterns(&self) -> bool {
82        (self.flags & selective_sync_flags::PATTERNS) != 0
83    }
84
85    /// Check if nested regions are supported
86    pub fn supports_nested(&self) -> bool {
87        (self.flags & selective_sync_flags::NESTED) != 0
88    }
89
90    /// Wire size
91    pub const fn wire_size() -> usize {
92        5 // flags(1) + max_regions(2) + max_expr(2)
93    }
94
95    /// Encode to extension
96    pub fn to_extension(&self) -> Extension {
97        let mut data = Vec::with_capacity(Self::wire_size());
98        data.push(self.flags);
99        data.extend_from_slice(&self.max_regions.to_le_bytes());
100        data.extend_from_slice(&self.max_expression_len.to_le_bytes());
101        Extension::new(ext_type::SELECTIVE_SYNC, data)
102    }
103
104    /// Decode from extension
105    pub fn from_extension(ext: &Extension) -> Option<Self> {
106        if ext.ext_type != ext_type::SELECTIVE_SYNC || ext.data.len() < Self::wire_size() {
107            return None;
108        }
109        Some(Self {
110            flags: ext.data[0],
111            max_regions: u16::from_le_bytes([ext.data[1], ext.data[2]]),
112            max_expression_len: u16::from_le_bytes([ext.data[3], ext.data[4]]),
113        })
114    }
115
116    /// Negotiate between client and server
117    pub fn negotiate(client: &Self, server: &Self) -> Self {
118        Self {
119            flags: client.flags & server.flags,
120            max_regions: client.max_regions.min(server.max_regions),
121            max_expression_len: client.max_expression_len.min(server.max_expression_len),
122        }
123    }
124}
125
126/// Subscription operation type
127#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128#[repr(u8)]
129pub enum SubscriptionOp {
130    /// Subscribe to a region by ID
131    Subscribe = 0x00,
132    /// Unsubscribe from a region by ID
133    Unsubscribe = 0x01,
134    /// Subscribe using a pattern
135    SubscribePattern = 0x02,
136    /// Clear all subscriptions
137    ClearAll = 0x03,
138}
139
140impl SubscriptionOp {
141    /// Convert from byte
142    pub fn from_byte(b: u8) -> Option<Self> {
143        match b {
144            0x00 => Some(Self::Subscribe),
145            0x01 => Some(Self::Unsubscribe),
146            0x02 => Some(Self::SubscribePattern),
147            0x03 => Some(Self::ClearAll),
148            _ => None,
149        }
150    }
151}
152
153/// A subscription change request
154#[derive(Debug, Clone, PartialEq, Eq)]
155pub enum SubscriptionChange {
156    /// Subscribe to a specific region
157    Subscribe(u32),
158    /// Unsubscribe from a specific region
159    Unsubscribe(u32),
160    /// Subscribe using a pattern (e.g., "users/*")
161    SubscribePattern(String),
162    /// Clear all subscriptions
163    ClearAll,
164}
165
166impl SubscriptionChange {
167    /// Wire size
168    pub fn wire_size(&self) -> usize {
169        match self {
170            Self::Subscribe(_) | Self::Unsubscribe(_) => 5, // op(1) + region(4)
171            Self::SubscribePattern(p) => 3 + p.len(),       // op(1) + len(2) + pattern
172            Self::ClearAll => 1,                            // op(1)
173        }
174    }
175
176    /// Encode to bytes
177    pub fn encode(&self) -> Vec<u8> {
178        let mut buf = Vec::with_capacity(self.wire_size());
179        match self {
180            Self::Subscribe(id) => {
181                buf.push(SubscriptionOp::Subscribe as u8);
182                buf.extend_from_slice(&id.to_le_bytes());
183            }
184            Self::Unsubscribe(id) => {
185                buf.push(SubscriptionOp::Unsubscribe as u8);
186                buf.extend_from_slice(&id.to_le_bytes());
187            }
188            Self::SubscribePattern(pattern) => {
189                buf.push(SubscriptionOp::SubscribePattern as u8);
190                buf.extend_from_slice(&(pattern.len() as u16).to_le_bytes());
191                buf.extend_from_slice(pattern.as_bytes());
192            }
193            Self::ClearAll => {
194                buf.push(SubscriptionOp::ClearAll as u8);
195            }
196        }
197        buf
198    }
199
200    /// Decode from bytes
201    pub fn decode(data: &[u8]) -> Result<(Self, usize), NegotiationError> {
202        if data.is_empty() {
203            return Err(NegotiationError::TooShort {
204                expected: 1,
205                actual: 0,
206            });
207        }
208
209        let op = SubscriptionOp::from_byte(data[0]).ok_or(NegotiationError::InvalidData)?;
210
211        match op {
212            SubscriptionOp::Subscribe | SubscriptionOp::Unsubscribe => {
213                if data.len() < 5 {
214                    return Err(NegotiationError::TooShort {
215                        expected: 5,
216                        actual: data.len(),
217                    });
218                }
219                let id = u32::from_le_bytes([data[1], data[2], data[3], data[4]]);
220                let change = if op == SubscriptionOp::Subscribe {
221                    Self::Subscribe(id)
222                } else {
223                    Self::Unsubscribe(id)
224                };
225                Ok((change, 5))
226            }
227            SubscriptionOp::SubscribePattern => {
228                if data.len() < 3 {
229                    return Err(NegotiationError::TooShort {
230                        expected: 3,
231                        actual: data.len(),
232                    });
233                }
234                let len = u16::from_le_bytes([data[1], data[2]]) as usize;
235                if data.len() < 3 + len {
236                    return Err(NegotiationError::TooShort {
237                        expected: 3 + len,
238                        actual: data.len(),
239                    });
240                }
241                let pattern = String::from_utf8(data[3..3 + len].to_vec())
242                    .map_err(|_| NegotiationError::InvalidData)?;
243                Ok((Self::SubscribePattern(pattern), 3 + len))
244            }
245            SubscriptionOp::ClearAll => Ok((Self::ClearAll, 1)),
246        }
247    }
248}
249
250/// Tracks active subscriptions for a client
251#[derive(Debug, Clone, Default)]
252pub struct SubscriptionState {
253    /// Subscribed region IDs
254    regions: HashSet<u32>,
255    /// Subscribed patterns (if supported)
256    patterns: Vec<String>,
257    /// Maximum allowed regions
258    max_regions: u16,
259}
260
261impl SubscriptionState {
262    /// Create new subscription state with limit
263    pub fn new(max_regions: u16) -> Self {
264        Self {
265            regions: HashSet::new(),
266            patterns: Vec::new(),
267            max_regions,
268        }
269    }
270
271    /// Apply a subscription change
272    ///
273    /// Returns true if the change was applied, false if rejected (e.g., at limit)
274    pub fn apply(&mut self, change: &SubscriptionChange) -> bool {
275        match change {
276            SubscriptionChange::Subscribe(id) => {
277                if self.regions.len() >= self.max_regions as usize {
278                    return false;
279                }
280                self.regions.insert(*id);
281                true
282            }
283            SubscriptionChange::Unsubscribe(id) => {
284                self.regions.remove(id);
285                true
286            }
287            SubscriptionChange::SubscribePattern(pattern) => {
288                if self.patterns.len() >= self.max_regions as usize {
289                    return false;
290                }
291                if !self.patterns.contains(pattern) {
292                    self.patterns.push(pattern.clone());
293                }
294                true
295            }
296            SubscriptionChange::ClearAll => {
297                self.regions.clear();
298                self.patterns.clear();
299                true
300            }
301        }
302    }
303
304    /// Check if a region ID is subscribed
305    pub fn is_subscribed(&self, region_id: u32) -> bool {
306        self.regions.contains(&region_id)
307    }
308
309    /// Check if a region matches any pattern
310    ///
311    /// This is a stub - real implementation would use proper pattern matching
312    pub fn matches_pattern(&self, region_path: &str) -> bool {
313        for pattern in &self.patterns {
314            if pattern_matches(pattern, region_path) {
315                return true;
316            }
317        }
318        false
319    }
320
321    /// Get count of active subscriptions
322    pub fn count(&self) -> usize {
323        self.regions.len() + self.patterns.len()
324    }
325
326    /// Check if any subscriptions are active
327    pub fn is_empty(&self) -> bool {
328        self.regions.is_empty() && self.patterns.is_empty()
329    }
330
331    /// Get all subscribed region IDs
332    pub fn region_ids(&self) -> impl Iterator<Item = &u32> {
333        self.regions.iter()
334    }
335
336    /// Get all patterns
337    pub fn patterns(&self) -> &[String] {
338        &self.patterns
339    }
340}
341
342/// Simple glob-style pattern matching
343///
344/// Supports:
345/// - `*` matches any sequence within a segment
346/// - `**` matches any sequence including path separators
347fn pattern_matches(pattern: &str, path: &str) -> bool {
348    // Handle exact match
349    if pattern == path {
350        return true;
351    }
352
353    // Handle ** (match everything)
354    if pattern == "**" {
355        return true;
356    }
357
358    // Handle trailing /*
359    if let Some(prefix) = pattern.strip_suffix("/*") {
360        if let Some(path_prefix) = path.rsplit_once('/') {
361            return path_prefix.0 == prefix;
362        }
363        return false;
364    }
365
366    // Handle trailing /**
367    if let Some(prefix) = pattern.strip_suffix("/**") {
368        return path.starts_with(prefix) && path.len() > prefix.len();
369    }
370
371    // Handle prefix*
372    if let Some(prefix) = pattern.strip_suffix('*') {
373        return path.starts_with(prefix);
374    }
375
376    false
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382
383    #[test]
384    fn test_config_default() {
385        let config = SelectiveSyncConfig::default();
386        assert!(config.supports_regions());
387        assert!(!config.supports_patterns());
388        assert!(!config.supports_nested());
389    }
390
391    #[test]
392    fn test_config_full() {
393        let config = SelectiveSyncConfig::full();
394        assert!(config.supports_regions());
395        assert!(config.supports_patterns());
396        assert!(config.supports_nested());
397    }
398
399    #[test]
400    fn test_config_extension_roundtrip() {
401        let config = SelectiveSyncConfig {
402            flags: selective_sync_flags::REGION_OPS | selective_sync_flags::PATTERNS,
403            max_regions: 512,
404            max_expression_len: 200,
405        };
406
407        let ext = config.to_extension();
408        let decoded = SelectiveSyncConfig::from_extension(&ext).unwrap();
409        assert_eq!(decoded, config);
410    }
411
412    #[test]
413    fn test_subscribe_roundtrip() {
414        let change = SubscriptionChange::Subscribe(12345);
415        let encoded = change.encode();
416        let (decoded, len) = SubscriptionChange::decode(&encoded).unwrap();
417        assert_eq!(decoded, change);
418        assert_eq!(len, 5);
419    }
420
421    #[test]
422    fn test_unsubscribe_roundtrip() {
423        let change = SubscriptionChange::Unsubscribe(99999);
424        let encoded = change.encode();
425        let (decoded, _) = SubscriptionChange::decode(&encoded).unwrap();
426        assert_eq!(decoded, change);
427    }
428
429    #[test]
430    fn test_pattern_roundtrip() {
431        let change = SubscriptionChange::SubscribePattern("users/*/profile".to_string());
432        let encoded = change.encode();
433        let (decoded, len) = SubscriptionChange::decode(&encoded).unwrap();
434        assert_eq!(decoded, change);
435        assert_eq!(len, 3 + 15); // op + len + "users/*/profile"
436    }
437
438    #[test]
439    fn test_clear_all() {
440        let change = SubscriptionChange::ClearAll;
441        let encoded = change.encode();
442        assert_eq!(encoded.len(), 1);
443        let (decoded, len) = SubscriptionChange::decode(&encoded).unwrap();
444        assert_eq!(decoded, change);
445        assert_eq!(len, 1);
446    }
447
448    #[test]
449    fn test_subscription_state() {
450        let mut state = SubscriptionState::new(10);
451
452        assert!(state.apply(&SubscriptionChange::Subscribe(1)));
453        assert!(state.apply(&SubscriptionChange::Subscribe(2)));
454        assert!(state.is_subscribed(1));
455        assert!(state.is_subscribed(2));
456        assert!(!state.is_subscribed(3));
457
458        assert!(state.apply(&SubscriptionChange::Unsubscribe(1)));
459        assert!(!state.is_subscribed(1));
460
461        assert!(state.apply(&SubscriptionChange::ClearAll));
462        assert!(state.is_empty());
463    }
464
465    #[test]
466    fn test_subscription_limit() {
467        let mut state = SubscriptionState::new(2);
468
469        assert!(state.apply(&SubscriptionChange::Subscribe(1)));
470        assert!(state.apply(&SubscriptionChange::Subscribe(2)));
471        assert!(!state.apply(&SubscriptionChange::Subscribe(3))); // At limit
472
473        assert_eq!(state.count(), 2);
474    }
475
476    #[test]
477    fn test_pattern_matching() {
478        assert!(pattern_matches("users/*", "users/alice"));
479        assert!(!pattern_matches("users/*", "users/alice/profile"));
480        assert!(pattern_matches("users/**", "users/alice/profile"));
481        assert!(pattern_matches("data*", "database"));
482        assert!(pattern_matches("**", "anything/at/all"));
483        assert!(pattern_matches("exact", "exact"));
484        assert!(!pattern_matches("exact", "not-exact"));
485    }
486
487    #[test]
488    fn test_decode_invalid() {
489        // Invalid op
490        assert!(matches!(
491            SubscriptionChange::decode(&[0xFF]),
492            Err(NegotiationError::InvalidData)
493        ));
494
495        // Truncated subscribe
496        assert!(matches!(
497            SubscriptionChange::decode(&[0x00, 1, 2]),
498            Err(NegotiationError::TooShort { .. })
499        ));
500    }
501}