netlink_socket/
sys.rs

1//! Netlink socket related functions
2use libc;
3use std::io::{Error, Result};
4use std::mem;
5use std::os::unix::io::{AsRawFd, RawFd};
6
7use super::Protocol;
8
9pub struct Socket(RawFd);
10
11impl AsRawFd for Socket {
12    fn as_raw_fd(&self) -> RawFd {
13        self.0
14    }
15}
16
17impl Drop for Socket {
18    fn drop(&mut self) {
19        unsafe { libc::close(self.0) };
20    }
21}
22
23#[derive(Copy, Clone)]
24pub struct SocketAddr(libc::sockaddr_nl);
25
26impl SocketAddr {
27    pub fn new(port_number: u32, multicast_groups: u32) -> Self {
28        let mut addr: libc::sockaddr_nl = unsafe { mem::zeroed() };
29        addr.nl_family = libc::PF_NETLINK as libc::sa_family_t;
30        addr.nl_pid = port_number;
31        addr.nl_groups = multicast_groups;
32        SocketAddr(addr)
33    }
34
35    pub fn port_number(&self) -> u32 {
36        self.0.nl_pid
37    }
38
39    pub fn multicast_groups(&self) -> u32 {
40        self.0.nl_groups
41    }
42
43    fn as_raw(&self) -> (*const libc::sockaddr, libc::socklen_t) {
44        let addr_ptr = &self.0 as *const libc::sockaddr_nl as *const libc::sockaddr;
45        //             \                                 / \                      /
46        //              +---------------+---------------+   +----------+---------+
47        //                               |                             |
48        //                               v                             |
49        //             create a raw pointer to the sockaddr_nl         |
50        //                                                             v
51        //                                                cast *sockaddr_nl -> *sockaddr
52        //
53        // This kind of things seems to be pretty usual when using C APIs from Rust. It could be
54        // written in a shorter way thank to type inference:
55        //
56        //      let addr_ptr: *const libc:sockaddr = &self.0 as *const _ as *const _;
57        //
58        // But since this is my first time dealing with this kind of things I chose the most
59        // explicit form.
60
61        let addr_len = mem::size_of::<libc::sockaddr_nl>() as libc::socklen_t;
62        (addr_ptr, addr_len)
63    }
64
65    fn as_raw_mut(&mut self) -> (*mut libc::sockaddr, libc::socklen_t) {
66        let addr_ptr = &mut self.0 as *mut libc::sockaddr_nl as *mut libc::sockaddr;
67        let addr_len = mem::size_of::<libc::sockaddr_nl>() as libc::socklen_t;
68        (addr_ptr, addr_len)
69    }
70}
71
72impl Socket {
73    pub fn new(protocol: Protocol) -> Result<Self> {
74        let res =
75            unsafe { libc::socket(libc::PF_NETLINK, libc::SOCK_DGRAM, protocol as libc::c_int) };
76        if res < 0 {
77            return Err(Error::last_os_error());
78        }
79        Ok(Socket(res))
80    }
81
82    pub fn bind(&mut self, addr: &SocketAddr) -> Result<()> {
83        let (addr_ptr, addr_len) = addr.as_raw();
84        let res = unsafe { libc::bind(self.0, addr_ptr, addr_len) };
85        if res < 0 {
86            return Err(Error::last_os_error());
87        }
88        Ok(())
89    }
90
91    pub fn bind_auto(&mut self) -> Result<SocketAddr> {
92        let mut addr = SocketAddr::new(0, 0);
93        self.bind(&addr)?;
94        self.get_address(&mut addr)?;
95        Ok(addr)
96    }
97
98    pub fn get_address(&self, addr: &mut SocketAddr) -> Result<()> {
99        let (addr_ptr, mut addr_len) = addr.as_raw_mut();
100        let addr_len_copy = addr_len;
101        let addr_len_ptr = &mut addr_len as *mut libc::socklen_t;
102        let res = unsafe { libc::getsockname(self.0, addr_ptr, addr_len_ptr) };
103        if res < 0 {
104            return Err(Error::last_os_error());
105        }
106        assert_eq!(addr_len, addr_len_copy);
107        Ok(())
108    }
109
110    pub fn set_non_blocking(&self, non_blocking: bool) -> Result<()> {
111        let mut non_blocking = non_blocking as libc::c_int;
112        let res = unsafe { libc::ioctl(self.0, libc::FIONBIO, &mut non_blocking) };
113        if res < 0 {
114            return Err(Error::last_os_error());
115        }
116        Ok(())
117    }
118
119    pub fn connect(&self, remote_addr: &SocketAddr) -> Result<()> {
120        // Event though for SOCK_DGRAM sockets there's no IO, since our socket is non-blocking,
121        // connect() might return EINPROGRESS. In theory, the right way to treat EINPROGRESS would
122        // be to ignore the error, and let the user poll the socket to check when it becomes
123        // writable, indicating that the connection succeeded. The code already exists in mio for
124        // TcpStream:
125        //
126        // > pub fn connect(stream: net::TcpStream, addr: &SocketAddr) -> io::Result<TcpStream> {
127        // >     set_non_block(stream.as_raw_fd())?;
128        // >     match stream.connect(addr) {
129        // >         Ok(..) => {}
130        // >         Err(ref e) if e.raw_os_error() == Some(libc::EINPROGRESS) => {}
131        // >         Err(e) => return Err(e),
132        // >     }
133        // >     Ok(TcpStream {  inner: stream })
134        // > }
135        //
136        // The polling to wait for the connection is available in the tokio-tcp crate. See:
137        // https://github.com/tokio-rs/tokio/blob/363b207f2b6c25857c70d76b303356db87212f59/tokio-tcp/src/stream.rs#L706
138        //
139        // In practice, since the connection does not require any IO for SOCK_DGRAM sockets, it
140        // almost never returns EINPROGRESS and so for now, we just return whatever libc::connect
141        // returns. If it returns EINPROGRESS, the caller will have to handle the error themself
142        //
143        // Refs:
144        //
145        // - https://stackoverflow.com/a/14046386/1836144
146        // - https://lists.isc.org/pipermail/bind-users/2009-August/077527.html
147        let (addr, addr_len) = remote_addr.as_raw();
148        let res = unsafe { libc::connect(self.0, addr, addr_len) };
149        if res < 0 {
150            return Err(Error::last_os_error());
151        }
152        Ok(())
153    }
154
155    // Most of the comments in this method come from a discussion on rust users forum.
156    // [thread]: https://users.rust-lang.org/t/help-understanding-libc-call/17308/9
157    pub fn recv_from(&self, buf: &mut [u8], flags: libc::c_int) -> Result<(usize, SocketAddr)> {
158        // Create an empty storage for the address. Note that Rust standard library create a
159        // sockaddr_storage so that it works for any address family, but here, we already know that
160        // we'll have a Netlink address, so we can create the appropriate storage.
161        let mut addr = unsafe { mem::zeroed::<libc::sockaddr_nl>() };
162
163        // recvfrom takes a *sockaddr as parameter so that it can accept any kind of address
164        // storage, so we need to create such a pointer for the sockaddr_nl we just initialized.
165        //
166        //                     Create a raw pointer to        Cast our raw pointer to a
167        //                     our storage. We cannot         generic pointer to *sockaddr
168        //                     pass it to recvfrom yet.       that recvfrom can use
169        //                                 ^                              ^
170        //                                 |                              |
171        //                  +--------------+---------------+    +---------+--------+
172        //                 /                                \  /                    \
173        let addr_ptr = &mut addr as *mut libc::sockaddr_nl as *mut libc::sockaddr;
174
175        // Why do we need to pass the address length? We're passing a generic *sockaddr to
176        // recvfrom. Somehow recvfrom needs to make sure that the address of the received packet
177        // would fit into the actual type that is behind *sockaddr: it could be a sockaddr_nl but
178        // also a sockaddr_in, a sockaddr_in6, or even the generic sockaddr_storage that can store
179        // any address.
180        let mut addrlen = mem::size_of_val(&addr);
181        // recvfrom does not take the address length by value (see [thread]), so we need to create
182        // a pointer to it.
183        let addrlen_ptr = &mut addrlen as *mut usize as *mut libc::socklen_t;
184
185        //                      Cast the *mut u8 into *mut void.
186        //               This is equivalent to casting a *char into *void
187        //                                 See [thread]
188        //                                       ^
189        //           Create a *mut u8            |
190        //                   ^                   |
191        //                   |                   |
192        //             +-----+-----+    +--------+-------+
193        //            /             \  /                  \
194        let buf_ptr = buf.as_mut_ptr() as *mut libc::c_void;
195        let buf_len = buf.len() as libc::size_t;
196
197        let res = unsafe { libc::recvfrom(self.0, buf_ptr, buf_len, flags, addr_ptr, addrlen_ptr) };
198        if res < 0 {
199            return Err(Error::last_os_error());
200        }
201        Ok((res as usize, SocketAddr(addr)))
202    }
203
204    pub fn recv(&self, buf: &mut [u8], flags: libc::c_int) -> Result<usize> {
205        let buf_ptr = buf.as_mut_ptr() as *mut libc::c_void;
206        let buf_len = buf.len() as libc::size_t;
207
208        let res = unsafe { libc::recv(self.0, buf_ptr, buf_len, flags) };
209        if res < 0 {
210            return Err(Error::last_os_error());
211        }
212        Ok(res as usize)
213    }
214
215    pub fn send_to(&self, buf: &[u8], addr: &SocketAddr, flags: libc::c_int) -> Result<usize> {
216        let (addr_ptr, addr_len) = addr.as_raw();
217        let buf_ptr = buf.as_ptr() as *const libc::c_void;
218        let buf_len = buf.len() as libc::size_t;
219
220        let res = unsafe { libc::sendto(self.0, buf_ptr, buf_len, flags, addr_ptr, addr_len) };
221        if res < 0 {
222            return Err(Error::last_os_error());
223        }
224        Ok(res as usize)
225    }
226
227    pub fn send(&self, buf: &[u8], flags: libc::c_int) -> Result<usize> {
228        let buf_ptr = buf.as_ptr() as *const libc::c_void;
229        let buf_len = buf.len() as libc::size_t;
230
231        let res = unsafe { libc::send(self.0, buf_ptr, buf_len, flags) };
232        if res < 0 {
233            return Err(Error::last_os_error());
234        }
235        Ok(res as usize)
236    }
237
238    pub fn set_pktinfo(&mut self, set: bool) -> Result<()> {
239        setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_PKTINFO, set)
240    }
241
242    pub fn get_pktinfo(&self) -> Result<bool> {
243        getsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_PKTINFO)
244    }
245
246    pub fn add_membership(&mut self, group: u32) -> Result<()> {
247        setsockopt(
248            self.0,
249            libc::SOL_NETLINK,
250            libc::NETLINK_ADD_MEMBERSHIP,
251            group,
252        )
253    }
254
255    pub fn drop_membership(&mut self, group: u32) -> Result<()> {
256        setsockopt(
257            self.0,
258            libc::SOL_NETLINK,
259            libc::NETLINK_DROP_MEMBERSHIP,
260            group,
261        )
262    }
263
264    pub fn list_membership(&self) -> Vec<u32> {
265        unimplemented!();
266        // getsockopt won't be enough here, because we may need to perform 2 calls, and because the
267        // length of the list returned by libc::getsockopt is returned by mutating the length
268        // argument, which our implementation of getsockopt forbids.
269    }
270
271    pub fn set_broadcast_error(&mut self, set: bool) -> Result<()> {
272        setsockopt(
273            self.0,
274            libc::SOL_NETLINK,
275            libc::NETLINK_BROADCAST_ERROR,
276            set,
277        )
278    }
279
280    pub fn get_broadcast_error(&self) -> Result<bool> {
281        getsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_BROADCAST_ERROR)
282    }
283
284    pub fn set_no_enobufs(&mut self, set: bool) -> Result<()> {
285        setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_NO_ENOBUFS, set)
286    }
287
288    pub fn get_no_enobufs(&self) -> Result<bool> {
289        getsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_NO_ENOBUFS)
290    }
291
292    pub fn set_listen_all_namespaces(&mut self, set: bool) -> Result<()> {
293        setsockopt(
294            self.0,
295            libc::SOL_NETLINK,
296            libc::NETLINK_LISTEN_ALL_NSID,
297            set,
298        )
299    }
300
301    pub fn get_listen_all_namespaces(&self) -> Result<bool> {
302        getsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_LISTEN_ALL_NSID)
303    }
304
305    pub fn set_cap_ack(&mut self, set: bool) -> Result<()> {
306        setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_CAP_ACK, set)
307    }
308
309    pub fn get_cap_ack(&self) -> Result<bool> {
310        getsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_CAP_ACK)
311    }
312}
313
314// FIXME: setsockopt and getsockopt don't work... We get always get EINVAL, which the manpage
315// describes as:
316//
317//  > The specified option is invalid at the specified socket level or the socket has been shut
318//  > down.
319
320// adapted from rust standard library
321fn getsockopt<T: Copy>(fd: RawFd, opt: libc::c_int, val: libc::c_int) -> Result<T> {
322    unsafe {
323        // Create storage for the options we're fetching
324        let mut slot: T = mem::zeroed();
325
326        // Create a mutable raw pointer to the storage so that getsockopt can fill the value
327        let slot_ptr = &mut slot as *mut T as *mut libc::c_void;
328
329        // Let getsockopt know how big our storage is
330        let mut slot_len = mem::size_of::<T>() as libc::socklen_t;
331
332        // getsockopt takes a mutable pointer to the length, because for some options like
333        // NETLINK_LIST_MEMBERSHIP where the option value is a list with arbitrary length,
334        // getsockopt uses this parameter to signal how big the storage needs to be.
335        let slot_len_ptr = &mut slot_len as *mut libc::socklen_t;
336
337        let res = libc::getsockopt(fd, opt, val, slot_ptr, slot_len_ptr);
338        if res < 0 {
339            return Err(Error::last_os_error());
340        }
341
342        // Ignore the options that require the legnth to be set by getsockopt.
343        // We'll deal with them individually.
344        assert_eq!(slot_len as usize, mem::size_of::<T>());
345
346        Ok(slot)
347    }
348}
349
350// adapted from rust standard library
351fn setsockopt<T>(fd: RawFd, opt: libc::c_int, val: libc::c_int, payload: T) -> Result<()> {
352    unsafe {
353        let payload = &payload as *const T as *const libc::c_void;
354        let payload_len = mem::size_of::<T>() as libc::socklen_t;
355
356        let res = libc::setsockopt(fd, opt, val, payload, payload_len);
357        if res < 0 {
358            return Err(Error::last_os_error());
359        }
360    }
361    Ok(())
362}
363
364#[cfg(test)]
365mod test {
366    use super::*;
367
368    #[test]
369    fn new() {
370        Socket::new(Protocol::Route).unwrap();
371    }
372
373    #[test]
374    fn connect() {
375        let sock = Socket::new(Protocol::Route).unwrap();
376        sock.connect(&SocketAddr::new(0, 0)).unwrap();
377    }
378
379    #[test]
380    fn bind() {
381        let mut sock = Socket::new(Protocol::Route).unwrap();
382        sock.bind(&SocketAddr::new(4321, 0)).unwrap();
383    }
384
385    #[test]
386    fn bind_auto() {
387        let mut sock = Socket::new(Protocol::Route).unwrap();
388        let addr = sock.bind_auto().unwrap();
389        // make sure that the address we got from the kernel is there
390        assert!(addr.port_number() != 0);
391    }
392
393    #[test]
394    fn set_non_blocking() {
395        let sock = Socket::new(Protocol::Route).unwrap();
396        sock.set_non_blocking(true).unwrap();
397        sock.set_non_blocking(false).unwrap();
398    }
399
400    // FIXME!
401    // #[test]
402    // fn options() {
403    //     let mut sock = Socket::new(Protocol::Route).unwrap();
404
405    //     sock.set_no_enobufs(true).unwrap();
406    //     assert!(sock.get_no_enobufs().unwrap());
407    //     sock.set_no_enobufs(false).unwrap();
408    //     assert!(!sock.get_no_enobufs().unwrap());
409
410    //     sock.set_broadcast_error(true).unwrap();
411    //     assert!(sock.get_broadcast_error().unwrap());
412    //     sock.set_broadcast_error(false).unwrap();
413    //     assert!(!sock.get_broadcast_error().unwrap());
414
415    //     sock.set_cap_ack(true).unwrap();
416    //     assert!(sock.get_cap_ack().unwrap());
417    //     sock.set_cap_ack(false).unwrap();
418    //     assert!(!sock.get_cap_ack().unwrap());
419
420    //     sock.set_listen_all_namespaces(true).unwrap();
421    //     assert!(sock.get_listen_all_namespaces().unwrap());
422    //     sock.set_listen_all_namespaces(false).unwrap();
423    //     assert!(!sock.get_listen_all_namespaces().unwrap());
424    // }
425
426    #[test]
427    fn address() {
428        let mut addr = SocketAddr::new(42, 1234);
429        assert_eq!(addr.port_number(), 42);
430        assert_eq!(addr.multicast_groups(), 1234);
431
432        {
433            let (addr_ptr, _) = addr.as_raw();
434            let inner_addr = unsafe { *(addr_ptr as *const libc::sockaddr_nl) };
435            assert_eq!(inner_addr.nl_pid, 42);
436            assert_eq!(inner_addr.nl_groups, 1234);
437        }
438
439        {
440            let (addr_ptr, _) = addr.as_raw_mut();
441            let sockaddr_nl = addr_ptr as *mut libc::sockaddr_nl;
442            unsafe {
443                sockaddr_nl.as_mut().unwrap().nl_pid = 24;
444                sockaddr_nl.as_mut().unwrap().nl_groups = 4321
445            }
446        }
447        assert_eq!(addr.port_number(), 24);
448        assert_eq!(addr.multicast_groups(), 4321);
449    }
450}