#![cfg(target_os = "linux")]
#[cfg(target_endian = "big")]
compile_error!(
"tcp_diag is not supported on big-endian Linux. \
`tcpi_bytes_acked`/`tcpi_bytes_received` extraction uses native-endian \
reads that would misparse on big-endian hosts. \
If you need this target, port `extract_tcp_counters` to explicit \
`from_le_bytes` and remove this guard."
);
use std::collections::HashMap;
use std::panic::{self, AssertUnwindSafe};
use netlink_packet_core::{
NLM_F_DUMP, NLM_F_REQUEST, NetlinkHeader, NetlinkMessage, NetlinkPayload,
};
use netlink_packet_sock_diag::{
SockDiagMessage,
constants::{AF_INET, AF_INET6, IPPROTO_TCP},
inet::{ExtensionFlags, InetRequest, SocketId, StateFlags, nlas::Nla},
};
use netlink_sys::{Socket, SocketAddr, protocols::NETLINK_SOCK_DIAG};
const MAX_SOCKETS: usize = 4096;
const RECV_BUF_LEN: usize = 32 * 1024;
pub fn sample_per_local_port() -> HashMap<u16, (u64, u64)> {
let mut out = HashMap::new();
let mut truncated = false;
for family in [AF_INET, AF_INET6] {
match dump_family(family, &mut out, &mut truncated) {
Ok(()) => {}
Err(e) => {
log::warn!("[external] netlink INET_DIAG family={family} dump failed: {e}",);
}
}
if out.len() >= MAX_SOCKETS {
break;
}
}
if truncated {
log::warn!(
"[external] netlink dump truncated at {MAX_SOCKETS} sockets; counters for additional sockets are skipped this tick",
);
}
log::debug!("[external] netlink returned {} socket(s)", out.len());
out
}
fn dump_family(
family: u8,
out: &mut HashMap<u16, (u64, u64)>,
truncated: &mut bool,
) -> Result<(), DumpError> {
let socket_id = match family {
AF_INET => SocketId::new_v4(),
AF_INET6 => SocketId::new_v6(),
_ => return Err(DumpError::UnsupportedFamily(family)),
};
let mut socket = Socket::new(NETLINK_SOCK_DIAG).map_err(DumpError::OpenSocket)?;
socket.bind_auto().map_err(DumpError::Bind)?;
let request = InetRequest {
family,
protocol: IPPROTO_TCP,
extensions: ExtensionFlags::INFO,
states: StateFlags::ESTABLISHED,
socket_id,
};
let mut header = NetlinkHeader::default();
header.flags = NLM_F_REQUEST | NLM_F_DUMP;
let mut packet = NetlinkMessage::new(
header,
NetlinkPayload::from(SockDiagMessage::InetRequest(request)),
);
packet.finalize();
let mut send_buf = vec![0u8; packet.buffer_len()];
packet.serialize(&mut send_buf);
let kernel_addr = SocketAddr::new(0, 0);
socket
.send_to(&send_buf, &kernel_addr, 0)
.map_err(DumpError::Send)?;
catch_panic(|| drain_responses(&socket, out, truncated))
}
fn catch_panic<F>(f: F) -> Result<(), DumpError>
where
F: FnOnce() -> Result<(), DumpError>,
{
match panic::catch_unwind(AssertUnwindSafe(f)) {
Ok(Ok(())) => Ok(()),
Ok(Err(e)) => Err(e),
Err(_) => Err(DumpError::ParsePanic),
}
}
fn drain_responses(
socket: &Socket,
out: &mut HashMap<u16, (u64, u64)>,
truncated: &mut bool,
) -> Result<(), DumpError> {
let mut recv_buf = vec![0u8; RECV_BUF_LEN];
loop {
let mut slice = &mut recv_buf[..];
let (size, _addr) = socket.recv_from(&mut slice, 0).map_err(DumpError::Recv)?;
let bytes = &recv_buf[..size];
match process_batch(bytes, out, truncated)? {
BatchOutcome::Continue => {}
BatchOutcome::Done => return Ok(()),
}
if out.len() >= MAX_SOCKETS {
return Ok(());
}
}
}
fn process_batch(
bytes: &[u8],
out: &mut HashMap<u16, (u64, u64)>,
truncated: &mut bool,
) -> Result<BatchOutcome, DumpError> {
let mut offset = 0usize;
while offset < bytes.len() {
let remaining = &bytes[offset..];
let raw_len = match peek_msg_len(remaining) {
Some(l) => l,
None => {
log::debug!(
"[external] netlink: header truncated ({} bytes), aborting batch",
remaining.len()
);
return Ok(BatchOutcome::Done);
}
};
if raw_len == 0 || raw_len > remaining.len() {
log::debug!("[external] netlink: invalid header length ({raw_len}), aborting batch");
return Ok(BatchOutcome::Done);
}
match NetlinkMessage::<SockDiagMessage>::deserialize(remaining) {
Ok(msg) => match msg.payload {
NetlinkPayload::Done(_) => return Ok(BatchOutcome::Done),
NetlinkPayload::Error(err) => {
return Err(DumpError::KernelError(format!("{err:?}")));
}
NetlinkPayload::Overrun(_) => {
log::warn!(
"[external] netlink OVERRUN: kernel dropped messages, counters incomplete this tick",
);
return Ok(BatchOutcome::Done);
}
NetlinkPayload::InnerMessage(SockDiagMessage::InetResponse(resp)) => {
if out.len() >= MAX_SOCKETS {
*truncated = true;
} else {
merge_response(&resp, out);
}
}
_ => {}
},
Err(e) => {
log::debug!(
"[external] netlink: skipping malformed message at offset {offset}: {e}"
);
}
}
offset += align(raw_len);
}
Ok(BatchOutcome::Continue)
}
fn peek_msg_len(bytes: &[u8]) -> Option<usize> {
if bytes.len() < 4 {
return None;
}
let raw: [u8; 4] = bytes[0..4].try_into().ok()?;
Some(u32::from_ne_bytes(raw) as usize)
}
const TCP_INFO_BYTES_ACKED_OFFSET: usize = 120;
const TCP_INFO_BYTES_RECEIVED_OFFSET: usize = 128;
const TCP_INFO_MIN_LEN: usize = TCP_INFO_BYTES_RECEIVED_OFFSET + 8;
fn merge_response(
resp: &netlink_packet_sock_diag::inet::InetResponse,
out: &mut HashMap<u16, (u64, u64)>,
) {
let local_port = resp.header.socket_id.source_port;
if local_port == 0 {
return;
}
for nla in &resp.nlas {
if let Nla::TcpInfo(bytes) = nla {
if let Some(counters) = extract_tcp_counters(bytes) {
out.insert(local_port, counters);
}
return;
}
}
}
fn extract_tcp_counters(bytes: &[u8]) -> Option<(u64, u64)> {
if bytes.len() < TCP_INFO_MIN_LEN {
return None;
}
let acked = u64::from_ne_bytes(
bytes[TCP_INFO_BYTES_ACKED_OFFSET..TCP_INFO_BYTES_ACKED_OFFSET + 8]
.try_into()
.ok()?,
);
let received = u64::from_ne_bytes(
bytes[TCP_INFO_BYTES_RECEIVED_OFFSET..TCP_INFO_BYTES_RECEIVED_OFFSET + 8]
.try_into()
.ok()?,
);
Some((received, acked))
}
fn align(len: usize) -> usize {
(len + 3) & !3
}
#[derive(Debug)]
enum DumpError {
OpenSocket(std::io::Error),
Bind(std::io::Error),
Send(std::io::Error),
Recv(std::io::Error),
KernelError(String),
UnsupportedFamily(u8),
ParsePanic,
}
impl std::fmt::Display for DumpError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DumpError::OpenSocket(e) => write!(f, "open netlink socket: {e}"),
DumpError::Bind(e) => write!(f, "bind netlink socket: {e}"),
DumpError::Send(e) => write!(f, "send INET_DIAG request: {e}"),
DumpError::Recv(e) => write!(f, "recv from netlink: {e}"),
DumpError::KernelError(s) => write!(f, "kernel returned NLMSG_ERROR: {s}"),
DumpError::UnsupportedFamily(fam) => write!(f, "unsupported family {fam}"),
DumpError::ParsePanic => write!(f, "panic in netlink parser"),
}
}
}
#[derive(Debug)]
enum BatchOutcome {
Continue,
Done,
}
#[cfg(test)]
mod tests {
use super::*;
use netlink_packet_sock_diag::inet::{InetResponse, InetResponseHeader};
fn tcp_info_bytes(bytes_received: u64, bytes_acked: u64) -> Vec<u8> {
let mut v = vec![0u8; TCP_INFO_MIN_LEN];
v[TCP_INFO_BYTES_ACKED_OFFSET..TCP_INFO_BYTES_ACKED_OFFSET + 8]
.copy_from_slice(&bytes_acked.to_ne_bytes());
v[TCP_INFO_BYTES_RECEIVED_OFFSET..TCP_INFO_BYTES_RECEIVED_OFFSET + 8]
.copy_from_slice(&bytes_received.to_ne_bytes());
v
}
fn make_response(local_port: u16, rx: u64, tx: u64) -> InetResponse {
let mut socket_id = SocketId::new_v4();
socket_id.source_port = local_port;
let mut resp = InetResponse {
header: InetResponseHeader {
family: AF_INET,
state: 1, timer: None,
socket_id,
recv_queue: 0,
send_queue: 0,
uid: 1000,
inode: 12345,
},
nlas: Default::default(),
};
resp.nlas.push(Nla::TcpInfo(tcp_info_bytes(rx, tx)));
resp
}
#[test]
fn extract_tcp_counters_reads_offsets() {
let bytes = tcp_info_bytes(1_024_000, 64_000);
assert_eq!(extract_tcp_counters(&bytes), Some((1_024_000, 64_000)));
}
#[test]
fn extract_tcp_counters_returns_none_for_short_payload() {
assert_eq!(extract_tcp_counters(&[0u8; 100]), None);
}
#[test]
fn extract_tcp_counters_accepts_longer_payload() {
let mut bytes = tcp_info_bytes(99, 88);
bytes.extend_from_slice(&[0u8; 64]); assert_eq!(extract_tcp_counters(&bytes), Some((99, 88)));
}
#[test]
fn merge_response_extracts_rx_and_tx() {
let resp = make_response(54321, 1_024_000, 64_000);
let mut out = HashMap::new();
merge_response(&resp, &mut out);
assert_eq!(out.get(&54321), Some(&(1_024_000, 64_000)));
}
#[test]
fn merge_response_skips_zero_local_port() {
let resp = make_response(0, 999, 999);
let mut out = HashMap::new();
merge_response(&resp, &mut out);
assert!(out.is_empty());
}
#[test]
fn merge_response_skips_when_no_tcp_info_nla() {
let mut resp = make_response(54321, 1, 2);
resp.nlas.clear();
let mut out = HashMap::new();
merge_response(&resp, &mut out);
assert!(out.is_empty());
}
#[test]
fn merge_response_skips_when_tcp_info_payload_truncated() {
let mut resp = make_response(54321, 1, 2);
resp.nlas.clear();
resp.nlas.push(Nla::TcpInfo(vec![0u8; 50])); let mut out = HashMap::new();
merge_response(&resp, &mut out);
assert!(out.is_empty());
}
#[test]
fn merge_response_overwrites_existing_port_entry() {
let mut out = HashMap::new();
merge_response(&make_response(8080, 100, 50), &mut out);
merge_response(&make_response(8080, 200, 75), &mut out);
assert_eq!(out.get(&8080), Some(&(200, 75)));
}
#[test]
fn align_rounds_up_to_4_bytes() {
assert_eq!(align(0), 0);
assert_eq!(align(1), 4);
assert_eq!(align(4), 4);
assert_eq!(align(5), 8);
assert_eq!(align(7), 8);
assert_eq!(align(8), 8);
assert_eq!(align(9), 12);
}
#[test]
fn process_batch_handles_empty_input() {
let mut out = HashMap::new();
let mut truncated = false;
let outcome = process_batch(&[], &mut out, &mut truncated).unwrap();
assert!(matches!(outcome, BatchOutcome::Continue));
assert!(out.is_empty());
}
#[test]
fn process_batch_aborts_on_zero_header_length() {
let buf = vec![0u8; 16];
let mut out = HashMap::new();
let mut truncated = false;
let outcome = process_batch(&buf, &mut out, &mut truncated).unwrap();
assert!(matches!(outcome, BatchOutcome::Done));
assert!(out.is_empty());
}
fn serialize_inet_response(resp: InetResponse) -> Vec<u8> {
let header = NetlinkHeader::default();
let payload = NetlinkPayload::from(SockDiagMessage::InetResponse(Box::new(resp)));
let mut packet = NetlinkMessage::new(header, payload);
packet.finalize();
let mut buf = vec![0u8; packet.buffer_len()];
packet.serialize(&mut buf);
buf
}
fn serialize_nlmsg_done() -> Vec<u8> {
let header = NetlinkHeader::default();
let packet: NetlinkMessage<SockDiagMessage> =
NetlinkMessage::new(header, NetlinkPayload::Done(Default::default()));
let mut packet = packet;
packet.finalize();
let mut buf = vec![0u8; packet.buffer_len()];
packet.serialize(&mut buf);
buf
}
#[test]
fn process_batch_walks_multiple_inet_responses() {
let mut buf = serialize_inet_response(make_response(8080, 1_000, 500));
buf.extend(serialize_inet_response(make_response(9090, 2_000, 1_000)));
let mut out = HashMap::new();
let mut truncated = false;
let outcome = process_batch(&buf, &mut out, &mut truncated).unwrap();
assert!(matches!(outcome, BatchOutcome::Continue));
assert_eq!(out.get(&8080), Some(&(1_000, 500)));
assert_eq!(out.get(&9090), Some(&(2_000, 1_000)));
assert!(!truncated);
}
#[test]
fn process_batch_stops_on_nlmsg_done_after_responses() {
let mut buf = serialize_inet_response(make_response(7777, 42, 24));
buf.extend(serialize_nlmsg_done());
let mut out = HashMap::new();
let mut truncated = false;
let outcome = process_batch(&buf, &mut out, &mut truncated).unwrap();
assert!(matches!(outcome, BatchOutcome::Done));
assert_eq!(out.get(&7777), Some(&(42, 24)));
}
#[test]
fn process_batch_advances_past_malformed_message_inside_batch() {
let valid_a = serialize_inet_response(make_response(11111, 5, 3));
let valid_b = serialize_inet_response(make_response(22222, 7, 11));
let mut bad: Vec<u8> = Vec::new();
bad.extend_from_slice(&16u32.to_ne_bytes()); bad.extend_from_slice(&99u16.to_ne_bytes()); bad.extend_from_slice(&0u16.to_ne_bytes()); bad.extend_from_slice(&0u32.to_ne_bytes()); bad.extend_from_slice(&0u32.to_ne_bytes());
let mut buf = valid_a;
buf.extend(&bad);
buf.extend(valid_b);
let mut out = HashMap::new();
let mut truncated = false;
let outcome = process_batch(&buf, &mut out, &mut truncated).unwrap();
assert!(matches!(outcome, BatchOutcome::Continue));
assert_eq!(out.get(&11111), Some(&(5, 3)));
assert_eq!(
out.get(&22222),
Some(&(7, 11)),
"the malformed middle message must not silently discard the valid one that follows"
);
}
#[test]
fn process_batch_sets_truncated_flag_when_cap_reached() {
let mut out: HashMap<u16, (u64, u64)> = HashMap::new();
for port in 1u16..=MAX_SOCKETS as u16 {
out.insert(port, (0, 0));
}
assert_eq!(out.len(), MAX_SOCKETS);
let buf = serialize_inet_response(make_response(50_000, 999, 888));
let mut truncated = false;
let _ = process_batch(&buf, &mut out, &mut truncated).unwrap();
assert!(truncated, "truncated flag must flip when cap is reached");
assert!(
!out.contains_key(&50_000),
"capped insertion must be skipped"
);
assert_eq!(out.len(), MAX_SOCKETS, "size must not exceed cap");
}
#[test]
fn process_batch_returns_kernel_error_on_nlmsg_error_bytes() {
const NLMSG_ERROR: u16 = 2;
let mut bytes: Vec<u8> = Vec::new();
bytes.extend_from_slice(&36u32.to_ne_bytes()); bytes.extend_from_slice(&NLMSG_ERROR.to_ne_bytes()); bytes.extend_from_slice(&0u16.to_ne_bytes()); bytes.extend_from_slice(&0u32.to_ne_bytes()); bytes.extend_from_slice(&0u32.to_ne_bytes()); bytes.extend_from_slice(&(-1i32).to_ne_bytes()); bytes.extend_from_slice(&[0u8; 16]);
let mut out = HashMap::new();
let mut truncated = false;
let result = process_batch(&bytes, &mut out, &mut truncated);
assert!(
matches!(result, Err(DumpError::KernelError(_))),
"expected DumpError::KernelError, got {result:?}",
);
}
#[test]
fn extract_tcp_counters_returns_none_at_min_len_minus_one() {
let bytes = vec![0u8; TCP_INFO_MIN_LEN - 1];
assert_eq!(extract_tcp_counters(&bytes), None);
}
#[test]
fn extract_tcp_counters_returns_zeros_for_zeroed_payload() {
let bytes = tcp_info_bytes(0, 0);
assert_eq!(extract_tcp_counters(&bytes), Some((0, 0)));
}
#[test]
fn extract_tcp_counters_handles_max_u64_counters() {
let bytes = tcp_info_bytes(u64::MAX, u64::MAX);
assert_eq!(extract_tcp_counters(&bytes), Some((u64::MAX, u64::MAX)));
}
#[test]
fn peek_msg_len_reads_first_4_bytes_native_endian() {
let mut buf = Vec::new();
buf.extend_from_slice(&64u32.to_ne_bytes());
buf.extend_from_slice(&[0u8; 12]);
assert_eq!(peek_msg_len(&buf), Some(64));
}
#[test]
fn peek_msg_len_returns_none_for_short_input() {
assert_eq!(peek_msg_len(&[]), None);
assert_eq!(peek_msg_len(&[0u8; 3]), None);
}
#[test]
fn catch_panic_returns_ok_for_ok_closure() {
let result = catch_panic(|| Ok(()));
assert!(matches!(result, Ok(())));
}
#[test]
fn catch_panic_propagates_inner_dump_error() {
let result = catch_panic(|| Err(DumpError::UnsupportedFamily(99)));
assert!(matches!(result, Err(DumpError::UnsupportedFamily(99))));
}
#[test]
fn catch_panic_converts_panic_into_parse_panic_variant() {
let result = catch_panic(|| {
panic!("simulated upstream parser panic");
});
assert!(
matches!(result, Err(DumpError::ParsePanic)),
"panic in closure must convert to DumpError::ParsePanic, got {result:?}",
);
}
#[test]
fn extract_tcp_counters_pins_little_endian_assumption_on_supported_targets() {
let mut bytes = vec![0u8; TCP_INFO_MIN_LEN];
bytes[TCP_INFO_BYTES_ACKED_OFFSET..TCP_INFO_BYTES_ACKED_OFFSET + 8]
.copy_from_slice(&64_000_u64.to_le_bytes());
bytes[TCP_INFO_BYTES_RECEIVED_OFFSET..TCP_INFO_BYTES_RECEIVED_OFFSET + 8]
.copy_from_slice(&1_024_000_u64.to_le_bytes());
assert_eq!(extract_tcp_counters(&bytes), Some((1_024_000, 64_000)));
}
}