1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
use std::{
fmt::Debug,
io::Cursor,
os::unix::io::{AsRawFd, IntoRawFd, RawFd},
};
use log::trace;
use crate::{
consts::{nl::*, socket::*},
err::SocketError,
iter::NlBufferIter,
nl::Nlmsghdr,
socket::shared::NlSocket,
types::NlBuffer,
utils::{
synchronous::{BufferPool, BufferPoolGuard},
Groups, NetlinkBitArray,
},
FromBytesWithInput, Size, ToBytes,
};
/// Higher level handle for socket operations.
pub struct NlSocketHandle {
pub(super) socket: NlSocket,
pid: u32,
pool: BufferPool,
}
impl NlSocketHandle {
/// Equivalent of `socket` and `bind` calls.
pub fn connect(proto: NlFamily, pid: Option<u32>, groups: Groups) -> Result<Self, SocketError> {
let socket = NlSocket::connect(proto, pid, groups)?;
socket.block()?;
let pid = socket.pid()?;
Ok(NlSocketHandle {
socket,
pid,
pool: BufferPool::default(),
})
}
/// Join multicast groups for a socket.
pub fn add_mcast_membership(&self, groups: Groups) -> Result<(), SocketError> {
self.socket
.add_mcast_membership(groups)
.map_err(SocketError::from)
}
/// Leave multicast groups for a socket.
pub fn drop_mcast_membership(&self, groups: Groups) -> Result<(), SocketError> {
self.socket
.drop_mcast_membership(groups)
.map_err(SocketError::from)
}
/// List joined groups for a socket.
pub fn list_mcast_membership(&self) -> Result<NetlinkBitArray, SocketError> {
self.socket
.list_mcast_membership()
.map_err(SocketError::from)
}
/// Get the PID for the current socket.
pub fn pid(&self) -> u32 {
self.pid
}
/// Convenience function to send an [`Nlmsghdr`] struct
pub fn send<T, P>(&self, msg: &Nlmsghdr<T, P>) -> Result<(), SocketError>
where
T: NlType + Debug,
P: Size + ToBytes + Debug,
{
trace!("Message sent:\n{msg:?}");
let mut buffer = Cursor::new(vec![0; msg.padded_size()]);
msg.to_bytes(&mut buffer)?;
trace!("Buffer sent: {:?}", buffer.get_ref());
self.socket.send(buffer.get_ref(), Msg::empty())?;
Ok(())
}
/// Convenience function to read a stream of [`Nlmsghdr`]
/// structs one by one using an iterator.
///
/// Returns [`None`] when the stream of messages has been completely processed in
/// the current buffer resulting from a single
/// [`NlSocket::recv`][crate::socket::NlSocket::recv] call.
///
/// See [`NlBufferIter`] for more detailed information.
pub fn recv<T, P>(
&self,
) -> Result<(NlBufferIter<T, P, BufferPoolGuard<'_>>, Groups), SocketError>
where
T: NlType + Debug,
P: Size + FromBytesWithInput<Input = usize> + Debug,
{
let mut buffer = self.pool.acquire();
let (mem_read, groups) = self.socket.recv(&mut buffer, Msg::empty())?;
buffer.reduce_size(mem_read);
trace!("Buffer received: {:?}", buffer.as_ref());
Ok((NlBufferIter::new(Cursor::new(buffer)), groups))
}
/// Parse all [`Nlmsghdr`] structs sent in
/// one network packet and return them all in a list.
///
/// Failure to parse any packet will cause the entire operation
/// to fail. If an error is detected at the application level,
/// this method will discard any non-error
/// [`Nlmsghdr`] structs and only return the
/// error. For a more granular approach, use [`NlSocketHandle::recv`].
pub fn recv_all<T, P>(&self) -> Result<(NlBuffer<T, P>, Groups), SocketError>
where
T: NlType + Debug,
P: Size + FromBytesWithInput<Input = usize> + Debug,
{
let mut buffer = self.pool.acquire();
let (mem_read, groups) = self.socket.recv(&mut buffer, Msg::empty())?;
if mem_read == 0 {
return Ok((NlBuffer::new(), Groups::empty()));
}
buffer.reduce_size(mem_read);
let vec = NlBuffer::from_bytes_with_input(&mut Cursor::new(buffer), mem_read)?;
trace!("Messages received: {vec:?}");
Ok((vec, groups))
}
/// Set the size of the receive buffer for the socket.
///
/// This can be useful when communicating with a service that sends a high volume of
/// messages (especially multicast), and your application cannot process them fast enough,
/// leading to the kernel dropping messages. A larger buffer may help mitigate this.
///
/// The value passed is a hint to the kernel to set the size of the receive buffer.
/// The kernel will double the value provided to account for bookkeeping overhead.
/// The doubled value is capped by the value in `/proc/sys/net/core/rmem_max`.
///
/// The default value is `/proc/sys/net/core/rmem_default`
///
/// See `socket(7)` documentation for `SO_RCVBUF` for more information.
pub fn set_recv_buffer_size(&self, size: usize) -> Result<(), SocketError> {
self.socket
.set_recv_buffer_size(size)
.map_err(SocketError::from)
}
/// If [`true`] is passed in, enable extended ACKs for this socket. If [`false`]
/// is passed in, disable extended ACKs for this socket.
pub fn enable_ext_ack(&self, enable: bool) -> Result<(), SocketError> {
self.socket
.enable_ext_ack(enable)
.map_err(SocketError::from)
}
/// Return [`true`] if an extended ACK is enabled for this socket.
pub fn get_ext_ack_enabled(&self) -> Result<bool, SocketError> {
self.socket.get_ext_ack_enabled().map_err(SocketError::from)
}
/// If [`true`] is passed in, enable strict checking for this socket. If [`false`]
/// is passed in, disable strict checking for for this socket.
/// Only supported by `NlFamily::Route` sockets.
/// Requires Linux >= 4.20.
pub fn enable_strict_checking(&self, enable: bool) -> Result<(), SocketError> {
self.socket
.enable_strict_checking(enable)
.map_err(SocketError::from)
}
/// Return [`true`] if strict checking is enabled for this socket.
/// Only supported by `NlFamily::Route` sockets.
/// Requires Linux >= 4.20.
pub fn get_strict_checking_enabled(&self) -> Result<bool, SocketError> {
self.socket
.get_strict_checking_enabled()
.map_err(SocketError::from)
}
pub(in super::super) fn set_nonblock(&self) -> Result<(), SocketError> {
self.socket.nonblock().map_err(SocketError::from)
}
}
impl AsRawFd for NlSocketHandle {
fn as_raw_fd(&self) -> RawFd {
self.socket.as_raw_fd()
}
}
impl IntoRawFd for NlSocketHandle {
fn into_raw_fd(self) -> RawFd {
self.socket.into_raw_fd()
}
}