1#[cfg(target_os = "linux")]
2mod linux {
3 use netlink_packet_core::{
4 NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable,
5 NETLINK_HEADER_LEN, NLM_F_ACK, NLM_F_CREATE, NLM_F_EXCL, NLM_F_REQUEST,
6 };
7 use netlink_packet_generic::{
8 constants::GENL_HDRLEN,
9 ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd},
10 GenlFamily, GenlHeader, GenlMessage,
11 };
12 use netlink_packet_route::RouteNetlinkMessage;
13 use netlink_packet_utils::{Emitable, ParseableParametrized};
14 use netlink_sys::{constants::NETLINK_GENERIC, protocols::NETLINK_ROUTE, Socket};
15 use nix::unistd::{sysconf, SysconfVar};
16 use once_cell::sync::OnceCell;
17 use std::{fmt::Debug, io};
18
19 macro_rules! get_nla_value {
20 ($nlas:expr, $e:ident, $v:ident) => {
21 $nlas.iter().find_map(|attr| match attr {
22 $e::$v(value) => Some(value),
23 _ => None,
24 })
25 };
26 }
27
28 pub fn max_netlink_buffer_length() -> usize {
29 static LENGTH: OnceCell<usize> = OnceCell::new();
30 *LENGTH.get_or_init(|| {
31 const MIN_NELINK_BUFFER_LENGTH: usize = 8 * 1024;
35 let page_size = sysconf(SysconfVar::PAGE_SIZE).unwrap().unwrap() as usize;
40 std::cmp::max(MIN_NELINK_BUFFER_LENGTH, page_size)
41 })
42 }
43
44 pub fn max_genl_payload_length() -> usize {
45 max_netlink_buffer_length() - NETLINK_HEADER_LEN - GENL_HDRLEN
46 }
47
48 pub fn netlink_request_genl<F>(
49 mut message: GenlMessage<F>,
50 flags: Option<u16>,
51 ) -> Result<Vec<NetlinkMessage<GenlMessage<F>>>, io::Error>
52 where
53 F: GenlFamily + Clone + Debug + Eq + Emitable + ParseableParametrized<[u8], GenlHeader>,
54 GenlMessage<F>: Clone + Debug + Eq + NetlinkSerializable + NetlinkDeserializable,
55 {
56 if message.family_id() == 0 {
57 let genlmsg: GenlMessage<GenlCtrl> = GenlMessage::from_payload(GenlCtrl {
58 cmd: GenlCtrlCmd::GetFamily,
59 nlas: vec![GenlCtrlAttrs::FamilyName(F::family_name().to_string())],
60 });
61 let responses =
62 netlink_request_genl::<GenlCtrl>(genlmsg, Some(NLM_F_REQUEST | NLM_F_ACK))?;
63
64 match responses.first() {
65 Some(NetlinkMessage {
66 payload:
67 NetlinkPayload::InnerMessage(GenlMessage {
68 payload: GenlCtrl { nlas, .. },
69 ..
70 }),
71 ..
72 }) => {
73 let family_id = get_nla_value!(nlas, GenlCtrlAttrs, FamilyId)
74 .ok_or_else(|| io::ErrorKind::NotFound)?;
75 message.set_resolved_family_id(*family_id);
76 },
77 _ => {
78 return Err(io::Error::new(
79 io::ErrorKind::InvalidData,
80 "Unexpected netlink payload",
81 ))
82 },
83 };
84 }
85 netlink_request(message, flags, NETLINK_GENERIC)
86 }
87
88 pub fn netlink_request_rtnl(
89 message: RouteNetlinkMessage,
90 flags: Option<u16>,
91 ) -> Result<Vec<NetlinkMessage<RouteNetlinkMessage>>, io::Error> {
92 netlink_request(message, flags, NETLINK_ROUTE)
93 }
94
95 pub fn netlink_request<I>(
96 message: I,
97 flags: Option<u16>,
98 socket: isize,
99 ) -> Result<Vec<NetlinkMessage<I>>, io::Error>
100 where
101 NetlinkPayload<I>: From<I>,
102 I: Clone + Debug + Eq + Emitable + NetlinkSerializable + NetlinkDeserializable,
103 {
104 let mut req = NetlinkMessage::from(message);
105
106 let max_buffer_len = max_netlink_buffer_length();
107 if req.buffer_len() > max_buffer_len {
108 return Err(io::Error::new(
109 io::ErrorKind::InvalidInput,
110 format!(
111 "Serialized netlink packet ({} bytes) larger than maximum size {}: {:?}",
112 req.buffer_len(),
113 max_buffer_len,
114 req
115 ),
116 ));
117 }
118
119 req.header.flags = flags.unwrap_or(NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE);
120 req.finalize();
121 let mut buf = vec![0; max_buffer_len];
122 req.serialize(&mut buf);
123 let len = req.buffer_len();
124
125 let socket = Socket::new(socket)?;
126 let kernel_addr = netlink_sys::SocketAddr::new(0, 0);
127 socket.connect(&kernel_addr)?;
128 let n_sent = socket.send(&buf[..len], 0)?;
129 if n_sent != len {
130 return Err(io::Error::new(
131 io::ErrorKind::UnexpectedEof,
132 "failed to send netlink request",
133 ));
134 }
135
136 let mut responses = vec![];
137 loop {
138 let n_received = socket.recv(&mut &mut buf[..], 0)?;
139 let mut offset = 0;
140 loop {
141 let bytes = &buf[offset..];
142 let response = NetlinkMessage::<I>::deserialize(bytes)
143 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
144 match response.payload {
145 NetlinkPayload::Error(e) if e.code.is_some() => return Err(e.into()),
147 NetlinkPayload::Done(_) | NetlinkPayload::Error(_) => return Ok(responses),
148 _ => {},
149 }
150 responses.push(response.clone());
151 offset += response.header.length as usize;
152 if offset == n_received || response.header.length == 0 {
153 break;
156 }
157 }
158 }
159 }
160}
161
162#[cfg(target_os = "linux")]
163pub use linux::{
164 max_genl_payload_length, max_netlink_buffer_length, netlink_request, netlink_request_genl,
165 netlink_request_rtnl,
166};