nomad_protocol/extensions/
selective_sync.rs1use super::negotiation::{ext_type, Extension, NegotiationError};
32use std::collections::HashSet;
33
34pub mod selective_sync_flags {
36 pub const REGION_OPS: u8 = 0x01;
38 pub const PATTERNS: u8 = 0x02;
40 pub const NESTED: u8 = 0x04;
42}
43
44#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct SelectiveSyncConfig {
47 pub flags: u8,
49 pub max_regions: u16,
51 pub max_expression_len: u16,
53}
54
55impl Default for SelectiveSyncConfig {
56 fn default() -> Self {
57 Self {
58 flags: selective_sync_flags::REGION_OPS,
59 max_regions: 256,
60 max_expression_len: 128,
61 }
62 }
63}
64
65impl SelectiveSyncConfig {
66 pub fn full() -> Self {
68 Self {
69 flags: selective_sync_flags::REGION_OPS | selective_sync_flags::PATTERNS | selective_sync_flags::NESTED,
70 max_regions: 1024,
71 max_expression_len: 256,
72 }
73 }
74
75 pub fn supports_regions(&self) -> bool {
77 (self.flags & selective_sync_flags::REGION_OPS) != 0
78 }
79
80 pub fn supports_patterns(&self) -> bool {
82 (self.flags & selective_sync_flags::PATTERNS) != 0
83 }
84
85 pub fn supports_nested(&self) -> bool {
87 (self.flags & selective_sync_flags::NESTED) != 0
88 }
89
90 pub const fn wire_size() -> usize {
92 5 }
94
95 pub fn to_extension(&self) -> Extension {
97 let mut data = Vec::with_capacity(Self::wire_size());
98 data.push(self.flags);
99 data.extend_from_slice(&self.max_regions.to_le_bytes());
100 data.extend_from_slice(&self.max_expression_len.to_le_bytes());
101 Extension::new(ext_type::SELECTIVE_SYNC, data)
102 }
103
104 pub fn from_extension(ext: &Extension) -> Option<Self> {
106 if ext.ext_type != ext_type::SELECTIVE_SYNC || ext.data.len() < Self::wire_size() {
107 return None;
108 }
109 Some(Self {
110 flags: ext.data[0],
111 max_regions: u16::from_le_bytes([ext.data[1], ext.data[2]]),
112 max_expression_len: u16::from_le_bytes([ext.data[3], ext.data[4]]),
113 })
114 }
115
116 pub fn negotiate(client: &Self, server: &Self) -> Self {
118 Self {
119 flags: client.flags & server.flags,
120 max_regions: client.max_regions.min(server.max_regions),
121 max_expression_len: client.max_expression_len.min(server.max_expression_len),
122 }
123 }
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128#[repr(u8)]
129pub enum SubscriptionOp {
130 Subscribe = 0x00,
132 Unsubscribe = 0x01,
134 SubscribePattern = 0x02,
136 ClearAll = 0x03,
138}
139
140impl SubscriptionOp {
141 pub fn from_byte(b: u8) -> Option<Self> {
143 match b {
144 0x00 => Some(Self::Subscribe),
145 0x01 => Some(Self::Unsubscribe),
146 0x02 => Some(Self::SubscribePattern),
147 0x03 => Some(Self::ClearAll),
148 _ => None,
149 }
150 }
151}
152
153#[derive(Debug, Clone, PartialEq, Eq)]
155pub enum SubscriptionChange {
156 Subscribe(u32),
158 Unsubscribe(u32),
160 SubscribePattern(String),
162 ClearAll,
164}
165
166impl SubscriptionChange {
167 pub fn wire_size(&self) -> usize {
169 match self {
170 Self::Subscribe(_) | Self::Unsubscribe(_) => 5, Self::SubscribePattern(p) => 3 + p.len(), Self::ClearAll => 1, }
174 }
175
176 pub fn encode(&self) -> Vec<u8> {
178 let mut buf = Vec::with_capacity(self.wire_size());
179 match self {
180 Self::Subscribe(id) => {
181 buf.push(SubscriptionOp::Subscribe as u8);
182 buf.extend_from_slice(&id.to_le_bytes());
183 }
184 Self::Unsubscribe(id) => {
185 buf.push(SubscriptionOp::Unsubscribe as u8);
186 buf.extend_from_slice(&id.to_le_bytes());
187 }
188 Self::SubscribePattern(pattern) => {
189 buf.push(SubscriptionOp::SubscribePattern as u8);
190 buf.extend_from_slice(&(pattern.len() as u16).to_le_bytes());
191 buf.extend_from_slice(pattern.as_bytes());
192 }
193 Self::ClearAll => {
194 buf.push(SubscriptionOp::ClearAll as u8);
195 }
196 }
197 buf
198 }
199
200 pub fn decode(data: &[u8]) -> Result<(Self, usize), NegotiationError> {
202 if data.is_empty() {
203 return Err(NegotiationError::TooShort {
204 expected: 1,
205 actual: 0,
206 });
207 }
208
209 let op = SubscriptionOp::from_byte(data[0]).ok_or(NegotiationError::InvalidData)?;
210
211 match op {
212 SubscriptionOp::Subscribe | SubscriptionOp::Unsubscribe => {
213 if data.len() < 5 {
214 return Err(NegotiationError::TooShort {
215 expected: 5,
216 actual: data.len(),
217 });
218 }
219 let id = u32::from_le_bytes([data[1], data[2], data[3], data[4]]);
220 let change = if op == SubscriptionOp::Subscribe {
221 Self::Subscribe(id)
222 } else {
223 Self::Unsubscribe(id)
224 };
225 Ok((change, 5))
226 }
227 SubscriptionOp::SubscribePattern => {
228 if data.len() < 3 {
229 return Err(NegotiationError::TooShort {
230 expected: 3,
231 actual: data.len(),
232 });
233 }
234 let len = u16::from_le_bytes([data[1], data[2]]) as usize;
235 if data.len() < 3 + len {
236 return Err(NegotiationError::TooShort {
237 expected: 3 + len,
238 actual: data.len(),
239 });
240 }
241 let pattern = String::from_utf8(data[3..3 + len].to_vec())
242 .map_err(|_| NegotiationError::InvalidData)?;
243 Ok((Self::SubscribePattern(pattern), 3 + len))
244 }
245 SubscriptionOp::ClearAll => Ok((Self::ClearAll, 1)),
246 }
247 }
248}
249
250#[derive(Debug, Clone, Default)]
252pub struct SubscriptionState {
253 regions: HashSet<u32>,
255 patterns: Vec<String>,
257 max_regions: u16,
259}
260
261impl SubscriptionState {
262 pub fn new(max_regions: u16) -> Self {
264 Self {
265 regions: HashSet::new(),
266 patterns: Vec::new(),
267 max_regions,
268 }
269 }
270
271 pub fn apply(&mut self, change: &SubscriptionChange) -> bool {
275 match change {
276 SubscriptionChange::Subscribe(id) => {
277 if self.regions.len() >= self.max_regions as usize {
278 return false;
279 }
280 self.regions.insert(*id);
281 true
282 }
283 SubscriptionChange::Unsubscribe(id) => {
284 self.regions.remove(id);
285 true
286 }
287 SubscriptionChange::SubscribePattern(pattern) => {
288 if self.patterns.len() >= self.max_regions as usize {
289 return false;
290 }
291 if !self.patterns.contains(pattern) {
292 self.patterns.push(pattern.clone());
293 }
294 true
295 }
296 SubscriptionChange::ClearAll => {
297 self.regions.clear();
298 self.patterns.clear();
299 true
300 }
301 }
302 }
303
304 pub fn is_subscribed(&self, region_id: u32) -> bool {
306 self.regions.contains(®ion_id)
307 }
308
309 pub fn matches_pattern(&self, region_path: &str) -> bool {
313 for pattern in &self.patterns {
314 if pattern_matches(pattern, region_path) {
315 return true;
316 }
317 }
318 false
319 }
320
321 pub fn count(&self) -> usize {
323 self.regions.len() + self.patterns.len()
324 }
325
326 pub fn is_empty(&self) -> bool {
328 self.regions.is_empty() && self.patterns.is_empty()
329 }
330
331 pub fn region_ids(&self) -> impl Iterator<Item = &u32> {
333 self.regions.iter()
334 }
335
336 pub fn patterns(&self) -> &[String] {
338 &self.patterns
339 }
340}
341
342fn pattern_matches(pattern: &str, path: &str) -> bool {
348 if pattern == path {
350 return true;
351 }
352
353 if pattern == "**" {
355 return true;
356 }
357
358 if let Some(prefix) = pattern.strip_suffix("/*") {
360 if let Some(path_prefix) = path.rsplit_once('/') {
361 return path_prefix.0 == prefix;
362 }
363 return false;
364 }
365
366 if let Some(prefix) = pattern.strip_suffix("/**") {
368 return path.starts_with(prefix) && path.len() > prefix.len();
369 }
370
371 if let Some(prefix) = pattern.strip_suffix('*') {
373 return path.starts_with(prefix);
374 }
375
376 false
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382
383 #[test]
384 fn test_config_default() {
385 let config = SelectiveSyncConfig::default();
386 assert!(config.supports_regions());
387 assert!(!config.supports_patterns());
388 assert!(!config.supports_nested());
389 }
390
391 #[test]
392 fn test_config_full() {
393 let config = SelectiveSyncConfig::full();
394 assert!(config.supports_regions());
395 assert!(config.supports_patterns());
396 assert!(config.supports_nested());
397 }
398
399 #[test]
400 fn test_config_extension_roundtrip() {
401 let config = SelectiveSyncConfig {
402 flags: selective_sync_flags::REGION_OPS | selective_sync_flags::PATTERNS,
403 max_regions: 512,
404 max_expression_len: 200,
405 };
406
407 let ext = config.to_extension();
408 let decoded = SelectiveSyncConfig::from_extension(&ext).unwrap();
409 assert_eq!(decoded, config);
410 }
411
412 #[test]
413 fn test_subscribe_roundtrip() {
414 let change = SubscriptionChange::Subscribe(12345);
415 let encoded = change.encode();
416 let (decoded, len) = SubscriptionChange::decode(&encoded).unwrap();
417 assert_eq!(decoded, change);
418 assert_eq!(len, 5);
419 }
420
421 #[test]
422 fn test_unsubscribe_roundtrip() {
423 let change = SubscriptionChange::Unsubscribe(99999);
424 let encoded = change.encode();
425 let (decoded, _) = SubscriptionChange::decode(&encoded).unwrap();
426 assert_eq!(decoded, change);
427 }
428
429 #[test]
430 fn test_pattern_roundtrip() {
431 let change = SubscriptionChange::SubscribePattern("users/*/profile".to_string());
432 let encoded = change.encode();
433 let (decoded, len) = SubscriptionChange::decode(&encoded).unwrap();
434 assert_eq!(decoded, change);
435 assert_eq!(len, 3 + 15); }
437
438 #[test]
439 fn test_clear_all() {
440 let change = SubscriptionChange::ClearAll;
441 let encoded = change.encode();
442 assert_eq!(encoded.len(), 1);
443 let (decoded, len) = SubscriptionChange::decode(&encoded).unwrap();
444 assert_eq!(decoded, change);
445 assert_eq!(len, 1);
446 }
447
448 #[test]
449 fn test_subscription_state() {
450 let mut state = SubscriptionState::new(10);
451
452 assert!(state.apply(&SubscriptionChange::Subscribe(1)));
453 assert!(state.apply(&SubscriptionChange::Subscribe(2)));
454 assert!(state.is_subscribed(1));
455 assert!(state.is_subscribed(2));
456 assert!(!state.is_subscribed(3));
457
458 assert!(state.apply(&SubscriptionChange::Unsubscribe(1)));
459 assert!(!state.is_subscribed(1));
460
461 assert!(state.apply(&SubscriptionChange::ClearAll));
462 assert!(state.is_empty());
463 }
464
465 #[test]
466 fn test_subscription_limit() {
467 let mut state = SubscriptionState::new(2);
468
469 assert!(state.apply(&SubscriptionChange::Subscribe(1)));
470 assert!(state.apply(&SubscriptionChange::Subscribe(2)));
471 assert!(!state.apply(&SubscriptionChange::Subscribe(3))); assert_eq!(state.count(), 2);
474 }
475
476 #[test]
477 fn test_pattern_matching() {
478 assert!(pattern_matches("users/*", "users/alice"));
479 assert!(!pattern_matches("users/*", "users/alice/profile"));
480 assert!(pattern_matches("users/**", "users/alice/profile"));
481 assert!(pattern_matches("data*", "database"));
482 assert!(pattern_matches("**", "anything/at/all"));
483 assert!(pattern_matches("exact", "exact"));
484 assert!(!pattern_matches("exact", "not-exact"));
485 }
486
487 #[test]
488 fn test_decode_invalid() {
489 assert!(matches!(
491 SubscriptionChange::decode(&[0xFF]),
492 Err(NegotiationError::InvalidData)
493 ));
494
495 assert!(matches!(
497 SubscriptionChange::decode(&[0x00, 1, 2]),
498 Err(NegotiationError::TooShort { .. })
499 ));
500 }
501}