use std::net::IpAddr;
use winnow::{binary::le_u16, prelude::*, token::take};
use zerocopy::FromBytes;
use crate::netlink::{
parse::{FromNetlink, PResult, parse_ip_addr},
types::rule::{FibRuleAction, FibRuleHdr, FibRulePortRange, FibRuleUidRange},
};
mod attr_ids {
pub const FRA_DST: u16 = 1;
pub const FRA_SRC: u16 = 2;
pub const FRA_IIFNAME: u16 = 3;
pub const FRA_GOTO: u16 = 4;
pub const FRA_PRIORITY: u16 = 6;
pub const FRA_FWMARK: u16 = 10;
pub const FRA_FLOW: u16 = 11;
pub const FRA_TUN_ID: u16 = 12;
pub const FRA_SUPPRESS_IFGROUP: u16 = 13;
pub const FRA_SUPPRESS_PREFIXLEN: u16 = 14;
pub const FRA_TABLE: u16 = 15;
pub const FRA_FWMASK: u16 = 16;
pub const FRA_OIFNAME: u16 = 17;
pub const FRA_L3MDEV: u16 = 19;
pub const FRA_UID_RANGE: u16 = 20;
pub const FRA_PROTOCOL: u16 = 21;
pub const FRA_IP_PROTO: u16 = 22;
pub const FRA_SPORT_RANGE: u16 = 23;
pub const FRA_DPORT_RANGE: u16 = 24;
}
#[derive(Debug, Clone, Default)]
pub struct RuleMessage {
pub header: FibRuleHdr,
pub priority: u32,
pub source: Option<IpAddr>,
pub destination: Option<IpAddr>,
pub iifname: Option<String>,
pub oifname: Option<String>,
pub fwmark: Option<u32>,
pub fwmask: Option<u32>,
pub table: u32,
pub goto: Option<u32>,
pub flow: Option<u32>,
pub tun_id: Option<u64>,
pub suppress_ifgroup: Option<u32>,
pub suppress_prefixlen: Option<u32>,
pub l3mdev: Option<u8>,
pub uid_range: Option<FibRuleUidRange>,
pub protocol: Option<u8>,
pub ip_proto: Option<u8>,
pub sport_range: Option<FibRulePortRange>,
pub dport_range: Option<FibRulePortRange>,
}
impl RuleMessage {
pub fn new() -> Self {
Self::default()
}
pub fn family(&self) -> u8 {
self.header.family
}
pub fn is_ipv4(&self) -> bool {
self.header.family == libc::AF_INET as u8
}
pub fn is_ipv6(&self) -> bool {
self.header.family == libc::AF_INET6 as u8
}
pub fn src_len(&self) -> u8 {
self.header.src_len
}
pub fn dst_len(&self) -> u8 {
self.header.dst_len
}
pub fn action(&self) -> FibRuleAction {
FibRuleAction::from(self.header.action)
}
pub fn is_lookup(&self) -> bool {
self.action() == FibRuleAction::ToTbl
}
pub fn is_blackhole(&self) -> bool {
self.action() == FibRuleAction::Blackhole
}
pub fn is_unreachable(&self) -> bool {
self.action() == FibRuleAction::Unreachable
}
pub fn is_prohibit(&self) -> bool {
self.action() == FibRuleAction::Prohibit
}
pub fn table_id(&self) -> u32 {
self.table
}
pub fn is_default(&self) -> bool {
self.priority == 0 || self.priority == 32766 || self.priority == 32767
}
}
impl FromNetlink for RuleMessage {
fn write_dump_header(buf: &mut Vec<u8>) {
let header = FibRuleHdr::new();
buf.extend_from_slice(header.as_bytes());
}
fn parse(input: &mut &[u8]) -> PResult<Self> {
let header_bytes: &[u8] = take(FibRuleHdr::SIZE).parse_next(input)?;
let header = *FibRuleHdr::ref_from_bytes(header_bytes)
.map_err(|_| winnow::error::ErrMode::Cut(winnow::error::ContextError::new()))?;
let mut msg = RuleMessage {
table: header.table as u32,
header,
..Default::default()
};
while input.len() >= 4 {
let attr_len: u16 = le_u16.parse_next(input)?;
let attr_type: u16 = le_u16.parse_next(input)?;
if attr_len < 4 {
break;
}
let data_len = (attr_len as usize).saturating_sub(4);
if input.len() < data_len {
break;
}
let attr_data: &[u8] = take(data_len).parse_next(input)?;
let padding = (4 - (attr_len as usize % 4)) % 4;
if input.len() >= padding {
let _ = take(padding).parse_next(input)?;
}
let attr_type_masked = attr_type & 0x7fff;
match attr_type_masked {
attr_ids::FRA_PRIORITY if attr_data.len() >= 4 => {
msg.priority = u32::from_ne_bytes(attr_data[..4].try_into().unwrap());
}
attr_ids::FRA_SRC => {
if let Ok(addr) = parse_ip_addr(attr_data, msg.header.family) {
msg.source = Some(addr);
}
}
attr_ids::FRA_DST => {
if let Ok(addr) = parse_ip_addr(attr_data, msg.header.family) {
msg.destination = Some(addr);
}
}
attr_ids::FRA_IIFNAME => {
msg.iifname = parse_string(attr_data);
}
attr_ids::FRA_OIFNAME => {
msg.oifname = parse_string(attr_data);
}
attr_ids::FRA_FWMARK if attr_data.len() >= 4 => {
msg.fwmark = Some(u32::from_ne_bytes(attr_data[..4].try_into().unwrap()));
}
attr_ids::FRA_FWMASK if attr_data.len() >= 4 => {
msg.fwmask = Some(u32::from_ne_bytes(attr_data[..4].try_into().unwrap()));
}
attr_ids::FRA_TABLE if attr_data.len() >= 4 => {
msg.table = u32::from_ne_bytes(attr_data[..4].try_into().unwrap());
}
attr_ids::FRA_GOTO if attr_data.len() >= 4 => {
msg.goto = Some(u32::from_ne_bytes(attr_data[..4].try_into().unwrap()));
}
attr_ids::FRA_FLOW if attr_data.len() >= 4 => {
msg.flow = Some(u32::from_ne_bytes(attr_data[..4].try_into().unwrap()));
}
attr_ids::FRA_TUN_ID if attr_data.len() >= 8 => {
msg.tun_id = Some(u64::from_be_bytes(attr_data[..8].try_into().unwrap()));
}
attr_ids::FRA_SUPPRESS_IFGROUP if attr_data.len() >= 4 => {
msg.suppress_ifgroup =
Some(u32::from_ne_bytes(attr_data[..4].try_into().unwrap()));
}
attr_ids::FRA_SUPPRESS_PREFIXLEN if attr_data.len() >= 4 => {
msg.suppress_prefixlen =
Some(u32::from_ne_bytes(attr_data[..4].try_into().unwrap()));
}
attr_ids::FRA_L3MDEV if !attr_data.is_empty() => {
msg.l3mdev = Some(attr_data[0]);
}
attr_ids::FRA_UID_RANGE => {
if let Some(range) = FibRuleUidRange::from_bytes(attr_data) {
msg.uid_range = Some(*range);
}
}
attr_ids::FRA_PROTOCOL if !attr_data.is_empty() => {
msg.protocol = Some(attr_data[0]);
}
attr_ids::FRA_IP_PROTO if !attr_data.is_empty() => {
msg.ip_proto = Some(attr_data[0]);
}
attr_ids::FRA_SPORT_RANGE => {
if let Some(range) = FibRulePortRange::from_bytes(attr_data) {
msg.sport_range = Some(*range);
}
}
attr_ids::FRA_DPORT_RANGE => {
if let Some(range) = FibRulePortRange::from_bytes(attr_data) {
msg.dport_range = Some(*range);
}
}
_ => {}
}
}
Ok(msg)
}
}
fn parse_string(data: &[u8]) -> Option<String> {
let end = data.iter().position(|&b| b == 0).unwrap_or(data.len());
std::str::from_utf8(&data[..end]).ok().map(String::from)
}