Skip to main content

oxigdal_websocket/broadcast/
filter.rs

1//! Message filtering for selective broadcasting
2
3use crate::protocol::message::{Message, MessageType};
4use crate::server::connection::ConnectionId;
5use std::collections::HashSet;
6
7/// Filter predicate trait
8pub trait FilterPredicate: Send + Sync {
9    /// Check if a message should be delivered to a connection
10    fn should_deliver(&self, message: &Message, connection_id: &ConnectionId) -> bool;
11}
12
13/// Message filter
14pub enum MessageFilter {
15    /// Accept all messages
16    All,
17    /// Accept messages of specific types
18    MessageType(HashSet<MessageType>),
19    /// Accept messages from specific connections
20    FromConnections(HashSet<ConnectionId>),
21    /// Reject messages from specific connections
22    ExcludeConnections(HashSet<ConnectionId>),
23    /// Custom filter predicate
24    Custom(Box<dyn FilterPredicate>),
25}
26
27impl MessageFilter {
28    /// Create an "accept all" filter
29    pub fn all() -> Self {
30        Self::All
31    }
32
33    /// Create a filter for specific message types
34    pub fn message_types(types: Vec<MessageType>) -> Self {
35        Self::MessageType(types.into_iter().collect())
36    }
37
38    /// Create a filter for specific connections
39    pub fn from_connections(connections: Vec<ConnectionId>) -> Self {
40        Self::FromConnections(connections.into_iter().collect())
41    }
42
43    /// Create a filter to exclude specific connections
44    pub fn exclude_connections(connections: Vec<ConnectionId>) -> Self {
45        Self::ExcludeConnections(connections.into_iter().collect())
46    }
47
48    /// Check if a message should be delivered
49    pub fn should_deliver(&self, message: &Message, connection_id: &ConnectionId) -> bool {
50        match self {
51            Self::All => true,
52            Self::MessageType(types) => types.contains(&message.msg_type),
53            Self::FromConnections(conns) => conns.contains(connection_id),
54            Self::ExcludeConnections(conns) => !conns.contains(connection_id),
55            Self::Custom(predicate) => predicate.should_deliver(message, connection_id),
56        }
57    }
58}
59
60/// Filter chain for combining multiple filters
61pub struct FilterChain {
62    filters: Vec<MessageFilter>,
63    /// If true, all filters must pass (AND). If false, any filter can pass (OR)
64    all_must_pass: bool,
65}
66
67impl FilterChain {
68    /// Create a new filter chain (AND logic)
69    pub fn new_and() -> Self {
70        Self {
71            filters: Vec::new(),
72            all_must_pass: true,
73        }
74    }
75
76    /// Create a new filter chain (OR logic)
77    pub fn new_or() -> Self {
78        Self {
79            filters: Vec::new(),
80            all_must_pass: false,
81        }
82    }
83
84    /// Add a filter to the chain
85    pub fn add_filter(mut self, filter: MessageFilter) -> Self {
86        self.filters.push(filter);
87        self
88    }
89
90    /// Check if a message should be delivered
91    pub fn should_deliver(&self, message: &Message, connection_id: &ConnectionId) -> bool {
92        if self.filters.is_empty() {
93            return true;
94        }
95
96        if self.all_must_pass {
97            // AND logic - all filters must pass
98            self.filters
99                .iter()
100                .all(|f| f.should_deliver(message, connection_id))
101        } else {
102            // OR logic - any filter can pass
103            self.filters
104                .iter()
105                .any(|f| f.should_deliver(message, connection_id))
106        }
107    }
108
109    /// Get filter count
110    pub fn len(&self) -> usize {
111        self.filters.len()
112    }
113
114    /// Check if empty
115    pub fn is_empty(&self) -> bool {
116        self.filters.is_empty()
117    }
118}
119
120/// Geographic bounding box filter
121pub struct GeoBboxFilter {
122    min_x: f64,
123    min_y: f64,
124    max_x: f64,
125    max_y: f64,
126}
127
128impl GeoBboxFilter {
129    /// Create a new geographic bounding box filter
130    pub fn new(min_x: f64, min_y: f64, max_x: f64, max_y: f64) -> Self {
131        Self {
132            min_x,
133            min_y,
134            max_x,
135            max_y,
136        }
137    }
138
139    /// Check if coordinates are within bounds
140    pub fn contains(&self, x: f64, y: f64) -> bool {
141        x >= self.min_x && x <= self.max_x && y >= self.min_y && y <= self.max_y
142    }
143}
144
145impl FilterPredicate for GeoBboxFilter {
146    fn should_deliver(&self, _message: &Message, _connection_id: &ConnectionId) -> bool {
147        // In a real implementation, would extract coordinates from message
148        // For now, accept all
149        true
150    }
151}
152
153/// Attribute-based filter
154#[allow(dead_code)]
155pub struct AttributeFilter {
156    key: String,
157    value: String,
158}
159
160impl AttributeFilter {
161    /// Create a new attribute filter
162    pub fn new(key: String, value: String) -> Self {
163        Self { key, value }
164    }
165}
166
167impl FilterPredicate for AttributeFilter {
168    fn should_deliver(&self, _message: &Message, _connection_id: &ConnectionId) -> bool {
169        // In a real implementation, would check message attributes
170        true
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use uuid::Uuid;
178
179    #[test]
180    fn test_filter_all() {
181        let filter = MessageFilter::all();
182        let msg = Message::ping();
183        let conn_id = Uuid::new_v4();
184
185        assert!(filter.should_deliver(&msg, &conn_id));
186    }
187
188    #[test]
189    fn test_filter_message_type() {
190        let filter = MessageFilter::message_types(vec![MessageType::Ping, MessageType::Pong]);
191        let ping = Message::ping();
192        let pong = Message::pong();
193        let data = Message::data(bytes::Bytes::new());
194        let conn_id = Uuid::new_v4();
195
196        assert!(filter.should_deliver(&ping, &conn_id));
197        assert!(filter.should_deliver(&pong, &conn_id));
198        assert!(!filter.should_deliver(&data, &conn_id));
199    }
200
201    #[test]
202    fn test_filter_from_connections() {
203        let conn1 = Uuid::new_v4();
204        let conn2 = Uuid::new_v4();
205        let conn3 = Uuid::new_v4();
206
207        let filter = MessageFilter::from_connections(vec![conn1, conn2]);
208        let msg = Message::ping();
209
210        assert!(filter.should_deliver(&msg, &conn1));
211        assert!(filter.should_deliver(&msg, &conn2));
212        assert!(!filter.should_deliver(&msg, &conn3));
213    }
214
215    #[test]
216    fn test_filter_exclude_connections() {
217        let conn1 = Uuid::new_v4();
218        let conn2 = Uuid::new_v4();
219
220        let filter = MessageFilter::exclude_connections(vec![conn1]);
221        let msg = Message::ping();
222
223        assert!(!filter.should_deliver(&msg, &conn1));
224        assert!(filter.should_deliver(&msg, &conn2));
225    }
226
227    #[test]
228    fn test_filter_chain_and() {
229        let conn1 = Uuid::new_v4();
230
231        let chain = FilterChain::new_and()
232            .add_filter(MessageFilter::message_types(vec![MessageType::Ping]))
233            .add_filter(MessageFilter::from_connections(vec![conn1]));
234
235        let ping = Message::ping();
236        let pong = Message::pong();
237
238        assert!(chain.should_deliver(&ping, &conn1));
239        assert!(!chain.should_deliver(&pong, &conn1));
240    }
241
242    #[test]
243    fn test_filter_chain_or() {
244        let conn1 = Uuid::new_v4();
245        let conn2 = Uuid::new_v4();
246
247        let chain = FilterChain::new_or()
248            .add_filter(MessageFilter::message_types(vec![MessageType::Ping]))
249            .add_filter(MessageFilter::from_connections(vec![conn1]));
250
251        let ping = Message::ping();
252        let data = Message::data(bytes::Bytes::new());
253
254        assert!(chain.should_deliver(&ping, &conn1));
255        assert!(chain.should_deliver(&ping, &conn2));
256        assert!(chain.should_deliver(&data, &conn1));
257        assert!(!chain.should_deliver(&data, &conn2));
258    }
259
260    #[test]
261    fn test_geo_bbox_filter() {
262        let filter = GeoBboxFilter::new(-180.0, -90.0, 180.0, 90.0);
263
264        assert!(filter.contains(0.0, 0.0));
265        assert!(filter.contains(-122.4, 37.8));
266        assert!(!filter.contains(200.0, 100.0));
267    }
268}