use crate::{error::GenetlinkError, GenetlinkHandle};
use futures::{future::Either, StreamExt};
use netlink_packet_core::{NetlinkMessage, NetlinkPayload, NLM_F_REQUEST};
use netlink_packet_generic::{
ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd},
GenlMessage,
};
use std::{collections::HashMap, future::Future};
#[derive(Clone, Debug, Default)]
pub struct Resolver {
cache: HashMap<&'static str, u16>,
}
impl Resolver {
pub fn new() -> Self {
Self {
cache: HashMap::new(),
}
}
pub fn get_cache_by_name(&self, family_name: &str) -> Option<u16> {
self.cache.get(family_name).copied()
}
pub fn query_family_id(
&mut self,
handle: &GenetlinkHandle,
family_name: &'static str,
) -> impl Future<Output = Result<u16, GenetlinkError>> + '_ {
if let Some(id) = self.get_cache_by_name(family_name) {
Either::Left(futures::future::ready(Ok(id)))
} else {
let mut handle = handle.clone();
Either::Right(async move {
let mut genlmsg: GenlMessage<GenlCtrl> =
GenlMessage::from_payload(GenlCtrl {
cmd: GenlCtrlCmd::GetFamily,
nlas: vec![GenlCtrlAttrs::FamilyName(
family_name.to_owned(),
)],
});
genlmsg.finalize();
let mut nlmsg = NetlinkMessage::from(genlmsg);
nlmsg.header.flags = NLM_F_REQUEST;
nlmsg.finalize();
let mut res = handle.send_request(nlmsg)?;
while let Some(result) = res.next().await {
let rx_packet = result?;
match rx_packet.payload {
NetlinkPayload::InnerMessage(genlmsg) => {
let family_id = genlmsg
.payload
.nlas
.iter()
.find_map(|nla| {
if let GenlCtrlAttrs::FamilyId(id) = nla {
Some(*id)
} else {
None
}
})
.ok_or_else(|| {
GenetlinkError::AttributeNotFound(
"CTRL_ATTR_FAMILY_ID".to_owned(),
)
})?;
self.cache.insert(family_name, family_id);
return Ok(family_id);
}
NetlinkPayload::Error(e) => return Err(e.into()),
_ => (),
}
}
Err(GenetlinkError::NoMessageReceived)
})
}
}
pub fn clear_cache(&mut self) {
self.cache.clear();
}
}
#[cfg(all(test, feature = "tokio_socket"))]
mod test {
use super::*;
use crate::new_connection;
use std::io::ErrorKind;
#[tokio::test]
async fn test_resolver_nlctrl() {
let (conn, handle, _) = new_connection().unwrap();
tokio::spawn(conn);
let mut resolver = Resolver::new();
let nlctrl_fid =
resolver.query_family_id(&handle, "nlctrl").await.unwrap();
assert_eq!(nlctrl_fid, 0x10);
}
const TEST_FAMILIES: &[&str] = &[
"devlink",
"ethtool",
"acpi_event",
"tcp_metrics",
"TASKSTATS",
"nl80211",
];
#[tokio::test]
async fn test_resolver_cache() {
let (conn, handle, _) = new_connection().unwrap();
tokio::spawn(conn);
let mut resolver = Resolver::new();
for name in TEST_FAMILIES.iter().copied() {
let id = resolver
.query_family_id(&handle, name)
.await
.or_else(|e| {
if let GenetlinkError::NetlinkError(io_err) = &e {
if io_err.kind() == ErrorKind::NotFound {
Ok(0)
} else {
Err(e)
}
} else {
Err(e)
}
})
.unwrap();
if id == 0 {
log::warn!(
"Generic family \"{name}\" not exist or not loaded \
in this environment. Ignored."
);
continue;
}
let cache = resolver.get_cache_by_name(name).unwrap();
assert_eq!(id, cache);
log::warn!("{:?}", (name, cache));
}
}
}