use std::{
collections::HashMap,
env, mem,
net::{IpAddr, Ipv4Addr},
os::unix::io::{AsRawFd, FromRawFd},
process,
time::Duration,
};
extern crate libc;
use libc::{c_int, c_void, socklen_t};
extern crate errno;
use errno::Errno;
extern crate mio;
use mio::{net::UdpSocket, Events, Interest, Poll, Token};
extern crate rsmnl as mnl;
use mnl::{AttrTbl, CbResult, CbStatus, GenError, MsgVec, Msghdr, Socket};
extern crate rsmnl_linux as linux;
use linux::netfilter::{
nfnetlink::Nfgenmsg,
nfnetlink_conntrack as nfct,
nfnetlink_conntrack::{CtattrType, CtattrTypeTbl},
};
mod timerfd;
#[derive(Debug)]
struct Nstats {
pkts: u64,
bytes: u64,
}
fn data_cb(hmap: &mut HashMap<IpAddr, Box<Nstats>>) -> impl FnMut(&Msghdr) -> CbResult + '_ {
move |nlh: &Msghdr| {
let tb = CtattrTypeTbl::from_nlmsg(mem::size_of::<Nfgenmsg>(), nlh)?;
let mut addr = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)); if let Some(tuple_tb) = tb.tuple_orig()? {
if let Some(ip_tb) = tuple_tb.ip()? {
ip_tb.v4_src()?.map(|r| {
addr = IpAddr::V4(*r);
});
ip_tb.v6_src()?.map(|r| {
addr = IpAddr::V6(*r);
});
}
}
let ns = hmap
.entry(addr)
.or_insert(Box::new(Nstats { pkts: 0, bytes: 0 }));
if let Some(counters_tb) = tb.counters_orig()? {
counters_tb.packets()?.map(|c| {
ns.pkts += u64::from_be(*c);
});
counters_tb.bytes()?.map(|c| {
ns.bytes += u64::from_be(*c);
});
}
Ok(CbStatus::Ok)
}
}
fn handle(nl: &mut Socket, hmap: &mut HashMap<IpAddr, Box<Nstats>>) -> CbResult {
let mut buf = mnl::dump_buffer();
loop {
match nl.recvfrom(&mut buf) {
Ok(nrecv) => match mnl::cb_run(&buf[0..nrecv], 0, 0, Some(data_cb(hmap))) {
Ok(CbStatus::Ok) => continue,
ret => return ret,
},
Err(errno) => {
if errno.0 == libc::EAGAIN {
return Ok(CbStatus::Ok);
}
if errno.0 == libc::ENOBUFS {
println!(
"The daemon has hit ENOBUFS, you can \
increase the size of your receiver \
buffer to mitigate this or enable \
reliable delivery."
);
} else {
println!("mnl_socket_recvfrom: {}", errno);
}
return mnl::gen_errno!(errno.0);
}
}
}
}
pub const SO_RECVBUFFORCE: c_int = 33;
fn main() -> Result<(), String> {
let args: Vec<_> = env::args().collect();
if args.len() != 2 {
println!("\nUsage: {} <poll-secs>", args[0]);
process::exit(libc::EXIT_FAILURE);
}
let secs = args[1].parse::<u32>().unwrap();
println!("Polling every {} seconds from kernel...", secs);
unsafe {
libc::nice(-20);
};
let mut nl = Socket::open(libc::NETLINK_NETFILTER, 0)
.map_err(|errno| format!("mnl_socket_open: {}", errno))?;
nl.bind(nfct::NF_NETLINK_CONNTRACK_DESTROY, mnl::SOCKET_AUTOPID)
.map_err(|errno| format!("mnl_socket_bind: {}", errno))?;
unsafe {
let buffersize: c_int = 1 << 22;
libc::setsockopt(
nl.as_raw_fd(),
libc::SOL_SOCKET,
SO_RECVBUFFORCE,
&buffersize as *const _ as *const c_void,
mem::size_of::<socklen_t>() as u32,
);
}
let _ = nl.set_broadcast_error(true);
let _ = nl.set_no_enobufs(true);
let mut nlv = MsgVec::new();
let mut nlh = nlv.put_header();
nlh.nlmsg_type = (libc::NFNL_SUBSYS_CTNETLINK << 8) as u16 | nfct::IPCTNL_MSG_CT_GET_CTRZERO;
nlh.nlmsg_flags = (libc::NLM_F_REQUEST | libc::NLM_F_DUMP) as u16;
let nfh = nlv.put_extra_header::<Nfgenmsg>().unwrap();
nfh.nfgen_family = libc::AF_INET as u8;
nfh.version = libc::NFNETLINK_V0 as u8;
nfh.res_id = 0;
CtattrType::put_mark(&mut nlv, &0u32.to_be()).unwrap();
CtattrType::put_mark_mask(&mut nlv, &0xffffffffu32.to_be()).unwrap();
let mut hmap = HashMap::<IpAddr, Box<Nstats>>::new();
let token = Token(nl.as_raw_fd() as usize);
let mut listener = unsafe { UdpSocket::from_raw_fd(nl.as_raw_fd()) };
let mut timer = timerfd::Timerfd::create(libc::CLOCK_MONOTONIC, 0).unwrap();
timer
.settime(
0,
&timerfd::Itimerspec {
it_interval: Duration::new(secs as u64, 0),
it_value: Duration::new(0, 1),
},
)
.unwrap();
let mut poll = Poll::new().unwrap();
poll.registry()
.register(&mut listener, token, Interest::READABLE)
.unwrap();
poll.registry()
.register(&mut timer, Token(0), Interest::READABLE)
.unwrap();
let mut events = Events::with_capacity(256);
loop {
poll.poll(&mut events, None).unwrap();
for event in events.iter() {
match usize::from(event.token()) {
0 => {
timer.read().unwrap(); nl.sendto(&nlv)
.unwrap_or_else(|errno| panic!("mnl_socket_sendto: {}", errno));
for (addr, nstats) in hmap.iter() {
print!("src={:?} ", addr);
println!("counters {} {}", nstats.pkts, nstats.bytes);
}
}
_ => {
let _ = handle(&mut nl, &mut hmap).unwrap();
}
}
}
}
}