use crate::protocol::message::{Message, MessageType};
use crate::server::connection::ConnectionId;
use std::collections::HashSet;
pub trait FilterPredicate: Send + Sync {
fn should_deliver(&self, message: &Message, connection_id: &ConnectionId) -> bool;
}
pub enum MessageFilter {
All,
MessageType(HashSet<MessageType>),
FromConnections(HashSet<ConnectionId>),
ExcludeConnections(HashSet<ConnectionId>),
Custom(Box<dyn FilterPredicate>),
}
impl MessageFilter {
pub fn all() -> Self {
Self::All
}
pub fn message_types(types: Vec<MessageType>) -> Self {
Self::MessageType(types.into_iter().collect())
}
pub fn from_connections(connections: Vec<ConnectionId>) -> Self {
Self::FromConnections(connections.into_iter().collect())
}
pub fn exclude_connections(connections: Vec<ConnectionId>) -> Self {
Self::ExcludeConnections(connections.into_iter().collect())
}
pub fn should_deliver(&self, message: &Message, connection_id: &ConnectionId) -> bool {
match self {
Self::All => true,
Self::MessageType(types) => types.contains(&message.msg_type),
Self::FromConnections(conns) => conns.contains(connection_id),
Self::ExcludeConnections(conns) => !conns.contains(connection_id),
Self::Custom(predicate) => predicate.should_deliver(message, connection_id),
}
}
}
pub struct FilterChain {
filters: Vec<MessageFilter>,
all_must_pass: bool,
}
impl FilterChain {
pub fn new_and() -> Self {
Self {
filters: Vec::new(),
all_must_pass: true,
}
}
pub fn new_or() -> Self {
Self {
filters: Vec::new(),
all_must_pass: false,
}
}
pub fn add_filter(mut self, filter: MessageFilter) -> Self {
self.filters.push(filter);
self
}
pub fn should_deliver(&self, message: &Message, connection_id: &ConnectionId) -> bool {
if self.filters.is_empty() {
return true;
}
if self.all_must_pass {
self.filters
.iter()
.all(|f| f.should_deliver(message, connection_id))
} else {
self.filters
.iter()
.any(|f| f.should_deliver(message, connection_id))
}
}
pub fn len(&self) -> usize {
self.filters.len()
}
pub fn is_empty(&self) -> bool {
self.filters.is_empty()
}
}
pub struct GeoBboxFilter {
min_x: f64,
min_y: f64,
max_x: f64,
max_y: f64,
}
impl GeoBboxFilter {
pub fn new(min_x: f64, min_y: f64, max_x: f64, max_y: f64) -> Self {
Self {
min_x,
min_y,
max_x,
max_y,
}
}
pub fn contains(&self, x: f64, y: f64) -> bool {
x >= self.min_x && x <= self.max_x && y >= self.min_y && y <= self.max_y
}
}
impl FilterPredicate for GeoBboxFilter {
fn should_deliver(&self, _message: &Message, _connection_id: &ConnectionId) -> bool {
true
}
}
#[allow(dead_code)]
pub struct AttributeFilter {
key: String,
value: String,
}
impl AttributeFilter {
pub fn new(key: String, value: String) -> Self {
Self { key, value }
}
}
impl FilterPredicate for AttributeFilter {
fn should_deliver(&self, _message: &Message, _connection_id: &ConnectionId) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use uuid::Uuid;
#[test]
fn test_filter_all() {
let filter = MessageFilter::all();
let msg = Message::ping();
let conn_id = Uuid::new_v4();
assert!(filter.should_deliver(&msg, &conn_id));
}
#[test]
fn test_filter_message_type() {
let filter = MessageFilter::message_types(vec![MessageType::Ping, MessageType::Pong]);
let ping = Message::ping();
let pong = Message::pong();
let data = Message::data(bytes::Bytes::new());
let conn_id = Uuid::new_v4();
assert!(filter.should_deliver(&ping, &conn_id));
assert!(filter.should_deliver(&pong, &conn_id));
assert!(!filter.should_deliver(&data, &conn_id));
}
#[test]
fn test_filter_from_connections() {
let conn1 = Uuid::new_v4();
let conn2 = Uuid::new_v4();
let conn3 = Uuid::new_v4();
let filter = MessageFilter::from_connections(vec![conn1, conn2]);
let msg = Message::ping();
assert!(filter.should_deliver(&msg, &conn1));
assert!(filter.should_deliver(&msg, &conn2));
assert!(!filter.should_deliver(&msg, &conn3));
}
#[test]
fn test_filter_exclude_connections() {
let conn1 = Uuid::new_v4();
let conn2 = Uuid::new_v4();
let filter = MessageFilter::exclude_connections(vec![conn1]);
let msg = Message::ping();
assert!(!filter.should_deliver(&msg, &conn1));
assert!(filter.should_deliver(&msg, &conn2));
}
#[test]
fn test_filter_chain_and() {
let conn1 = Uuid::new_v4();
let chain = FilterChain::new_and()
.add_filter(MessageFilter::message_types(vec![MessageType::Ping]))
.add_filter(MessageFilter::from_connections(vec![conn1]));
let ping = Message::ping();
let pong = Message::pong();
assert!(chain.should_deliver(&ping, &conn1));
assert!(!chain.should_deliver(&pong, &conn1));
}
#[test]
fn test_filter_chain_or() {
let conn1 = Uuid::new_v4();
let conn2 = Uuid::new_v4();
let chain = FilterChain::new_or()
.add_filter(MessageFilter::message_types(vec![MessageType::Ping]))
.add_filter(MessageFilter::from_connections(vec![conn1]));
let ping = Message::ping();
let data = Message::data(bytes::Bytes::new());
assert!(chain.should_deliver(&ping, &conn1));
assert!(chain.should_deliver(&ping, &conn2));
assert!(chain.should_deliver(&data, &conn1));
assert!(!chain.should_deliver(&data, &conn2));
}
#[test]
fn test_geo_bbox_filter() {
let filter = GeoBboxFilter::new(-180.0, -90.0, 180.0, 90.0);
assert!(filter.contains(0.0, 0.0));
assert!(filter.contains(-122.4, 37.8));
assert!(!filter.contains(200.0, 100.0));
}
}