Skip to main content

oxigdal_ws/
subscription.rs

1//! Subscription management for WebSocket clients.
2
3use crate::error::{Error, Result};
4use crate::protocol::{EventType, SubscriptionFilter};
5use dashmap::DashMap;
6use std::collections::HashSet;
7use std::ops::Range;
8use std::sync::Arc;
9use uuid::Uuid;
10
11/// Subscription type.
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub enum SubscriptionType {
14    /// Tile subscription with bbox and zoom range
15    Tiles {
16        /// Bounding box [min_x, min_y, max_x, max_y]
17        bbox: [i64; 4], // Using i64 for exact comparison
18        /// Zoom level range
19        zoom_range: Range<u8>,
20    },
21    /// Feature subscription with filters
22    Features {
23        /// Layer name
24        layer: Option<String>,
25    },
26    /// Event subscription
27    Events {
28        /// Event types
29        event_types: HashSet<EventType>,
30    },
31}
32
33/// A subscription to WebSocket updates.
34#[derive(Debug, Clone)]
35pub struct Subscription {
36    /// Unique subscription ID
37    pub id: String,
38    /// Client ID that owns this subscription
39    pub client_id: String,
40    /// Subscription type
41    pub subscription_type: SubscriptionType,
42    /// Optional filter
43    pub filter: Option<SubscriptionFilter>,
44}
45
46impl Subscription {
47    /// Create a new subscription.
48    pub fn new(
49        client_id: String,
50        subscription_type: SubscriptionType,
51        filter: Option<SubscriptionFilter>,
52    ) -> Self {
53        Self {
54            id: Uuid::new_v4().to_string(),
55            client_id,
56            subscription_type,
57            filter,
58        }
59    }
60
61    /// Create a tile subscription.
62    pub fn tiles(
63        client_id: String,
64        bbox: [f64; 4],
65        zoom_range: Range<u8>,
66        filter: Option<SubscriptionFilter>,
67    ) -> Self {
68        // Convert f64 bbox to i64 for exact comparison (multiply by 1e6)
69        let bbox_i64 = [
70            (bbox[0] * 1_000_000.0) as i64,
71            (bbox[1] * 1_000_000.0) as i64,
72            (bbox[2] * 1_000_000.0) as i64,
73            (bbox[3] * 1_000_000.0) as i64,
74        ];
75
76        Self::new(
77            client_id,
78            SubscriptionType::Tiles {
79                bbox: bbox_i64,
80                zoom_range,
81            },
82            filter,
83        )
84    }
85
86    /// Create a feature subscription.
87    pub fn features(
88        client_id: String,
89        layer: Option<String>,
90        filter: Option<SubscriptionFilter>,
91    ) -> Self {
92        Self::new(client_id, SubscriptionType::Features { layer }, filter)
93    }
94
95    /// Create an event subscription.
96    pub fn events(
97        client_id: String,
98        event_types: HashSet<EventType>,
99        filter: Option<SubscriptionFilter>,
100    ) -> Self {
101        Self::new(client_id, SubscriptionType::Events { event_types }, filter)
102    }
103
104    /// Check if a tile matches this subscription.
105    pub fn matches_tile(&self, tile_x: u32, tile_y: u32, zoom: u8) -> bool {
106        match &self.subscription_type {
107            SubscriptionType::Tiles { bbox, zoom_range } => {
108                if !zoom_range.contains(&zoom) {
109                    return false;
110                }
111
112                // Convert tile coordinates to bbox
113                let n = 2_u32.pow(zoom.into());
114                let tile_bbox = Self::tile_to_bbox(tile_x, tile_y, n);
115
116                // Convert back to i64 for comparison
117                let tile_bbox_i64 = [
118                    (tile_bbox[0] * 1_000_000.0) as i64,
119                    (tile_bbox[1] * 1_000_000.0) as i64,
120                    (tile_bbox[2] * 1_000_000.0) as i64,
121                    (tile_bbox[3] * 1_000_000.0) as i64,
122                ];
123
124                // Check if bboxes overlap
125                Self::bboxes_overlap(bbox, &tile_bbox_i64)
126            }
127            _ => false,
128        }
129    }
130
131    /// Convert tile coordinates to geographic bbox.
132    fn tile_to_bbox(x: u32, y: u32, n: u32) -> [f64; 4] {
133        let n_f64 = n as f64;
134        let min_lon = (x as f64 / n_f64) * 360.0 - 180.0;
135        let max_lon = ((x + 1) as f64 / n_f64) * 360.0 - 180.0;
136
137        let lat_rad = |y_val: f64| -> f64 {
138            let n_rad = std::f64::consts::PI - 2.0 * std::f64::consts::PI * y_val / n_f64;
139            n_rad.sinh().atan().to_degrees()
140        };
141
142        let max_lat = lat_rad(y as f64);
143        let min_lat = lat_rad((y + 1) as f64);
144
145        [min_lon, min_lat, max_lon, max_lat]
146    }
147
148    /// Check if two bboxes overlap.
149    fn bboxes_overlap(bbox1: &[i64; 4], bbox2: &[i64; 4]) -> bool {
150        bbox1[0] <= bbox2[2] && bbox1[2] >= bbox2[0] && bbox1[1] <= bbox2[3] && bbox1[3] >= bbox2[1]
151    }
152
153    /// Check if a feature matches this subscription.
154    pub fn matches_feature(&self, layer: Option<&str>) -> bool {
155        match &self.subscription_type {
156            SubscriptionType::Features { layer: sub_layer } => {
157                if let Some(sub_layer) = sub_layer {
158                    layer == Some(sub_layer.as_str())
159                } else {
160                    true // Match all layers
161                }
162            }
163            _ => false,
164        }
165    }
166
167    /// Check if an event matches this subscription.
168    pub fn matches_event(&self, event_type: EventType) -> bool {
169        match &self.subscription_type {
170            SubscriptionType::Events { event_types } => event_types.contains(&event_type),
171            _ => false,
172        }
173    }
174}
175
176/// Manager for all subscriptions.
177pub struct SubscriptionManager {
178    /// All subscriptions by ID
179    subscriptions: Arc<DashMap<String, Subscription>>,
180    /// Subscriptions by client ID
181    client_subscriptions: Arc<DashMap<String, HashSet<String>>>,
182}
183
184impl SubscriptionManager {
185    /// Create a new subscription manager.
186    pub fn new() -> Self {
187        Self {
188            subscriptions: Arc::new(DashMap::new()),
189            client_subscriptions: Arc::new(DashMap::new()),
190        }
191    }
192
193    /// Add a subscription.
194    pub fn add(&self, subscription: Subscription) -> Result<String> {
195        let sub_id = subscription.id.clone();
196        let client_id = subscription.client_id.clone();
197
198        self.subscriptions.insert(sub_id.clone(), subscription);
199
200        self.client_subscriptions
201            .entry(client_id)
202            .or_default()
203            .insert(sub_id.clone());
204
205        Ok(sub_id)
206    }
207
208    /// Remove a subscription.
209    pub fn remove(&self, subscription_id: &str) -> Result<()> {
210        if let Some((_, subscription)) = self.subscriptions.remove(subscription_id) {
211            if let Some(mut client_subs) =
212                self.client_subscriptions.get_mut(&subscription.client_id)
213            {
214                client_subs.remove(subscription_id);
215            }
216            Ok(())
217        } else {
218            Err(Error::NotFound(format!(
219                "Subscription not found: {}",
220                subscription_id
221            )))
222        }
223    }
224
225    /// Remove all subscriptions for a client.
226    pub fn remove_client(&self, client_id: &str) -> Result<()> {
227        if let Some((_, sub_ids)) = self.client_subscriptions.remove(client_id) {
228            for sub_id in sub_ids {
229                self.subscriptions.remove(&sub_id);
230            }
231        }
232        Ok(())
233    }
234
235    /// Get a subscription by ID.
236    pub fn get(&self, subscription_id: &str) -> Option<Subscription> {
237        self.subscriptions.get(subscription_id).map(|s| s.clone())
238    }
239
240    /// Get all subscriptions for a client.
241    pub fn get_client_subscriptions(&self, client_id: &str) -> Vec<Subscription> {
242        if let Some(sub_ids) = self.client_subscriptions.get(client_id) {
243            sub_ids.iter().filter_map(|id| self.get(id)).collect()
244        } else {
245            Vec::new()
246        }
247    }
248
249    /// Find subscriptions matching a tile.
250    pub fn find_tile_subscriptions(&self, tile_x: u32, tile_y: u32, zoom: u8) -> Vec<Subscription> {
251        self.subscriptions
252            .iter()
253            .filter(|entry| entry.value().matches_tile(tile_x, tile_y, zoom))
254            .map(|entry| entry.value().clone())
255            .collect()
256    }
257
258    /// Find subscriptions matching a feature.
259    pub fn find_feature_subscriptions(&self, layer: Option<&str>) -> Vec<Subscription> {
260        self.subscriptions
261            .iter()
262            .filter(|entry| entry.value().matches_feature(layer))
263            .map(|entry| entry.value().clone())
264            .collect()
265    }
266
267    /// Find subscriptions matching an event.
268    pub fn find_event_subscriptions(&self, event_type: EventType) -> Vec<Subscription> {
269        self.subscriptions
270            .iter()
271            .filter(|entry| entry.value().matches_event(event_type))
272            .map(|entry| entry.value().clone())
273            .collect()
274    }
275
276    /// Get total subscription count.
277    pub fn count(&self) -> usize {
278        self.subscriptions.len()
279    }
280
281    /// Get client count.
282    pub fn client_count(&self) -> usize {
283        self.client_subscriptions.len()
284    }
285}
286
287impl Default for SubscriptionManager {
288    fn default() -> Self {
289        Self::new()
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296
297    #[test]
298    fn test_subscription_creation() {
299        let sub = Subscription::tiles(
300            "client-1".to_string(),
301            [-180.0, -90.0, 180.0, 90.0],
302            0..14,
303            None,
304        );
305
306        assert_eq!(sub.client_id, "client-1");
307        assert!(!sub.id.is_empty());
308    }
309
310    #[test]
311    fn test_tile_matching() {
312        let sub = Subscription::tiles(
313            "client-1".to_string(),
314            [-180.0, -90.0, 0.0, 0.0],
315            5..10,
316            None,
317        );
318
319        // Should match tiles in the southwest quadrant at zoom 5-9
320        // At zoom 5 (n=32), tile (0,0) is northwest, tile (0,31) is southwest
321        // At zoom 5, southwest quadrant uses x=[0,15], y=[16,31]
322        assert!(sub.matches_tile(0, 31, 5)); // Southwest corner tile
323        assert!(sub.matches_tile(15, 16, 5)); // Near equator/prime meridian boundary
324        // At zoom 8 (n=256), southwest quadrant uses x=[0,127], y=[128,255]
325        assert!(sub.matches_tile(100, 200, 8)); // Deep in southwest quadrant
326
327        // Should not match tiles outside zoom range
328        assert!(!sub.matches_tile(0, 31, 4)); // Correct tile but wrong zoom
329        assert!(!sub.matches_tile(0, 31, 10)); // Correct tile but wrong zoom
330
331        // Should not match tiles outside bbox
332        assert!(!sub.matches_tile(0, 0, 5)); // Northwest, not southwest
333        assert!(!sub.matches_tile(16, 16, 8)); // Northeast, not southwest
334    }
335
336    #[test]
337    fn test_subscription_manager() {
338        let manager = SubscriptionManager::new();
339
340        let sub1 = Subscription::tiles(
341            "client-1".to_string(),
342            [-180.0, -90.0, 0.0, 0.0],
343            0..14,
344            None,
345        );
346        let sub_id1 = sub1.id.clone();
347
348        assert!(manager.add(sub1).is_ok());
349
350        assert_eq!(manager.count(), 1);
351        assert_eq!(manager.client_count(), 1);
352
353        let retrieved = manager.get(&sub_id1);
354        assert!(retrieved.is_some());
355        if let Some(ref sub) = retrieved {
356            assert_eq!(sub.id, sub_id1);
357        }
358
359        assert!(manager.remove(&sub_id1).is_ok());
360        assert_eq!(manager.count(), 0);
361    }
362
363    #[test]
364    fn test_find_tile_subscriptions() {
365        let manager = SubscriptionManager::new();
366
367        // client-1: southwest quadrant
368        let sub1 = Subscription::tiles(
369            "client-1".to_string(),
370            [-180.0, -90.0, 0.0, 0.0],
371            5..10,
372            None,
373        );
374
375        // client-2: northeast quadrant
376        let sub2 =
377            Subscription::tiles("client-2".to_string(), [0.0, 0.0, 180.0, 90.0], 5..10, None);
378
379        assert!(manager.add(sub1).is_ok());
380        assert!(manager.add(sub2).is_ok());
381
382        // Find subscriptions for a tile in the southwest quadrant
383        // At zoom 5, tile (0,31) is in the southwest corner
384        let matches = manager.find_tile_subscriptions(0, 31, 5);
385        assert_eq!(matches.len(), 1);
386        assert_eq!(matches[0].client_id, "client-1");
387
388        // Find subscriptions for a tile in the northeast quadrant
389        // At zoom 5, tile (31,0) is in the northeast corner
390        let matches2 = manager.find_tile_subscriptions(31, 0, 5);
391        assert_eq!(matches2.len(), 1);
392        assert_eq!(matches2[0].client_id, "client-2");
393    }
394
395    #[test]
396    fn test_remove_client() {
397        let manager = SubscriptionManager::new();
398
399        let sub1 = Subscription::tiles(
400            "client-1".to_string(),
401            [-180.0, -90.0, 0.0, 0.0],
402            0..14,
403            None,
404        );
405
406        let sub2 = Subscription::features("client-1".to_string(), Some("layer1".to_string()), None);
407
408        assert!(manager.add(sub1).is_ok());
409        assert!(manager.add(sub2).is_ok());
410
411        assert_eq!(manager.count(), 2);
412
413        assert!(manager.remove_client("client-1").is_ok());
414        assert_eq!(manager.count(), 0);
415    }
416}