use crate::constants::network_values::{PROTOCOL_TCP, PROTOCOL_UDP};
use crate::constants::{ebpf, TCP_ACK, TCP_FIN, TCP_RST, TCP_SYN};
use crate::stats::{Direction, StatItem, StatKey, StatValues};
use pnet::packet::ethernet::EthernetPacket;
use pnet::packet::{
ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, tcp::TcpPacket, udp::UdpPacket, Packet,
};
use std::collections::HashSet;
use std::net::Ipv4Addr;
pub fn determine_packet_direction(
src_ip: Ipv4Addr,
dst_ip: Ipv4Addr,
local_ips: &HashSet<Ipv4Addr>,
gateway_ip: Option<Ipv4Addr>,
hook_source: Option<u8>,
) -> Direction {
match hook_source {
Some(hook) if hook == ebpf::HOOK_XDP => direction_from_xdp(src_ip, dst_ip, gateway_ip),
Some(hook) if hook == ebpf::HOOK_TC_EGRESS => {
direction_from_tc_egress(src_ip, dst_ip, gateway_ip)
}
_ => direction_from_ip_fallback(src_ip, dst_ip, local_ips, gateway_ip),
}
}
fn direction_from_xdp(
src_ip: Ipv4Addr,
dst_ip: Ipv4Addr,
gateway_ip: Option<Ipv4Addr>,
) -> Direction {
let log_msg = if gateway_ip == Some(src_ip) {
format!("INTERNET DOWNLOAD (from gateway) src={src_ip} dst={dst_ip}")
} else {
format!("DOWNLOAD (XDP hook) src={src_ip} dst={dst_ip}")
};
tracing::debug!("Direction: {log_msg}");
Direction::Incoming
}
fn direction_from_tc_egress(
src_ip: Ipv4Addr,
dst_ip: Ipv4Addr,
gateway_ip: Option<Ipv4Addr>,
) -> Direction {
let log_msg = if gateway_ip == Some(dst_ip) {
format!("INTERNET UPLOAD (to gateway) src={src_ip} dst={dst_ip}")
} else {
format!("UPLOAD (TC egress hook) src={src_ip} dst={dst_ip}")
};
tracing::debug!("Direction: {log_msg}");
Direction::Outgoing
}
fn direction_from_ip_fallback(
src_ip: Ipv4Addr,
dst_ip: Ipv4Addr,
local_ips: &HashSet<Ipv4Addr>,
gateway_ip: Option<Ipv4Addr>,
) -> Direction {
let src_is_local = local_ips.contains(&src_ip);
let dst_is_local = local_ips.contains(&dst_ip);
if let Some(gw) = gateway_ip {
if src_is_local && dst_ip == gw {
tracing::debug!(
"Direction: INTERNET UPLOAD (IP-based) src={} -> dst={} (gateway)",
src_ip,
dst_ip
);
return Direction::Outgoing;
}
if src_ip == gw && dst_is_local {
tracing::debug!(
"Direction: INTERNET DOWNLOAD (IP-based) src={} (gateway) -> dst={}",
src_ip,
dst_ip
);
return Direction::Incoming;
}
}
if src_is_local && dst_is_local {
tracing::debug!(
"LOCAL: src_ip={} (local) -> dst_ip={} (local)",
src_ip,
dst_ip
);
Direction::Local
} else if src_is_local {
tracing::debug!(
"UPLOAD (IP-based): src_ip={} (local) -> dst_ip={} (remote)",
src_ip,
dst_ip
);
Direction::Outgoing
} else if dst_is_local {
tracing::debug!(
"DOWNLOAD (IP-based): src_ip={} (remote) -> dst_ip={} (local)",
src_ip,
dst_ip
);
Direction::Incoming
} else {
Direction::None
}
}
pub fn calculate_packet_size_bits(ipv4_packet: &Ipv4Packet, original_len: Option<u32>) -> u128 {
use crate::constants::network;
if let Some(len) = original_len {
8 * len as u128
} else {
((ipv4_packet.get_total_length() as u128) + network::ETHERNET_HEADER_SIZE as u128) * 8
}
}
pub fn create_tcp_stat(
ipv4_packet: &Ipv4Packet,
direction: Direction,
size_bits: u128,
) -> Option<StatItem> {
let message = TcpPacket::new(ipv4_packet.payload())?;
let flags = message.get_flags();
Some(StatItem {
key: StatKey {
direction,
src_port: message.get_source(),
dst_port: message.get_destination(),
src_ip: ipv4_packet.get_source(),
dst_ip: ipv4_packet.get_destination(),
protocol: PROTOCOL_TCP,
tcp_syn: (flags & TCP_SYN) != 0,
tcp_ack: (flags & TCP_ACK) != 0,
tcp_fin: (flags & TCP_FIN) != 0,
tcp_rst: (flags & TCP_RST) != 0,
},
value: StatValues {
size: size_bits,
last_timestamp: None, last_seq: Some(message.get_sequence()),
last_ack: Some(message.get_acknowledgement()),
},
})
}
pub fn create_udp_stat(
ipv4_packet: &Ipv4Packet,
direction: Direction,
size_bits: u128,
) -> Option<StatItem> {
let datagram = UdpPacket::new(ipv4_packet.payload())?;
Some(StatItem {
key: StatKey {
direction,
src_port: datagram.get_source(),
dst_port: datagram.get_destination(),
src_ip: ipv4_packet.get_source(),
dst_ip: ipv4_packet.get_destination(),
protocol: PROTOCOL_UDP,
tcp_syn: false, tcp_ack: false,
tcp_fin: false,
tcp_rst: false,
},
value: StatValues {
size: size_bits,
last_timestamp: None,
last_seq: None,
last_ack: None,
},
})
}
pub fn get_stats(
ethernet_packet: &EthernetPacket,
local_ips: &HashSet<Ipv4Addr>,
gateway_ip: Option<Ipv4Addr>,
hook_source: Option<u8>,
original_len: Option<u32>,
) -> Option<StatItem> {
let ipv4_packet = Ipv4Packet::new(ethernet_packet.payload())?;
let src_ip = ipv4_packet.get_source();
let dst_ip = ipv4_packet.get_destination();
let next_level_protocol = ipv4_packet.get_next_level_protocol();
let direction = determine_packet_direction(src_ip, dst_ip, local_ips, gateway_ip, hook_source);
let size_bits = calculate_packet_size_bits(&ipv4_packet, original_len);
if let Some(len) = original_len {
tracing::info!(
"BANDWIDTH: original_len={} bytes -> {} bits direction={:?}",
len,
size_bits,
direction
);
}
match next_level_protocol {
IpNextHeaderProtocols::Tcp => create_tcp_stat(&ipv4_packet, direction, size_bits),
IpNextHeaderProtocols::Udp => create_udp_stat(&ipv4_packet, direction, size_bits),
_ => None,
}
}