use arc_swap::ArcSwap;
use once_cell::sync::Lazy;
use crate::filters::{
CreateFilterArgs, CreationError, DynFilterFactory, FilterInstance, FilterSet,
};
static REGISTRY: Lazy<ArcSwap<FilterSet>> =
Lazy::new(|| ArcSwap::new(std::sync::Arc::new(FilterSet::default())));
#[derive(Debug)]
pub struct FilterRegistry;
impl FilterRegistry {
pub fn register(factories: impl IntoIterator<Item = DynFilterFactory>) {
let mut registry = FilterSet::clone(®ISTRY.load_full());
for factory in factories {
registry.insert(factory);
}
REGISTRY.store(std::sync::Arc::from(registry));
}
pub fn get(key: &str, args: CreateFilterArgs) -> Result<FilterInstance, CreationError> {
match REGISTRY.load().get(key).map(|p| p.create_filter(args)) {
None => Err(CreationError::NotFound(key.to_owned())),
Some(filter) => filter,
}
}
pub fn get_factory(key: &str) -> Option<std::sync::Arc<DynFilterFactory>> {
REGISTRY.load().get(key).cloned()
}
}
#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;
use crate::test::{alloc_buffer, load_test_filters};
use super::*;
use crate::filters::{
Filter, FilterError, FilterRegistry, PacketMut, ReadContext, WriteContext,
};
use crate::net::endpoint::{Endpoint, EndpointAddress};
#[allow(dead_code)]
struct TestFilter {}
impl Filter for TestFilter {
fn read<P: PacketMut>(&self, _: &mut ReadContext<'_, P>) -> Result<(), FilterError> {
Err(FilterError::Custom("test error"))
}
fn write<P: PacketMut>(&self, _: &mut WriteContext<P>) -> Result<(), FilterError> {
Err(FilterError::Custom("test error"))
}
}
#[tokio::test]
async fn insert_and_get() {
load_test_filters();
match FilterRegistry::get(&String::from("not.found"), CreateFilterArgs::fixed(None)) {
Ok(_) => unreachable!("should not be filter"),
Err(err) => assert_eq!(CreationError::NotFound("not.found".to_string()), err),
};
assert!(
FilterRegistry::get(&String::from("TestFilter"), CreateFilterArgs::fixed(None)).is_ok()
);
let instance =
FilterRegistry::get(&String::from("TestFilter"), CreateFilterArgs::fixed(None))
.unwrap();
let filter = instance.filter();
let addr: EndpointAddress = (Ipv4Addr::LOCALHOST, 8080).into();
let endpoint = Endpoint::new(addr.clone());
let endpoints = crate::net::cluster::ClusterMap::new_default([endpoint.clone()].into());
let mut dest = Vec::new();
assert!(
filter
.read(&mut ReadContext::new(
&endpoints,
addr.clone(),
alloc_buffer([]),
&mut dest,
))
.is_ok()
);
assert!(
filter
.write(&mut WriteContext::new(addr.clone(), addr, alloc_buffer([])))
.is_ok()
);
}
}