use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
pub trait AddrFilter: Send + Sync + 'static {
fn filter(&self, addrs: Vec<SocketAddr>) -> Vec<SocketAddr>;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct PassThroughFilter;
impl AddrFilter for PassThroughFilter {
fn filter(&self, addrs: Vec<SocketAddr>) -> Vec<SocketAddr> {
addrs
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct DropLoopbackFilter;
impl AddrFilter for DropLoopbackFilter {
fn filter(&self, addrs: Vec<SocketAddr>) -> Vec<SocketAddr> {
addrs
.into_iter()
.filter(|sa| !sa.ip().is_loopback())
.collect()
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct DropUnspecifiedFilter;
impl AddrFilter for DropUnspecifiedFilter {
fn filter(&self, addrs: Vec<SocketAddr>) -> Vec<SocketAddr> {
addrs
.into_iter()
.filter(|sa| !sa.ip().is_unspecified())
.collect()
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct DedupFilter;
impl AddrFilter for DedupFilter {
fn filter(&self, addrs: Vec<SocketAddr>) -> Vec<SocketAddr> {
let mut seen = std::collections::HashSet::with_capacity(addrs.len());
let mut out = Vec::with_capacity(addrs.len());
for sa in addrs {
if seen.insert(sa) {
out.push(sa);
}
}
out
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct PreferIpv6Filter;
impl AddrFilter for PreferIpv6Filter {
fn filter(&self, mut addrs: Vec<SocketAddr>) -> Vec<SocketAddr> {
addrs.sort_by_key(|sa| match sa.ip() {
IpAddr::V6(_) => 0u8,
IpAddr::V4(_) => 1u8,
});
addrs
}
}
#[derive(Default, Clone)]
pub struct CompositeFilter {
filters: Vec<Arc<dyn AddrFilter>>,
}
impl std::fmt::Debug for CompositeFilter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompositeFilter")
.field("filter_count", &self.filters.len())
.finish()
}
}
impl CompositeFilter {
pub fn new() -> Self {
Self::default()
}
pub fn push<F: AddrFilter>(mut self, f: F) -> Self {
self.filters.push(Arc::new(f));
self
}
pub fn push_arc(mut self, f: Arc<dyn AddrFilter>) -> Self {
self.filters.push(f);
self
}
pub fn len(&self) -> usize {
self.filters.len()
}
pub fn is_empty(&self) -> bool {
self.filters.is_empty()
}
}
impl AddrFilter for CompositeFilter {
fn filter(&self, addrs: Vec<SocketAddr>) -> Vec<SocketAddr> {
let mut current = addrs;
for f in &self.filters {
current = f.filter(current);
}
current
}
}
#[cfg(test)]
mod tests {
use super::*;
fn v4(port: u16) -> SocketAddr {
format!("192.0.2.1:{port}").parse().unwrap()
}
fn v6(port: u16) -> SocketAddr {
format!("[2001:db8::1]:{port}").parse().unwrap()
}
fn lo(port: u16) -> SocketAddr {
format!("127.0.0.1:{port}").parse().unwrap()
}
fn unspec(port: u16) -> SocketAddr {
format!("0.0.0.0:{port}").parse().unwrap()
}
#[test]
fn passthrough_returns_input_untouched() {
let f = PassThroughFilter;
let addrs = vec![v4(1), v6(2), lo(3)];
assert_eq!(f.filter(addrs.clone()), addrs);
}
#[test]
fn drop_loopback_removes_loopback_only() {
let f = DropLoopbackFilter;
let out = f.filter(vec![v4(1), lo(2), v6(3)]);
assert_eq!(out, vec![v4(1), v6(3)]);
}
#[test]
fn drop_unspecified_removes_zero_addrs() {
let f = DropUnspecifiedFilter;
let out = f.filter(vec![v4(1), unspec(2), v6(3)]);
assert_eq!(out, vec![v4(1), v6(3)]);
}
#[test]
fn dedup_preserves_first_seen_order() {
let f = DedupFilter;
let out = f.filter(vec![v4(1), v4(2), v4(1), v6(3), v4(2)]);
assert_eq!(out, vec![v4(1), v4(2), v6(3)]);
}
#[test]
fn prefer_ipv6_brings_v6_to_front_stable() {
let f = PreferIpv6Filter;
let out = f.filter(vec![v4(1), v6(2), v4(3), v6(4)]);
assert_eq!(out, vec![v6(2), v6(4), v4(1), v4(3)]);
}
#[test]
fn composite_chains_in_order() {
let f = CompositeFilter::new()
.push(DropLoopbackFilter)
.push(DropUnspecifiedFilter)
.push(DedupFilter)
.push(PreferIpv6Filter);
assert_eq!(f.len(), 4);
let out = f.filter(vec![v4(1), lo(2), unspec(3), v4(1), v6(4)]);
assert_eq!(out, vec![v6(4), v4(1)]);
}
#[test]
fn empty_composite_is_passthrough() {
let f = CompositeFilter::new();
assert!(f.is_empty());
let addrs = vec![v4(1), v6(2)];
assert_eq!(f.filter(addrs.clone()), addrs);
}
}