use std::net::{Ipv4Addr, Ipv6Addr};
use cidr::IpCidr;
use futures::TryStreamExt;
use netlink_packet_route::{
AddressFamily,
route::{RouteAddress, RouteAttribute, RouteProtocol, RouteScope, RouteType},
};
use rtnetlink::{Error, Handle, RouteMessageBuilder, new_connection};
use sysctl::{Sysctl, SysctlError};
pub async fn sysctl_route_add_cidr(subnet: &IpCidr) {
let (connection, handle, _) = new_connection().unwrap();
tokio::spawn(connection);
if let Err(e) = add_route(handle.clone(), subnet).await {
tracing::trace!("Failed to apply route: {}", e);
}
}
async fn add_route(handle: Handle, cidr: &IpCidr) -> Result<(), Error> {
const LOCAL_TABLE_ID: u8 = 255;
let iface_idx = handle
.link()
.get()
.match_name("lo".to_owned())
.execute()
.try_next()
.await?
.unwrap()
.header
.index;
let route = handle.route();
let (route_message, address_family, destination_prefix_length, route_address) = match cidr {
IpCidr::V4(v4) => (
RouteMessageBuilder::<Ipv4Addr>::new().build(),
AddressFamily::Inet,
v4.network_length(),
RouteAddress::Inet(v4.first_address()),
),
IpCidr::V6(v6) => (
RouteMessageBuilder::<Ipv6Addr>::new().build(),
AddressFamily::Inet6,
v6.network_length(),
RouteAddress::Inet6(v6.first_address()),
),
};
let mut routes = route.get(route_message).execute();
while let Some(route) = routes.try_next().await? {
let header = route.header;
tracing::trace!(
"route attributes: {:?}\nroute header: {:?}",
route.attributes,
header
);
if header.address_family == address_family
&& header.destination_prefix_length == destination_prefix_length
&& header.table == LOCAL_TABLE_ID
{
for attr in route.attributes.iter() {
if let RouteAttribute::Destination(dest) = attr {
if dest == &route_address {
tracing::info!("Route {} already exists on loopback interface", cidr);
return Ok(());
}
}
}
}
}
let route_message = match cidr {
IpCidr::V4(v4) => RouteMessageBuilder::<Ipv4Addr>::new()
.destination_prefix(v4.first_address(), v4.network_length())
.kind(RouteType::Local)
.protocol(RouteProtocol::Boot)
.scope(RouteScope::Universe)
.output_interface(iface_idx)
.priority(1024)
.table_id(LOCAL_TABLE_ID.into())
.build(),
IpCidr::V6(v6) => RouteMessageBuilder::<Ipv6Addr>::new()
.destination_prefix(v6.first_address(), v6.network_length())
.kind(RouteType::Local)
.protocol(RouteProtocol::Boot)
.scope(RouteScope::Universe)
.output_interface(iface_idx)
.priority(1024)
.table_id(LOCAL_TABLE_ID.into())
.build(),
};
route.add(route_message).execute().await?;
tracing::info!("Added route {} to loopback interface", cidr);
Ok(())
}
pub fn sysctl_ipv6_no_local_bind(subnet: &IpCidr) {
if subnet.is_ipv6() {
if let Err(err) = execute_sysctl("net.ipv6.ip_nonlocal_bind", "1") {
tracing::trace!("Failed to execute sysctl: {}", err)
}
}
}
pub fn sysctl_ipv6_all_enable_ipv6(subnet: &IpCidr) {
if subnet.is_ipv6() {
if let Err(err) = execute_sysctl("net.ipv6.conf.all.disable_ipv6", "0") {
tracing::trace!("Failed to execute sysctl: {}", err)
}
}
}
fn execute_sysctl(command: &str, value: &str) -> Result<(), SysctlError> {
let ctl = <sysctl::Ctl as Sysctl>::new(command)?;
assert_eq!(command, ctl.name()?);
let old_value = ctl.value_string()?;
tracing::trace!("Sysctl '{}' old value: {}", command, old_value);
ctl.set_value_string(value).map(|_| {
tracing::trace!("Sysctl '{}' value: {}", command, value);
})
}