use std::{os::unix::io::RawFd, path::Path, time::Duration};
use tracing::{instrument, warn};
use super::{
builder::MessageBuilder,
error::{Error, Result},
interface_ref::InterfaceRef,
message::{
MessageIter, NLM_F_ACK, NLM_F_DUMP, NLM_F_REQUEST, NLMSG_HDRLEN, NlMsgError, NlMsgHdr,
NlMsgType,
},
parse::FromNetlink,
protocol::{ProtocolState, Route},
socket::NetlinkSocket,
tc_handle::TcHandle,
};
pub struct Connection<P: ProtocolState> {
socket: NetlinkSocket,
state: P,
timeout: Option<Duration>,
}
impl<P: ProtocolState + Default> Connection<P> {
#[instrument(level = "info", skip_all, fields(protocol = std::any::type_name::<P>()))]
pub fn new() -> Result<Self> {
Ok(Self {
socket: NetlinkSocket::new(P::PROTOCOL)?,
state: P::default(),
timeout: None,
})
}
#[instrument(level = "info", skip_all, fields(protocol = std::any::type_name::<P>(), ns_fd))]
pub fn new_in_namespace(ns_fd: RawFd) -> Result<Self> {
Ok(Self {
socket: NetlinkSocket::new_in_namespace(P::PROTOCOL, ns_fd)?,
state: P::default(),
timeout: None,
})
}
#[instrument(level = "info", skip_all, fields(protocol = std::any::type_name::<P>(), ns_path = %ns_path.as_ref().display()))]
pub fn new_in_namespace_path<T: AsRef<Path>>(ns_path: T) -> Result<Self> {
Ok(Self {
socket: NetlinkSocket::new_in_namespace_path(P::PROTOCOL, ns_path)?,
state: P::default(),
timeout: None,
})
}
}
impl<P: ProtocolState> Connection<P> {
pub fn socket(&self) -> &NetlinkSocket {
&self.socket
}
pub(crate) fn socket_mut(&mut self) -> &mut NetlinkSocket {
&mut self.socket
}
pub fn state(&self) -> &P {
&self.state
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn no_timeout(mut self) -> Self {
self.timeout = None;
self
}
pub fn get_timeout(&self) -> Option<Duration> {
self.timeout
}
pub(crate) async fn with_timeout<F, T>(&self, fut: F) -> Result<T>
where
F: std::future::Future<Output = Result<T>>,
{
match self.timeout {
Some(dur) => tokio::time::timeout(dur, fut)
.await
.map_err(|_| Error::Timeout)?,
None => fut.await,
}
}
pub(crate) fn from_parts(socket: NetlinkSocket, state: P) -> Self {
Self {
socket,
state,
timeout: None,
}
}
pub(crate) async fn send_request(&self, builder: MessageBuilder) -> Result<Vec<u8>> {
self.with_timeout(self.send_request_inner(builder)).await
}
pub(crate) async fn send_ack(&self, builder: MessageBuilder) -> Result<()> {
self.with_timeout(self.send_ack_inner(builder)).await
}
pub(crate) async fn send_dump(&self, builder: MessageBuilder) -> Result<Vec<Vec<u8>>> {
self.with_timeout(self.send_dump_inner(builder)).await
}
#[instrument(level = "trace", skip_all, fields(seq))]
async fn send_request_inner(&self, mut builder: MessageBuilder) -> Result<Vec<u8>> {
let seq = self.socket.next_seq();
builder.set_seq(seq);
builder.set_pid(self.socket.pid());
tracing::Span::current().record("seq", seq);
let msg = builder.finish();
self.socket.send(&msg).await?;
let response = self.socket.recv_msg().await?;
self.process_response(&response, seq).inspect_err(|e| {
warn!(errno = ?e.errno(), "kernel returned error for request");
})?;
Ok(response)
}
#[instrument(level = "trace", skip_all, fields(seq))]
async fn send_ack_inner(&self, mut builder: MessageBuilder) -> Result<()> {
let seq = self.socket.next_seq();
builder.set_seq(seq);
builder.set_pid(self.socket.pid());
tracing::Span::current().record("seq", seq);
let msg = builder.finish();
self.socket.send(&msg).await?;
let response = self.socket.recv_msg().await?;
self.process_ack(&response, seq).inspect_err(|e| {
warn!(errno = ?e.errno(), "kernel returned error for ack");
})?;
Ok(())
}
#[instrument(level = "trace", skip_all, fields(seq, responses))]
async fn send_dump_inner(&self, mut builder: MessageBuilder) -> Result<Vec<Vec<u8>>> {
let seq = self.socket.next_seq();
builder.set_seq(seq);
builder.set_pid(self.socket.pid());
tracing::Span::current().record("seq", seq);
let msg = builder.finish();
self.socket.send(&msg).await?;
let mut responses = Vec::new();
loop {
let data = self.socket.recv_msg().await?;
let mut done = false;
for result in MessageIter::new(&data) {
let (header, payload) = result?;
if header.nlmsg_seq != seq {
continue;
}
if header.is_error() {
let err = NlMsgError::from_bytes(payload)?;
if !err.is_ack() {
return Err(Error::from_errno(err.error));
}
}
if header.is_done() {
done = true;
break;
}
let msg_len = header.nlmsg_len as usize;
let msg_start = payload.as_ptr() as usize
- data.as_ptr() as usize
- std::mem::size_of::<NlMsgHdr>();
if msg_start + msg_len <= data.len() {
responses.push(data[msg_start..msg_start + msg_len].to_vec());
}
}
if done {
break;
}
}
tracing::Span::current().record("responses", responses.len());
Ok(responses)
}
fn process_response(&self, data: &[u8], expected_seq: u32) -> Result<()> {
for result in MessageIter::new(data) {
let (header, payload) = result?;
if header.nlmsg_seq != expected_seq {
continue;
}
if header.is_error() {
let err = NlMsgError::from_bytes(payload)?;
if !err.is_ack() {
return Err(Error::from_errno(err.error));
}
}
}
Ok(())
}
fn process_ack(&self, data: &[u8], expected_seq: u32) -> Result<()> {
for result in MessageIter::new(data) {
let (header, payload) = result?;
if header.nlmsg_seq != expected_seq {
continue;
}
if header.is_error() {
let err = NlMsgError::from_bytes(payload)?;
if !err.is_ack() {
return Err(Error::from_errno(err.error));
}
return Ok(());
}
}
Err(Error::InvalidMessage("expected ACK message".into()))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum RtnetlinkGroup {
Link,
Ipv4Addr,
Ipv6Addr,
Ipv4Route,
Ipv6Route,
Neigh,
Tc,
NsId,
Ipv4Rule,
Ipv6Rule,
}
impl RtnetlinkGroup {
fn to_group(self) -> u32 {
use super::socket::rtnetlink_groups::*;
match self {
Self::Link => RTNLGRP_LINK,
Self::Ipv4Addr => RTNLGRP_IPV4_IFADDR,
Self::Ipv6Addr => RTNLGRP_IPV6_IFADDR,
Self::Ipv4Route => RTNLGRP_IPV4_ROUTE,
Self::Ipv6Route => RTNLGRP_IPV6_ROUTE,
Self::Neigh => RTNLGRP_NEIGH,
Self::Tc => RTNLGRP_TC,
Self::NsId => RTNLGRP_NSID,
Self::Ipv4Rule => RTNLGRP_IPV4_RULE,
Self::Ipv6Rule => RTNLGRP_IPV6_RULE,
}
}
}
impl Connection<Route> {
pub fn for_namespace(spec: super::namespace::NamespaceSpec<'_>) -> Result<Self> {
spec.connection()
}
#[instrument(level = "info", skip(self), fields(groups = ?groups))]
pub fn subscribe(&mut self, groups: &[RtnetlinkGroup]) -> Result<()> {
for group in groups {
self.socket.add_membership(group.to_group())?;
}
Ok(())
}
pub fn subscribe_all(&mut self) -> Result<()> {
self.subscribe(&[
RtnetlinkGroup::Link,
RtnetlinkGroup::Ipv4Addr,
RtnetlinkGroup::Ipv6Addr,
RtnetlinkGroup::Ipv4Route,
RtnetlinkGroup::Ipv6Route,
RtnetlinkGroup::Neigh,
RtnetlinkGroup::Tc,
])
}
pub async fn dump_typed<T: FromNetlink>(&self, msg_type: u16) -> Result<Vec<T>> {
let mut builder = dump_request(msg_type);
let mut header_buf = Vec::new();
T::write_dump_header(&mut header_buf);
builder.append_bytes(&header_buf);
let responses = self.send_dump(builder).await?;
let mut parsed = Vec::with_capacity(responses.len());
for response in responses {
if response.len() < NLMSG_HDRLEN {
continue;
}
let payload = &response[NLMSG_HDRLEN..];
if let Ok(msg) = T::from_bytes(payload) {
parsed.push(msg);
}
}
Ok(parsed)
}
pub fn parse_response<T: FromNetlink>(&self, response: &[u8]) -> Result<T> {
if response.len() < NLMSG_HDRLEN {
return Err(Error::Truncated {
expected: NLMSG_HDRLEN,
actual: response.len(),
});
}
let payload = &response[NLMSG_HDRLEN..];
T::from_bytes(payload)
}
}
pub(crate) fn dump_request(msg_type: u16) -> MessageBuilder {
MessageBuilder::new(msg_type, NLM_F_REQUEST | NLM_F_DUMP)
}
pub(crate) fn ack_request(msg_type: u16) -> MessageBuilder {
MessageBuilder::new(msg_type, NLM_F_REQUEST | NLM_F_ACK)
}
pub(crate) fn create_request(msg_type: u16) -> MessageBuilder {
MessageBuilder::new(msg_type, NLM_F_REQUEST | NLM_F_ACK | 0x400) }
pub(crate) fn replace_request(msg_type: u16) -> MessageBuilder {
MessageBuilder::new(msg_type, NLM_F_REQUEST | NLM_F_ACK | 0x400 | 0x100) }
impl Connection<Route> {
pub fn batch(&self) -> super::batch::Batch<'_> {
super::batch::Batch::new(self)
}
}
use super::messages::{
AddressMessage, LinkMessage, NeighborMessage, RouteMessage, RuleMessage, TcMessage,
};
impl Connection<Route> {
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_links"))]
pub async fn get_links(&self) -> Result<Vec<LinkMessage>> {
self.dump_typed(NlMsgType::RTM_GETLINK).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_link_by_name"))]
pub async fn get_link_by_name(
&self,
name: impl Into<InterfaceRef>,
) -> Result<Option<LinkMessage>> {
let iface = name.into();
match iface {
InterfaceRef::Name(ref name_str) => {
let links = self.get_links().await?;
Ok(links
.into_iter()
.find(|l| l.name.as_deref() == Some(name_str)))
}
InterfaceRef::Index(idx) => self.get_link_by_index(idx).await,
}
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_link_by_index"))]
pub async fn get_link_by_index(&self, index: u32) -> Result<Option<LinkMessage>> {
let links = self.get_links().await?;
Ok(links.into_iter().find(|l| l.ifindex() == index))
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "resolve_interface"))]
pub async fn resolve_interface(&self, iface: &InterfaceRef) -> Result<u32> {
match iface {
InterfaceRef::Index(idx) => Ok(*idx),
InterfaceRef::Name(name) => {
let link = self
.get_link_by_name(name)
.await?
.ok_or_else(|| Error::interface_not_found(name))?;
Ok(link.ifindex())
}
}
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "resolve_interface_opt"))]
pub async fn resolve_interface_opt(&self, iface: Option<&InterfaceRef>) -> Result<Option<u32>> {
match iface {
Some(iface) => Ok(Some(self.resolve_interface(iface).await?)),
None => Ok(None),
}
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_interface_names"))]
pub async fn get_interface_names(&self) -> Result<std::collections::HashMap<u32, String>> {
let links = self.get_links().await?;
Ok(links
.into_iter()
.filter_map(|l| l.name.clone().map(|n| (l.ifindex(), n)))
.collect())
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "interface_name"))]
pub async fn interface_name(&self, ifindex: u32) -> Result<Option<String>> {
let link = self.get_link_by_index(ifindex).await?;
Ok(link.and_then(|l| l.name))
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "interface_name_or"))]
pub async fn interface_name_or(&self, ifindex: u32, default: &str) -> Result<String> {
Ok(self
.interface_name(ifindex)
.await?
.unwrap_or_else(|| default.to_string()))
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_bond_info"))]
pub async fn get_bond_info(
&self,
iface: impl Into<InterfaceRef>,
) -> Result<crate::netlink::messages::BondInfo> {
let ifindex = self.resolve_interface(&iface.into()).await?;
let link = self
.get_link_by_index(ifindex)
.await?
.ok_or_else(|| Error::InvalidMessage("interface not found".into()))?;
link.bond_info()
.ok_or_else(|| Error::InvalidMessage("not a bond interface".into()))
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_bond_slaves"))]
pub async fn get_bond_slaves(
&self,
bond: impl Into<InterfaceRef>,
) -> Result<Vec<(LinkMessage, crate::netlink::messages::BondSlaveInfo)>> {
let bond_ifindex = self.resolve_interface(&bond.into()).await?;
let all_links = self.get_links().await?;
let mut slaves = Vec::new();
for link in all_links {
if link.master() == Some(bond_ifindex)
&& let Some(info) = link.bond_slave_info()
{
slaves.push((link, info));
}
}
Ok(slaves)
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_addresses"))]
pub async fn get_addresses(&self) -> Result<Vec<AddressMessage>> {
self.dump_typed(NlMsgType::RTM_GETADDR).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_addresses_by_name"))]
pub async fn get_addresses_by_name(
&self,
iface: impl Into<InterfaceRef>,
) -> Result<Vec<AddressMessage>> {
let ifindex = self.resolve_interface(&iface.into()).await?;
self.get_addresses_by_index(ifindex).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_addresses_by_index"))]
pub async fn get_addresses_by_index(&self, ifindex: u32) -> Result<Vec<AddressMessage>> {
let addresses = self.get_addresses().await?;
Ok(addresses
.into_iter()
.filter(|a| a.ifindex() == ifindex)
.collect())
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_address_by_ip"))]
pub async fn get_address_by_ip(
&self,
addr: std::net::IpAddr,
) -> Result<Option<AddressMessage>> {
let addresses = self.get_addresses().await?;
Ok(addresses.into_iter().find(|a| a.address == Some(addr)))
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_routes"))]
pub async fn get_routes(&self) -> Result<Vec<RouteMessage>> {
self.dump_typed(NlMsgType::RTM_GETROUTE).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_routes_for_table"))]
pub async fn get_routes_for_table(&self, table_id: u32) -> Result<Vec<RouteMessage>> {
let routes = self.get_routes().await?;
Ok(routes
.into_iter()
.filter(|r| r.table_id() == table_id)
.collect())
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_route_v4"))]
pub async fn get_route_v4(
&self,
destination: std::net::Ipv4Addr,
prefix_len: u8,
) -> Result<Option<RouteMessage>> {
use crate::netlink::types::route::{RtMsg, RtaAttr};
let mut builder = MessageBuilder::new(NlMsgType::RTM_GETROUTE, NLM_F_REQUEST);
let rtmsg = RtMsg::new()
.with_family(libc::AF_INET as u8)
.with_dst_len(prefix_len);
builder.append(&rtmsg);
builder.append_attr(RtaAttr::Dst as u16, &destination.octets());
match self.send_request(builder).await {
Ok(response) => {
if response.len() >= NLMSG_HDRLEN {
let payload = &response[NLMSG_HDRLEN..];
if let Ok(msg) = RouteMessage::from_bytes(payload) {
return Ok(Some(msg));
}
}
Ok(None)
}
Err(e) if e.is_not_found() => Ok(None),
Err(e) => Err(e),
}
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_route_v6"))]
pub async fn get_route_v6(
&self,
destination: std::net::Ipv6Addr,
prefix_len: u8,
) -> Result<Option<RouteMessage>> {
use crate::netlink::types::route::{RtMsg, RtaAttr};
let mut builder = MessageBuilder::new(NlMsgType::RTM_GETROUTE, NLM_F_REQUEST);
let rtmsg = RtMsg::new()
.with_family(libc::AF_INET6 as u8)
.with_dst_len(prefix_len);
builder.append(&rtmsg);
builder.append_attr(RtaAttr::Dst as u16, &destination.octets());
match self.send_request(builder).await {
Ok(response) => {
if response.len() >= NLMSG_HDRLEN {
let payload = &response[NLMSG_HDRLEN..];
if let Ok(msg) = RouteMessage::from_bytes(payload) {
return Ok(Some(msg));
}
}
Ok(None)
}
Err(e) if e.is_not_found() => Ok(None),
Err(e) => Err(e),
}
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_neighbors"))]
pub async fn get_neighbors(&self) -> Result<Vec<NeighborMessage>> {
self.dump_typed(NlMsgType::RTM_GETNEIGH).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_neighbors_by_name"))]
pub async fn get_neighbors_by_name(
&self,
iface: impl Into<InterfaceRef>,
) -> Result<Vec<NeighborMessage>> {
let ifindex = self.resolve_interface(&iface.into()).await?;
self.get_neighbors_by_index(ifindex).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_rules"))]
pub async fn get_rules(&self) -> Result<Vec<RuleMessage>> {
self.dump_typed(NlMsgType::RTM_GETRULE).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_rules_for_family"))]
pub async fn get_rules_for_family(&self, family: u8) -> Result<Vec<RuleMessage>> {
let rules = self.get_rules().await?;
Ok(rules.into_iter().filter(|r| r.family() == family).collect())
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_rules_v4"))]
pub async fn get_rules_v4(&self) -> Result<Vec<RuleMessage>> {
self.get_rules_for_family(libc::AF_INET as u8).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_rules_v6"))]
pub async fn get_rules_v6(&self) -> Result<Vec<RuleMessage>> {
self.get_rules_for_family(libc::AF_INET6 as u8).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "add_rule"))]
pub async fn add_rule(&self, rule: super::rule::RuleBuilder) -> Result<()> {
let builder = rule.build()?;
self.send_ack(builder)
.await
.map_err(|e| e.with_context("add_rule"))
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "del_rule"))]
pub async fn del_rule(&self, rule: super::rule::RuleBuilder) -> Result<()> {
let builder = rule.build_delete()?;
self.send_ack(builder)
.await
.map_err(|e| e.with_context("del_rule"))
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "del_rule_by_priority"))]
pub async fn del_rule_by_priority(&self, family: u8, priority: u32) -> Result<()> {
let rule = super::rule::RuleBuilder::new(family).priority(priority);
self.del_rule(rule).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "flush_rules"))]
pub async fn flush_rules(&self, family: u8) -> Result<()> {
let rules = self.get_rules_for_family(family).await?;
for rule in rules {
if rule.priority == 0 || rule.priority == 32766 || rule.priority == 32767 {
continue;
}
let _ = self.del_rule_by_priority(family, rule.priority).await;
}
Ok(())
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_qdiscs"))]
pub async fn get_qdiscs(&self) -> Result<Vec<TcMessage>> {
self.dump_typed(NlMsgType::RTM_GETQDISC).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_qdiscs_by_name"))]
pub async fn get_qdiscs_by_name(
&self,
iface: impl Into<InterfaceRef>,
) -> Result<Vec<TcMessage>> {
let ifindex = self.resolve_interface(&iface.into()).await?;
self.get_qdiscs_by_index(ifindex).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_qdiscs_by_index"))]
pub async fn get_qdiscs_by_index(&self, ifindex: u32) -> Result<Vec<TcMessage>> {
let qdiscs = self.get_qdiscs().await?;
Ok(qdiscs
.into_iter()
.filter(|q| q.ifindex() == ifindex)
.collect())
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_classes"))]
pub async fn get_classes(&self) -> Result<Vec<TcMessage>> {
self.dump_typed(NlMsgType::RTM_GETTCLASS).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_classes_by_name"))]
pub async fn get_classes_by_name(
&self,
iface: impl Into<InterfaceRef>,
) -> Result<Vec<TcMessage>> {
let ifindex = self.resolve_interface(&iface.into()).await?;
self.get_classes_by_index(ifindex).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_classes_by_index"))]
pub async fn get_classes_by_index(&self, ifindex: u32) -> Result<Vec<TcMessage>> {
let classes = self.get_classes().await?;
Ok(classes
.into_iter()
.filter(|c| c.ifindex() == ifindex)
.collect())
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_filters"))]
pub async fn get_filters(&self) -> Result<Vec<TcMessage>> {
self.dump_typed(NlMsgType::RTM_GETTFILTER).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_filters_by_name"))]
pub async fn get_filters_by_name(
&self,
iface: impl Into<InterfaceRef>,
) -> Result<Vec<TcMessage>> {
let ifindex = self.resolve_interface(&iface.into()).await?;
self.get_filters_by_index(ifindex).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_filters_by_index"))]
pub async fn get_filters_by_index(&self, ifindex: u32) -> Result<Vec<TcMessage>> {
let filters = self.get_filters().await?;
Ok(filters
.into_iter()
.filter(|f| f.ifindex() == ifindex)
.collect())
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_filters_by_parent"))]
pub async fn get_filters_by_parent(
&self,
iface: impl Into<InterfaceRef>,
parent: TcHandle,
) -> Result<Vec<TcMessage>> {
let ifindex = self.resolve_interface(&iface.into()).await?;
self.get_filters_by_parent_index(ifindex, parent).await
}
#[tracing::instrument(
level = "debug",
skip_all,
fields(method = "get_filters_by_parent_index")
)]
pub async fn get_filters_by_parent_index(
&self,
ifindex: u32,
parent: TcHandle,
) -> Result<Vec<TcMessage>> {
let filters = self.get_filters_by_index(ifindex).await?;
Ok(filters
.into_iter()
.filter(|f| f.parent() == parent)
.collect())
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_tc_chains"))]
pub async fn get_tc_chains(
&self,
ifname: impl Into<InterfaceRef>,
parent: TcHandle,
) -> Result<Vec<u32>> {
let ifindex = self.resolve_interface(&ifname.into()).await?;
self.get_tc_chains_by_index(ifindex, parent).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_tc_chains_by_index"))]
pub async fn get_tc_chains_by_index(&self, ifindex: u32, parent: TcHandle) -> Result<Vec<u32>> {
use super::types::tc::TcMsg;
let tcmsg = TcMsg::new()
.with_ifindex(ifindex as i32)
.with_parent(parent.as_raw());
let mut builder = dump_request(NlMsgType::RTM_GETCHAIN);
builder.append(&tcmsg);
let responses = self.send_dump(builder).await?;
let mut chains = Vec::new();
for response in responses {
if response.len() < NLMSG_HDRLEN {
continue;
}
let payload = &response[NLMSG_HDRLEN..];
if let Ok(tc) = TcMessage::from_bytes(payload)
&& let Some(chain) = tc.chain()
{
chains.push(chain);
}
}
Ok(chains)
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "add_tc_chain"))]
pub async fn add_tc_chain(
&self,
ifname: impl Into<InterfaceRef>,
parent: TcHandle,
chain: u32,
) -> Result<()> {
let ifindex = self.resolve_interface(&ifname.into()).await?;
self.add_tc_chain_by_index(ifindex, parent, chain).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "add_tc_chain_by_index"))]
pub async fn add_tc_chain_by_index(
&self,
ifindex: u32,
parent: TcHandle,
chain: u32,
) -> Result<()> {
use super::types::tc::{TcMsg, TcaAttr};
let tcmsg = TcMsg::new()
.with_ifindex(ifindex as i32)
.with_parent(parent.as_raw());
let mut builder = create_request(NlMsgType::RTM_NEWCHAIN);
builder.append(&tcmsg);
builder.append_attr_u32(TcaAttr::Chain as u16, chain);
self.send_ack(builder)
.await
.map_err(|e| e.with_context("add_tc_chain"))
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "del_tc_chain"))]
pub async fn del_tc_chain(
&self,
ifname: impl Into<InterfaceRef>,
parent: TcHandle,
chain: u32,
) -> Result<()> {
let ifindex = self.resolve_interface(&ifname.into()).await?;
self.del_tc_chain_by_index(ifindex, parent, chain).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "del_tc_chain_by_index"))]
pub async fn del_tc_chain_by_index(
&self,
ifindex: u32,
parent: TcHandle,
chain: u32,
) -> Result<()> {
use super::types::tc::{TcMsg, TcaAttr};
let tcmsg = TcMsg::new()
.with_ifindex(ifindex as i32)
.with_parent(parent.as_raw());
let mut builder = create_request(NlMsgType::RTM_DELCHAIN);
builder.append(&tcmsg);
builder.append_attr_u32(TcaAttr::Chain as u16, chain);
self.send_ack(builder)
.await
.map_err(|e| e.with_context("del_tc_chain"))
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_root_qdisc_by_name"))]
pub async fn get_root_qdisc_by_name(
&self,
iface: impl Into<InterfaceRef>,
) -> Result<Option<TcMessage>> {
let ifindex = self.resolve_interface(&iface.into()).await?;
self.get_root_qdisc_by_index(ifindex).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_root_qdisc_by_index"))]
pub async fn get_root_qdisc_by_index(&self, ifindex: u32) -> Result<Option<TcMessage>> {
let qdiscs = self.get_qdiscs().await?;
Ok(qdiscs
.into_iter()
.find(|q| q.ifindex() == ifindex && q.is_root()))
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_qdisc_by_handle"))]
pub async fn get_qdisc_by_handle(
&self,
ifname: &str,
handle: TcHandle,
) -> Result<Option<TcMessage>> {
let ifindex = self
.resolve_interface(&InterfaceRef::Name(ifname.to_string()))
.await?;
self.get_qdisc_by_handle_index(ifindex, handle).await
}
#[tracing::instrument(
level = "debug",
skip_all,
fields(method = "get_qdisc_by_handle_index")
)]
pub async fn get_qdisc_by_handle_index(
&self,
ifindex: u32,
handle: TcHandle,
) -> Result<Option<TcMessage>> {
let qdiscs = self.get_qdiscs().await?;
Ok(qdiscs
.into_iter()
.find(|q| q.ifindex() == ifindex && q.handle() == handle))
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_netem_by_name"))]
pub async fn get_netem_by_name(
&self,
iface: impl Into<InterfaceRef>,
) -> Result<Option<super::tc_options::NetemOptions>> {
let ifindex = self.resolve_interface(&iface.into()).await?;
self.get_netem_by_index(ifindex).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_netem_by_index"))]
pub async fn get_netem_by_index(
&self,
ifindex: u32,
) -> Result<Option<super::tc_options::NetemOptions>> {
use super::tc_options::QdiscOptions;
let root = self.get_root_qdisc_by_index(ifindex).await?;
Ok(match root.and_then(|q| q.options()) {
Some(QdiscOptions::Netem(opts)) => Some(opts),
_ => None,
})
}
}
use super::types::link::{IfInfoMsg, iff};
impl Connection<Route> {
#[tracing::instrument(level = "debug", skip_all, fields(method = "set_link_up"))]
pub async fn set_link_up(&self, iface: impl Into<InterfaceRef>) -> Result<()> {
let ifindex = self.resolve_interface(&iface.into()).await?;
self.set_link_up_by_index(ifindex).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "set_link_up_by_index"))]
pub async fn set_link_up_by_index(&self, ifindex: u32) -> Result<()> {
self.set_link_state_by_index(ifindex, true).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "set_link_down"))]
pub async fn set_link_down(&self, iface: impl Into<InterfaceRef>) -> Result<()> {
let ifindex = self.resolve_interface(&iface.into()).await?;
self.set_link_down_by_index(ifindex).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "set_link_down_by_index"))]
pub async fn set_link_down_by_index(&self, ifindex: u32) -> Result<()> {
self.set_link_state_by_index(ifindex, false).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "set_link_state"))]
pub async fn set_link_state(&self, iface: impl Into<InterfaceRef>, up: bool) -> Result<()> {
let ifindex = self.resolve_interface(&iface.into()).await?;
self.set_link_state_by_index(ifindex, up).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "set_link_state_by_index"))]
pub async fn set_link_state_by_index(&self, ifindex: u32, up: bool) -> Result<()> {
let mut ifinfo = IfInfoMsg::new().with_index(ifindex as i32);
if up {
ifinfo.ifi_flags = iff::UP;
ifinfo.ifi_change = iff::UP;
} else {
ifinfo.ifi_flags = 0;
ifinfo.ifi_change = iff::UP;
}
let mut builder = ack_request(NlMsgType::RTM_SETLINK);
builder.append(&ifinfo);
let state = if up { "up" } else { "down" };
self.send_ack(builder).await.map_err(|e| {
if e.is_not_found() {
Error::InterfaceNotFound {
name: format!("ifindex {ifindex}"),
}
} else {
e.with_context(format!("set_link_{state}(ifindex {ifindex})"))
}
})
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "set_link_mtu"))]
pub async fn set_link_mtu(&self, iface: impl Into<InterfaceRef>, mtu: u32) -> Result<()> {
let ifindex = self.resolve_interface(&iface.into()).await?;
self.set_link_mtu_by_index(ifindex, mtu).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "set_link_mtu_by_index"))]
pub async fn set_link_mtu_by_index(&self, ifindex: u32, mtu: u32) -> Result<()> {
use super::types::link::IflaAttr;
let ifinfo = IfInfoMsg::new().with_index(ifindex as i32);
let mut builder = ack_request(NlMsgType::RTM_SETLINK);
builder.append(&ifinfo);
builder.append_attr_u32(IflaAttr::Mtu as u16, mtu);
self.send_ack(builder)
.await
.map_err(|e| e.with_context("set_link_mtu"))
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "del_link"))]
pub async fn del_link(&self, iface: impl Into<InterfaceRef>) -> Result<()> {
let ifindex = self.resolve_interface(&iface.into()).await?;
self.del_link_by_index(ifindex).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "del_link_by_index"))]
pub async fn del_link_by_index(&self, ifindex: u32) -> Result<()> {
let ifinfo = IfInfoMsg::new().with_index(ifindex as i32);
let mut builder = ack_request(NlMsgType::RTM_DELLINK);
builder.append(&ifinfo);
self.send_ack(builder).await.map_err(|e| {
if e.is_not_found() {
Error::InterfaceNotFound {
name: format!("ifindex {ifindex}"),
}
} else {
e.with_context(format!("del_link(ifindex {ifindex})"))
}
})
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "set_link_txqlen"))]
pub async fn set_link_txqlen(&self, iface: impl Into<InterfaceRef>, txqlen: u32) -> Result<()> {
let ifindex = self.resolve_interface(&iface.into()).await?;
self.set_link_txqlen_by_index(ifindex, txqlen).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "set_link_txqlen_by_index"))]
pub async fn set_link_txqlen_by_index(&self, ifindex: u32, txqlen: u32) -> Result<()> {
use super::types::link::IflaAttr;
let ifinfo = IfInfoMsg::new().with_index(ifindex as i32);
let mut builder = ack_request(NlMsgType::RTM_SETLINK);
builder.append(&ifinfo);
builder.append_attr_u32(IflaAttr::TxqLen as u16, txqlen);
self.send_ack(builder)
.await
.map_err(|e| e.with_context("set_link_txqlen"))
}
}
use super::{
messages::NsIdMessage,
types::nsid::{RTM_GETNSID, RtGenMsg, netnsa},
};
impl Connection<Route> {
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_nsid"))]
pub async fn get_nsid(&self, ns_fd: RawFd) -> Result<u32> {
let mut builder = ack_request(RTM_GETNSID);
builder.append(&RtGenMsg::new());
builder.append_bytes(&[0u8; 3]);
builder.append_attr_u32(netnsa::FD, ns_fd as u32);
let response = self.send_request(builder).await?;
if response.len() >= super::message::NLMSG_HDRLEN {
let payload = &response[super::message::NLMSG_HDRLEN..];
if let Some(nsid_msg) = NsIdMessage::parse(payload)
&& let Some(nsid) = nsid_msg.nsid
{
return Ok(nsid);
}
}
Err(Error::InvalidMessage(
"namespace ID not found in response".into(),
))
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_nsid_for_pid"))]
pub async fn get_nsid_for_pid(&self, pid: u32) -> Result<u32> {
let mut builder = ack_request(RTM_GETNSID);
builder.append(&RtGenMsg::new());
builder.append_bytes(&[0u8; 3]);
builder.append_attr_u32(netnsa::PID, pid);
let response = self.send_request(builder).await?;
if response.len() >= super::message::NLMSG_HDRLEN {
let payload = &response[super::message::NLMSG_HDRLEN..];
if let Some(nsid_msg) = NsIdMessage::parse(payload)
&& let Some(nsid) = nsid_msg.nsid
{
return Ok(nsid);
}
}
Err(Error::InvalidMessage(
"namespace ID not found in response".into(),
))
}
}
use std::collections::HashMap;
use super::{
genl::{
CtrlAttr, CtrlAttrMcastGrp, CtrlCmd, FamilyInfo, GENL_HDRLEN, GENL_ID_CTRL, GenlMsgHdr,
},
protocol::Generic,
};
impl Connection<Generic> {
#[instrument(level = "info", skip(self), fields(family = %name, id, cached))]
pub async fn get_family(&self, name: &str) -> Result<FamilyInfo> {
{
let cache = self.state.cache.read().unwrap();
if let Some(info) = cache.get(name) {
let span = tracing::Span::current();
span.record("id", info.id);
span.record("cached", true);
return Ok(info.clone());
}
}
let info = self.query_family(name).await?;
let span = tracing::Span::current();
span.record("id", info.id);
span.record("cached", false);
{
let mut cache = self.state.cache.write().unwrap();
cache.insert(name.to_string(), info.clone());
}
Ok(info)
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_family_id"))]
pub async fn get_family_id(&self, name: &str) -> Result<u16> {
Ok(self.get_family(name).await?.id)
}
pub fn clear_cache(&self) {
let mut cache = self.state.cache.write().unwrap();
cache.clear();
}
async fn query_family(&self, name: &str) -> Result<FamilyInfo> {
let mut builder = MessageBuilder::new(GENL_ID_CTRL, NLM_F_REQUEST | NLM_F_ACK);
let genl_hdr = GenlMsgHdr::new(CtrlCmd::GetFamily as u8, 1);
builder.append(&genl_hdr);
builder.append_attr_str(CtrlAttr::FamilyName as u16, name);
let seq = self.socket.next_seq();
builder.set_seq(seq);
builder.set_pid(self.socket.pid());
let msg = builder.finish();
self.socket.send(&msg).await?;
let response = self.socket.recv_msg().await?;
self.parse_family_response(&response, seq, name)
}
fn parse_family_response(&self, data: &[u8], seq: u32, name: &str) -> Result<FamilyInfo> {
for result in MessageIter::new(data) {
let (header, payload) = result?;
if header.nlmsg_seq != seq {
continue;
}
if header.is_error() {
let err = NlMsgError::from_bytes(payload)?;
if !err.is_ack() {
if err.error == -libc::ENOENT {
return Err(Error::FamilyNotFound {
name: name.to_string(),
});
}
return Err(Error::from_errno(err.error));
}
continue;
}
if header.is_done() {
continue;
}
if payload.len() < GENL_HDRLEN {
return Err(Error::InvalidMessage("GENL header too short".into()));
}
let attrs_data = &payload[GENL_HDRLEN..];
return self.parse_family_attrs(attrs_data);
}
Err(Error::FamilyNotFound {
name: name.to_string(),
})
}
fn parse_family_attrs(&self, data: &[u8]) -> Result<FamilyInfo> {
use super::attr::{AttrIter, get};
let mut id: Option<u16> = None;
let mut version: u8 = 0;
let mut hdr_size: u32 = 0;
let mut max_attr: u32 = 0;
let mut mcast_groups = HashMap::new();
for (attr_type, payload) in AttrIter::new(data) {
match attr_type {
t if t == CtrlAttr::FamilyId as u16 => {
id = Some(get::u16_ne(payload)?);
}
t if t == CtrlAttr::Version as u16 => {
version = get::u32_ne(payload)? as u8;
}
t if t == CtrlAttr::HdrSize as u16 => {
hdr_size = get::u32_ne(payload)?;
}
t if t == CtrlAttr::MaxAttr as u16 => {
max_attr = get::u32_ne(payload)?;
}
t if t == CtrlAttr::McastGroups as u16 => {
mcast_groups = self.parse_mcast_groups(payload)?;
}
_ => {}
}
}
let id = id.ok_or_else(|| Error::InvalidMessage("missing family ID".into()))?;
Ok(FamilyInfo {
id,
version,
hdr_size,
max_attr,
mcast_groups,
})
}
fn parse_mcast_groups(&self, data: &[u8]) -> Result<HashMap<String, u32>> {
use super::attr::{AttrIter, get};
let mut groups = HashMap::new();
for (_group_idx, group_payload) in AttrIter::new(data) {
let mut name: Option<String> = None;
let mut grp_id: Option<u32> = None;
for (attr_type, payload) in AttrIter::new(group_payload) {
match attr_type {
t if t == CtrlAttrMcastGrp::Name as u16 => {
name = Some(get::string(payload)?.to_string());
}
t if t == CtrlAttrMcastGrp::Id as u16 => {
grp_id = Some(get::u32_ne(payload)?);
}
_ => {}
}
}
if let (Some(name), Some(id)) = (name, grp_id) {
groups.insert(name, id);
}
}
Ok(groups)
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "command"))]
pub async fn command(
&self,
family_id: u16,
cmd: u8,
version: u8,
build_attrs: impl FnOnce(&mut MessageBuilder),
) -> Result<Vec<u8>> {
let mut builder = MessageBuilder::new(family_id, NLM_F_REQUEST | NLM_F_ACK);
let genl_hdr = GenlMsgHdr::new(cmd, version);
builder.append(&genl_hdr);
build_attrs(&mut builder);
let seq = self.socket.next_seq();
builder.set_seq(seq);
builder.set_pid(self.socket.pid());
let msg = builder.finish();
self.socket.send(&msg).await?;
let response = self.socket.recv_msg().await?;
self.process_genl_response(&response, seq)?;
Ok(response)
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "dump_command"))]
pub async fn dump_command(
&self,
family_id: u16,
cmd: u8,
version: u8,
build_attrs: impl FnOnce(&mut MessageBuilder),
) -> Result<Vec<Vec<u8>>> {
let mut builder = MessageBuilder::new(family_id, NLM_F_REQUEST | NLM_F_DUMP);
let genl_hdr = GenlMsgHdr::new(cmd, version);
builder.append(&genl_hdr);
build_attrs(&mut builder);
let seq = self.socket.next_seq();
builder.set_seq(seq);
builder.set_pid(self.socket.pid());
let msg = builder.finish();
self.socket.send(&msg).await?;
let mut responses = Vec::new();
loop {
let data = self.socket.recv_msg().await?;
let mut done = false;
for result in MessageIter::new(&data) {
let (header, payload) = result?;
if header.nlmsg_seq != seq {
continue;
}
if header.is_error() {
let err = NlMsgError::from_bytes(payload)?;
if !err.is_ack() {
return Err(Error::from_errno(err.error));
}
continue;
}
if header.is_done() {
done = true;
break;
}
responses.push(payload.to_vec());
}
if done {
break;
}
}
Ok(responses)
}
fn process_genl_response(&self, data: &[u8], seq: u32) -> Result<()> {
for result in MessageIter::new(data) {
let (header, payload) = result?;
if header.nlmsg_seq != seq {
continue;
}
if header.is_error() {
let err = NlMsgError::from_bytes(payload)?;
if !err.is_ack() {
return Err(Error::from_errno(err.error));
}
}
}
Ok(())
}
}
#[cfg(test)]
mod send_sync_tests {
use super::*;
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
#[test]
fn connection_is_send_sync() {
assert_send::<Connection<Route>>();
assert_sync::<Connection<Route>>();
}
}