oxigdal_websocket/broadcast/
filter.rs1use crate::protocol::message::{Message, MessageType};
4use crate::server::connection::ConnectionId;
5use std::collections::HashSet;
6
7pub trait FilterPredicate: Send + Sync {
9 fn should_deliver(&self, message: &Message, connection_id: &ConnectionId) -> bool;
11}
12
13pub enum MessageFilter {
15 All,
17 MessageType(HashSet<MessageType>),
19 FromConnections(HashSet<ConnectionId>),
21 ExcludeConnections(HashSet<ConnectionId>),
23 Custom(Box<dyn FilterPredicate>),
25}
26
27impl MessageFilter {
28 pub fn all() -> Self {
30 Self::All
31 }
32
33 pub fn message_types(types: Vec<MessageType>) -> Self {
35 Self::MessageType(types.into_iter().collect())
36 }
37
38 pub fn from_connections(connections: Vec<ConnectionId>) -> Self {
40 Self::FromConnections(connections.into_iter().collect())
41 }
42
43 pub fn exclude_connections(connections: Vec<ConnectionId>) -> Self {
45 Self::ExcludeConnections(connections.into_iter().collect())
46 }
47
48 pub fn should_deliver(&self, message: &Message, connection_id: &ConnectionId) -> bool {
50 match self {
51 Self::All => true,
52 Self::MessageType(types) => types.contains(&message.msg_type),
53 Self::FromConnections(conns) => conns.contains(connection_id),
54 Self::ExcludeConnections(conns) => !conns.contains(connection_id),
55 Self::Custom(predicate) => predicate.should_deliver(message, connection_id),
56 }
57 }
58}
59
60pub struct FilterChain {
62 filters: Vec<MessageFilter>,
63 all_must_pass: bool,
65}
66
67impl FilterChain {
68 pub fn new_and() -> Self {
70 Self {
71 filters: Vec::new(),
72 all_must_pass: true,
73 }
74 }
75
76 pub fn new_or() -> Self {
78 Self {
79 filters: Vec::new(),
80 all_must_pass: false,
81 }
82 }
83
84 pub fn add_filter(mut self, filter: MessageFilter) -> Self {
86 self.filters.push(filter);
87 self
88 }
89
90 pub fn should_deliver(&self, message: &Message, connection_id: &ConnectionId) -> bool {
92 if self.filters.is_empty() {
93 return true;
94 }
95
96 if self.all_must_pass {
97 self.filters
99 .iter()
100 .all(|f| f.should_deliver(message, connection_id))
101 } else {
102 self.filters
104 .iter()
105 .any(|f| f.should_deliver(message, connection_id))
106 }
107 }
108
109 pub fn len(&self) -> usize {
111 self.filters.len()
112 }
113
114 pub fn is_empty(&self) -> bool {
116 self.filters.is_empty()
117 }
118}
119
120pub struct GeoBboxFilter {
122 min_x: f64,
123 min_y: f64,
124 max_x: f64,
125 max_y: f64,
126}
127
128impl GeoBboxFilter {
129 pub fn new(min_x: f64, min_y: f64, max_x: f64, max_y: f64) -> Self {
131 Self {
132 min_x,
133 min_y,
134 max_x,
135 max_y,
136 }
137 }
138
139 pub fn contains(&self, x: f64, y: f64) -> bool {
141 x >= self.min_x && x <= self.max_x && y >= self.min_y && y <= self.max_y
142 }
143}
144
145impl FilterPredicate for GeoBboxFilter {
146 fn should_deliver(&self, _message: &Message, _connection_id: &ConnectionId) -> bool {
147 true
150 }
151}
152
153#[allow(dead_code)]
155pub struct AttributeFilter {
156 key: String,
157 value: String,
158}
159
160impl AttributeFilter {
161 pub fn new(key: String, value: String) -> Self {
163 Self { key, value }
164 }
165}
166
167impl FilterPredicate for AttributeFilter {
168 fn should_deliver(&self, _message: &Message, _connection_id: &ConnectionId) -> bool {
169 true
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use uuid::Uuid;
178
179 #[test]
180 fn test_filter_all() {
181 let filter = MessageFilter::all();
182 let msg = Message::ping();
183 let conn_id = Uuid::new_v4();
184
185 assert!(filter.should_deliver(&msg, &conn_id));
186 }
187
188 #[test]
189 fn test_filter_message_type() {
190 let filter = MessageFilter::message_types(vec![MessageType::Ping, MessageType::Pong]);
191 let ping = Message::ping();
192 let pong = Message::pong();
193 let data = Message::data(bytes::Bytes::new());
194 let conn_id = Uuid::new_v4();
195
196 assert!(filter.should_deliver(&ping, &conn_id));
197 assert!(filter.should_deliver(&pong, &conn_id));
198 assert!(!filter.should_deliver(&data, &conn_id));
199 }
200
201 #[test]
202 fn test_filter_from_connections() {
203 let conn1 = Uuid::new_v4();
204 let conn2 = Uuid::new_v4();
205 let conn3 = Uuid::new_v4();
206
207 let filter = MessageFilter::from_connections(vec![conn1, conn2]);
208 let msg = Message::ping();
209
210 assert!(filter.should_deliver(&msg, &conn1));
211 assert!(filter.should_deliver(&msg, &conn2));
212 assert!(!filter.should_deliver(&msg, &conn3));
213 }
214
215 #[test]
216 fn test_filter_exclude_connections() {
217 let conn1 = Uuid::new_v4();
218 let conn2 = Uuid::new_v4();
219
220 let filter = MessageFilter::exclude_connections(vec![conn1]);
221 let msg = Message::ping();
222
223 assert!(!filter.should_deliver(&msg, &conn1));
224 assert!(filter.should_deliver(&msg, &conn2));
225 }
226
227 #[test]
228 fn test_filter_chain_and() {
229 let conn1 = Uuid::new_v4();
230
231 let chain = FilterChain::new_and()
232 .add_filter(MessageFilter::message_types(vec![MessageType::Ping]))
233 .add_filter(MessageFilter::from_connections(vec![conn1]));
234
235 let ping = Message::ping();
236 let pong = Message::pong();
237
238 assert!(chain.should_deliver(&ping, &conn1));
239 assert!(!chain.should_deliver(&pong, &conn1));
240 }
241
242 #[test]
243 fn test_filter_chain_or() {
244 let conn1 = Uuid::new_v4();
245 let conn2 = Uuid::new_v4();
246
247 let chain = FilterChain::new_or()
248 .add_filter(MessageFilter::message_types(vec![MessageType::Ping]))
249 .add_filter(MessageFilter::from_connections(vec![conn1]));
250
251 let ping = Message::ping();
252 let data = Message::data(bytes::Bytes::new());
253
254 assert!(chain.should_deliver(&ping, &conn1));
255 assert!(chain.should_deliver(&ping, &conn2));
256 assert!(chain.should_deliver(&data, &conn1));
257 assert!(!chain.should_deliver(&data, &conn2));
258 }
259
260 #[test]
261 fn test_geo_bbox_filter() {
262 let filter = GeoBboxFilter::new(-180.0, -90.0, 180.0, 90.0);
263
264 assert!(filter.contains(0.0, 0.0));
265 assert!(filter.contains(-122.4, 37.8));
266 assert!(!filter.contains(200.0, 100.0));
267 }
268}