use std::{collections::HashMap, sync::RwLock};
use super::{
error::Result,
genl::FamilyInfo,
socket::{NetlinkSocket, Protocol},
};
mod private {
pub trait Sealed {}
}
pub trait ProtocolState: private::Sealed {
const PROTOCOL: Protocol;
}
pub trait AsyncProtocolInit: ProtocolState {
fn resolve_async(
socket: &NetlinkSocket,
) -> impl std::future::Future<Output = Result<Self>> + Send
where
Self: Sized;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct Route;
impl private::Sealed for Route {}
impl ProtocolState for Route {
const PROTOCOL: Protocol = Protocol::Route;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct SockDiag;
impl private::Sealed for SockDiag {}
impl ProtocolState for SockDiag {
const PROTOCOL: Protocol = Protocol::SockDiag;
}
pub struct Generic {
pub(crate) cache: RwLock<HashMap<String, FamilyInfo>>,
}
impl Default for Generic {
fn default() -> Self {
Self {
cache: RwLock::new(HashMap::new()),
}
}
}
impl std::fmt::Debug for Generic {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Generic")
.field(
"cache_size",
&self.cache.read().map(|c| c.len()).unwrap_or(0),
)
.finish()
}
}
impl private::Sealed for Generic {}
impl ProtocolState for Generic {
const PROTOCOL: Protocol = Protocol::Generic;
}
#[derive(Debug, Default)]
pub struct Wireguard {
pub(crate) family_id: u16,
}
impl private::Sealed for Wireguard {}
impl ProtocolState for Wireguard {
const PROTOCOL: Protocol = Protocol::Generic;
}
#[derive(Debug, Clone, Copy)]
pub struct KobjectUevent;
impl private::Sealed for KobjectUevent {}
impl ProtocolState for KobjectUevent {
const PROTOCOL: Protocol = Protocol::KobjectUevent;
}
#[derive(Debug, Clone, Copy)]
pub struct Connector;
impl private::Sealed for Connector {}
impl ProtocolState for Connector {
const PROTOCOL: Protocol = Protocol::Connector;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct Netfilter;
impl private::Sealed for Netfilter {}
impl ProtocolState for Netfilter {
const PROTOCOL: Protocol = Protocol::Netfilter;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct Xfrm;
impl private::Sealed for Xfrm {}
impl ProtocolState for Xfrm {
const PROTOCOL: Protocol = Protocol::Xfrm;
}
#[derive(Debug, Clone, Copy)]
pub struct FibLookup;
impl private::Sealed for FibLookup {}
impl ProtocolState for FibLookup {
const PROTOCOL: Protocol = Protocol::FibLookup;
}
#[derive(Debug, Clone, Copy)]
pub struct SELinux;
impl private::Sealed for SELinux {}
impl ProtocolState for SELinux {
const PROTOCOL: Protocol = Protocol::SELinux;
}
#[derive(Debug, Clone, Copy)]
pub struct Audit;
impl private::Sealed for Audit {}
impl ProtocolState for Audit {
const PROTOCOL: Protocol = Protocol::Audit;
}
#[derive(Debug, Default)]
pub struct Macsec {
pub(crate) family_id: u16,
}
impl private::Sealed for Macsec {}
impl ProtocolState for Macsec {
const PROTOCOL: Protocol = Protocol::Generic;
}
#[derive(Debug, Default)]
pub struct Mptcp {
pub(crate) family_id: u16,
}
impl private::Sealed for Mptcp {}
impl ProtocolState for Mptcp {
const PROTOCOL: Protocol = Protocol::Generic;
}
#[derive(Debug, Default)]
pub struct Devlink {
pub(crate) family_id: u16,
pub(crate) monitor_group_id: Option<u32>,
}
impl private::Sealed for Devlink {}
impl ProtocolState for Devlink {
const PROTOCOL: Protocol = Protocol::Generic;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct Nftables;
impl private::Sealed for Nftables {}
impl ProtocolState for Nftables {
const PROTOCOL: Protocol = Protocol::Netfilter;
}
#[derive(Debug, Default)]
pub struct Nl80211 {
pub(crate) family_id: u16,
pub(crate) scan_group_id: Option<u32>,
pub(crate) mlme_group_id: Option<u32>,
pub(crate) regulatory_group_id: Option<u32>,
pub(crate) config_group_id: Option<u32>,
}
impl private::Sealed for Nl80211 {}
impl ProtocolState for Nl80211 {
const PROTOCOL: Protocol = Protocol::Generic;
}
#[derive(Debug, Default)]
pub struct Ethtool {
pub(crate) family_id: u16,
pub(crate) monitor_group_id: Option<u32>,
}
impl private::Sealed for Ethtool {}
impl ProtocolState for Ethtool {
const PROTOCOL: Protocol = Protocol::Generic;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn route_is_zero_sized() {
assert_eq!(std::mem::size_of::<Route>(), 0);
}
#[test]
fn sockdiag_is_zero_sized() {
assert_eq!(std::mem::size_of::<SockDiag>(), 0);
}
#[test]
fn generic_has_cache() {
let g = Generic::default();
assert!(g.cache.read().unwrap().is_empty());
}
#[test]
fn protocol_constants() {
assert_eq!(Route::PROTOCOL, Protocol::Route);
assert_eq!(SockDiag::PROTOCOL, Protocol::SockDiag);
assert_eq!(Generic::PROTOCOL, Protocol::Generic);
assert_eq!(Wireguard::PROTOCOL, Protocol::Generic);
assert_eq!(Macsec::PROTOCOL, Protocol::Generic);
assert_eq!(Mptcp::PROTOCOL, Protocol::Generic);
assert_eq!(Ethtool::PROTOCOL, Protocol::Generic);
assert_eq!(Nl80211::PROTOCOL, Protocol::Generic);
assert_eq!(KobjectUevent::PROTOCOL, Protocol::KobjectUevent);
assert_eq!(Connector::PROTOCOL, Protocol::Connector);
assert_eq!(Netfilter::PROTOCOL, Protocol::Netfilter);
assert_eq!(Nftables::PROTOCOL, Protocol::Netfilter);
assert_eq!(Xfrm::PROTOCOL, Protocol::Xfrm);
assert_eq!(FibLookup::PROTOCOL, Protocol::FibLookup);
assert_eq!(SELinux::PROTOCOL, Protocol::SELinux);
assert_eq!(Audit::PROTOCOL, Protocol::Audit);
}
#[test]
fn new_types_are_zero_sized() {
assert_eq!(std::mem::size_of::<KobjectUevent>(), 0);
assert_eq!(std::mem::size_of::<Connector>(), 0);
assert_eq!(std::mem::size_of::<Netfilter>(), 0);
assert_eq!(std::mem::size_of::<Xfrm>(), 0);
assert_eq!(std::mem::size_of::<FibLookup>(), 0);
assert_eq!(std::mem::size_of::<SELinux>(), 0);
assert_eq!(std::mem::size_of::<Audit>(), 0);
assert_eq!(std::mem::size_of::<Nftables>(), 0);
}
}