use std::net::IpAddr;
use super::types::{AddressFamily, InetExtension, Protocol, TcpState, UnixShow};
#[derive(Debug, Clone)]
pub struct SocketFilter {
pub kind: FilterKind,
}
#[derive(Debug, Clone)]
pub enum FilterKind {
Inet(InetFilter),
Unix(UnixFilter),
Netlink(NetlinkFilter),
Packet(PacketFilter),
}
impl SocketFilter {
pub fn tcp() -> InetFilterBuilder {
InetFilterBuilder::new(Protocol::Tcp)
}
pub fn udp() -> InetFilterBuilder {
InetFilterBuilder::new(Protocol::Udp)
}
pub fn mptcp() -> InetFilterBuilder {
InetFilterBuilder::new(Protocol::Mptcp)
}
pub fn sctp() -> InetFilterBuilder {
InetFilterBuilder::new(Protocol::Sctp)
}
pub fn dccp() -> InetFilterBuilder {
InetFilterBuilder::new(Protocol::Dccp)
}
pub fn raw() -> InetFilterBuilder {
InetFilterBuilder::new(Protocol::Raw)
}
pub fn unix() -> UnixFilterBuilder {
UnixFilterBuilder::new()
}
pub fn netlink() -> NetlinkFilterBuilder {
NetlinkFilterBuilder::new()
}
pub fn packet() -> PacketFilterBuilder {
PacketFilterBuilder::new()
}
}
#[derive(Debug, Clone)]
pub struct InetFilter {
pub family: Option<AddressFamily>,
pub protocol: Protocol,
pub states: u32,
pub extensions: u8,
pub local_addr: Option<IpAddr>,
pub local_port: Option<u16>,
pub remote_addr: Option<IpAddr>,
pub remote_port: Option<u16>,
pub interface: Option<u32>,
pub mark: Option<(u32, u32)>, pub cgroup_id: Option<u64>,
}
impl Default for InetFilter {
fn default() -> Self {
Self {
family: None,
protocol: Protocol::Tcp,
states: TcpState::all_mask(),
extensions: 0,
local_addr: None,
local_port: None,
remote_addr: None,
remote_port: None,
interface: None,
mark: None,
cgroup_id: None,
}
}
}
#[derive(Debug, Clone)]
pub struct InetFilterBuilder {
filter: InetFilter,
}
impl InetFilterBuilder {
pub fn new(protocol: Protocol) -> Self {
Self {
filter: InetFilter {
protocol,
..Default::default()
},
}
}
pub fn family(mut self, family: AddressFamily) -> Self {
self.filter.family = Some(family);
self
}
pub fn ipv4(self) -> Self {
self.family(AddressFamily::Inet)
}
pub fn ipv6(self) -> Self {
self.family(AddressFamily::Inet6)
}
pub fn states(mut self, states: &[TcpState]) -> Self {
self.filter.states = states.iter().fold(0, |acc, s| acc | s.mask());
self
}
pub fn all_states(mut self) -> Self {
self.filter.states = TcpState::all_mask();
self
}
pub fn connected(mut self) -> Self {
self.filter.states = TcpState::connected_mask();
self
}
pub fn listening(mut self) -> Self {
self.filter.states = TcpState::Listen.mask();
self
}
pub fn with_mem_info(mut self) -> Self {
self.filter.extensions |= InetExtension::MemInfo.mask();
self
}
pub fn with_tcp_info(mut self) -> Self {
self.filter.extensions |= InetExtension::Info.mask();
self
}
pub fn with_congestion(mut self) -> Self {
self.filter.extensions |= InetExtension::Cong.mask();
self
}
pub fn with_tos(mut self) -> Self {
self.filter.extensions |= InetExtension::Tos.mask();
self
}
pub fn with_all_extensions(mut self) -> Self {
self.filter.extensions = 0xFF;
self
}
pub fn local_addr(mut self, addr: IpAddr) -> Self {
self.filter.local_addr = Some(addr);
self
}
pub fn local_port(mut self, port: u16) -> Self {
self.filter.local_port = Some(port);
self
}
pub fn remote_addr(mut self, addr: IpAddr) -> Self {
self.filter.remote_addr = Some(addr);
self
}
pub fn remote_port(mut self, port: u16) -> Self {
self.filter.remote_port = Some(port);
self
}
pub fn interface(mut self, ifindex: u32) -> Self {
self.filter.interface = Some(ifindex);
self
}
pub fn mark(mut self, value: u32, mask: u32) -> Self {
self.filter.mark = Some((value, mask));
self
}
pub fn cgroup(mut self, cgroup_id: u64) -> Self {
self.filter.cgroup_id = Some(cgroup_id);
self
}
pub fn build(self) -> SocketFilter {
SocketFilter {
kind: FilterKind::Inet(self.filter),
}
}
}
#[derive(Debug, Clone)]
pub struct UnixFilter {
pub socket_types: u32,
pub states: u32,
pub show: u32,
pub inode: Option<u32>,
pub path_pattern: Option<String>,
}
impl Default for UnixFilter {
fn default() -> Self {
Self {
socket_types: 0xFFFFFFFF,
states: TcpState::all_mask(),
show: UnixShow::combine(&[UnixShow::Name, UnixShow::Peer, UnixShow::RqLen]),
inode: None,
path_pattern: None,
}
}
}
#[derive(Debug, Clone)]
pub struct UnixFilterBuilder {
filter: UnixFilter,
}
impl UnixFilterBuilder {
pub fn new() -> Self {
Self {
filter: UnixFilter::default(),
}
}
pub fn stream(mut self) -> Self {
self.filter.socket_types = 1 << libc::SOCK_STREAM;
self
}
pub fn dgram(mut self) -> Self {
self.filter.socket_types = 1 << libc::SOCK_DGRAM;
self
}
pub fn seqpacket(mut self) -> Self {
self.filter.socket_types = 1 << libc::SOCK_SEQPACKET;
self
}
pub fn states(mut self, states: &[TcpState]) -> Self {
self.filter.states = states.iter().fold(0, |acc, s| acc | s.mask());
self
}
pub fn listening(mut self) -> Self {
self.filter.states = TcpState::Listen.mask();
self
}
pub fn connected(mut self) -> Self {
self.filter.states = TcpState::connected_mask();
self
}
pub fn show_name(mut self) -> Self {
self.filter.show |= UnixShow::Name.mask();
self
}
pub fn show_vfs(mut self) -> Self {
self.filter.show |= UnixShow::Vfs.mask();
self
}
pub fn show_peer(mut self) -> Self {
self.filter.show |= UnixShow::Peer.mask();
self
}
pub fn show_icons(mut self) -> Self {
self.filter.show |= UnixShow::Icons.mask();
self
}
pub fn show_rqlen(mut self) -> Self {
self.filter.show |= UnixShow::RqLen.mask();
self
}
pub fn show_meminfo(mut self) -> Self {
self.filter.show |= UnixShow::MemInfo.mask();
self
}
pub fn show_uid(mut self) -> Self {
self.filter.show |= UnixShow::Uid.mask();
self
}
pub fn show_all(mut self) -> Self {
self.filter.show = UnixShow::combine(&[
UnixShow::Name,
UnixShow::Vfs,
UnixShow::Peer,
UnixShow::Icons,
UnixShow::RqLen,
UnixShow::MemInfo,
UnixShow::Uid,
]);
self
}
pub fn inode(mut self, inode: u32) -> Self {
self.filter.inode = Some(inode);
self
}
pub fn path(mut self, pattern: impl Into<String>) -> Self {
self.filter.path_pattern = Some(pattern.into());
self
}
pub fn build(self) -> SocketFilter {
SocketFilter {
kind: FilterKind::Unix(self.filter),
}
}
}
impl Default for UnixFilterBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct NetlinkFilter {
pub protocol: Option<u8>,
pub show_meminfo: bool,
pub show_groups: bool,
}
impl Default for NetlinkFilter {
fn default() -> Self {
Self {
protocol: None,
show_meminfo: false,
show_groups: true,
}
}
}
#[derive(Debug, Clone)]
pub struct NetlinkFilterBuilder {
filter: NetlinkFilter,
}
impl NetlinkFilterBuilder {
pub fn new() -> Self {
Self {
filter: NetlinkFilter::default(),
}
}
pub fn protocol(mut self, protocol: u8) -> Self {
self.filter.protocol = Some(protocol);
self
}
pub fn show_meminfo(mut self) -> Self {
self.filter.show_meminfo = true;
self
}
pub fn show_groups(mut self) -> Self {
self.filter.show_groups = true;
self
}
pub fn build(self) -> SocketFilter {
SocketFilter {
kind: FilterKind::Netlink(self.filter),
}
}
}
impl Default for NetlinkFilterBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct PacketFilter {
pub show_meminfo: bool,
pub show_info: bool,
pub show_fanout: bool,
}
impl Default for PacketFilter {
fn default() -> Self {
Self {
show_meminfo: false,
show_info: true,
show_fanout: true,
}
}
}
#[derive(Debug, Clone)]
pub struct PacketFilterBuilder {
filter: PacketFilter,
}
impl PacketFilterBuilder {
pub fn new() -> Self {
Self {
filter: PacketFilter::default(),
}
}
pub fn show_meminfo(mut self) -> Self {
self.filter.show_meminfo = true;
self
}
pub fn show_info(mut self) -> Self {
self.filter.show_info = true;
self
}
pub fn show_fanout(mut self) -> Self {
self.filter.show_fanout = true;
self
}
pub fn build(self) -> SocketFilter {
SocketFilter {
kind: FilterKind::Packet(self.filter),
}
}
}
impl Default for PacketFilterBuilder {
fn default() -> Self {
Self::new()
}
}