use std::{
marker::PhantomData,
net::{Ipv4Addr, Ipv6Addr},
};
use futures_util::stream::StreamExt;
use netlink_packet_core::{
NetlinkMessage, NLM_F_ACK, NLM_F_CREATE, NLM_F_EXCL, NLM_F_REPLACE,
NLM_F_REQUEST,
};
use netlink_packet_route::{
route::RouteHeader,
rule::{RuleAction, RuleAttribute, RuleMessage},
AddressFamily, RouteNetlinkMessage,
};
use crate::{try_nl, Error, Handle};
#[derive(Debug, Clone)]
pub struct RuleAddRequest<T = ()> {
handle: Handle,
message: RuleMessage,
replace: bool,
_phantom: PhantomData<T>,
}
impl<T> RuleAddRequest<T> {
pub(crate) fn new(handle: Handle) -> Self {
let mut message = RuleMessage::default();
message.header.table = RouteHeader::RT_TABLE_MAIN;
message.header.action = RuleAction::Unspec;
RuleAddRequest {
handle,
message,
replace: false,
_phantom: Default::default(),
}
}
pub fn input_interface(mut self, ifname: String) -> Self {
self.message.attributes.push(RuleAttribute::Iifname(ifname));
self
}
pub fn output_interface(mut self, ifname: String) -> Self {
self.message.attributes.push(RuleAttribute::Oifname(ifname));
self
}
#[deprecated(note = "Please use `table_id` instead")]
pub fn table(mut self, table: u8) -> Self {
self.message.header.table = table;
self
}
pub fn table_id(mut self, table: u32) -> Self {
if table > 255 {
self.message.attributes.push(RuleAttribute::Table(table));
} else {
self.message.header.table = table as u8;
}
self
}
pub fn tos(mut self, tos: u8) -> Self {
self.message.header.tos = tos;
self
}
pub fn action(mut self, action: RuleAction) -> Self {
self.message.header.action = action;
self
}
pub fn priority(mut self, priority: u32) -> Self {
self.message
.attributes
.push(RuleAttribute::Priority(priority));
self
}
pub fn fw_mark(mut self, fw_mark: u32) -> Self {
self.message.attributes.push(RuleAttribute::FwMark(fw_mark));
self
}
pub fn v4(mut self) -> RuleAddRequest<Ipv4Addr> {
self.message.header.family = AddressFamily::Inet;
RuleAddRequest {
handle: self.handle,
message: self.message,
replace: false,
_phantom: Default::default(),
}
}
pub fn v6(mut self) -> RuleAddRequest<Ipv6Addr> {
self.message.header.family = AddressFamily::Inet6;
RuleAddRequest {
handle: self.handle,
message: self.message,
replace: false,
_phantom: Default::default(),
}
}
pub fn replace(self) -> Self {
Self {
replace: true,
..self
}
}
pub async fn execute(self) -> Result<(), Error> {
let RuleAddRequest {
mut handle,
message,
replace,
..
} = self;
let mut req =
NetlinkMessage::from(RouteNetlinkMessage::NewRule(message));
let replace = if replace { NLM_F_REPLACE } else { NLM_F_EXCL };
req.header.flags = NLM_F_REQUEST | NLM_F_ACK | replace | NLM_F_CREATE;
let mut response = handle.request(req)?;
while let Some(message) = response.next().await {
try_nl!(message);
}
Ok(())
}
pub fn message_mut(&mut self) -> &mut RuleMessage {
&mut self.message
}
}
impl RuleAddRequest<Ipv4Addr> {
pub fn source_prefix(mut self, addr: Ipv4Addr, prefix_length: u8) -> Self {
self.message.header.src_len = prefix_length;
self.message
.attributes
.push(RuleAttribute::Source(addr.into()));
self
}
pub fn destination_prefix(
mut self,
addr: Ipv4Addr,
prefix_length: u8,
) -> Self {
self.message.header.dst_len = prefix_length;
self.message
.attributes
.push(RuleAttribute::Destination(addr.into()));
self
}
}
impl RuleAddRequest<Ipv6Addr> {
pub fn source_prefix(mut self, addr: Ipv6Addr, prefix_length: u8) -> Self {
self.message.header.src_len = prefix_length;
self.message
.attributes
.push(RuleAttribute::Source(addr.into()));
self
}
pub fn destination_prefix(
mut self,
addr: Ipv6Addr,
prefix_length: u8,
) -> Self {
self.message.header.dst_len = prefix_length;
self.message
.attributes
.push(RuleAttribute::Destination(addr.into()));
self
}
}