netlink_rust/core/
socket.rs

1use std::collections::HashMap;
2use std::io;
3use std::mem::size_of;
4use std::os::unix::io::{AsRawFd, RawFd};
5
6use libc;
7
8use crate::errors::{NetlinkError, NetlinkErrorKind, Result};
9
10use crate::core::message::{
11    netlink_align, ErrorMessage, Header, Message, MessageFlags, MessageMode, Messages,
12};
13use crate::core::pack::{NativePack, NativeUnpack};
14use crate::core::system;
15use crate::core::Protocol;
16
17/// Trait for message to be sent by the socket
18pub trait SendMessage {
19    /// Pack the message into the provided byte slice
20    fn pack(&self, data: &mut [u8]) -> Result<usize>;
21    /// Get the message type
22    fn message_type(&self) -> u16;
23    /// Get the query flags
24    fn query_flags(&self) -> MessageFlags;
25}
26
27const NLMSG_NOOP: u16 = 1;
28const NLMSG_ERROR: u16 = 2;
29const NLMSG_DONE: u16 = 3;
30// const NLMSG_OVERRUN: u16 = 4;
31
32const NETLINK_ADD_MEMBERSHIP: i32 = 1;
33// const NETLINK_DROP_MEMBERSHIP: i32 = 2;
34// const NETLINK_PKTINFO: i32 = 3;
35// const NETLINK_BROADCAST_ERROR: i32 = 4;
36// const NETLINK_NO_ENOBUFS: i32 = 5;
37// const NETLINK_RX_RING: i32 = 6;
38// const NETLINK_TX_RING: i32 = 7;
39// const NETLINK_LISTEN_ALL_NSID: i32 = 8;
40// const NETLINK_LIST_MEMBERSHIPS: i32 = 9;
41// const NETLINK_CAP_ACK: i32 = 10;
42// const NETLINK_EXT_ACK: i32 = 11;
43
44/// Netlink Socket can be used to communicate with the Linux kernel using the
45/// netlink protocol.
46pub struct Socket {
47    local: system::Address,
48    peer: system::Address,
49    socket: RawFd,
50    sequence_next: u32,
51    page_size: usize,
52    receive_buffer: Vec<u8>,
53    send_buffer: Vec<u8>,
54    sent: HashMap<u32, MessageMode>,
55}
56
57impl Socket {
58    /// Create a new Socket
59    pub fn new(protocol: Protocol) -> Result<Socket> {
60        Socket::new_multicast(protocol, 0)
61    }
62
63    /// Create a new Socket which subscribes to the provided multi-cast groups
64    pub fn new_multicast(protocol: Protocol, groups: u32) -> Result<Socket> {
65        let socket = system::netlink_socket(protocol as i32)?;
66        system::set_socket_option(socket, libc::SOL_SOCKET, libc::SO_SNDBUF, 32768)?;
67        system::set_socket_option(socket, libc::SOL_SOCKET, libc::SO_RCVBUF, 32768)?;
68        let mut local_addr = system::Address {
69            family: libc::AF_NETLINK as u16,
70            _pad: 0,
71            pid: 0,
72            groups: groups,
73        };
74        system::bind(socket, &mut local_addr)?;
75        system::get_socket_address(socket, &mut local_addr)?;
76        let page_size = netlink_align(system::get_page_size());
77        let peer_addr = system::Address {
78            family: libc::AF_NETLINK as u16,
79            _pad: 0,
80            pid: 0,
81            groups: groups,
82        };
83        Ok(Socket {
84            local: local_addr,
85            peer: peer_addr,
86            socket: socket,
87            sequence_next: 1,
88            page_size: page_size,
89            receive_buffer: vec![0u8; page_size],
90            send_buffer: vec![0u8; page_size],
91            sent: HashMap::new(),
92        })
93    }
94
95    /// Subscribe to the multi-cast group provided
96    pub fn multicast_group_subscribe(&mut self, group: u32) -> Result<()> {
97        system::set_socket_option(
98            self.socket,
99            libc::SOL_NETLINK,
100            NETLINK_ADD_MEMBERSHIP,
101            group,
102        )?;
103        Ok(())
104    }
105
106    fn message_header(&mut self, iov: &mut [libc::iovec]) -> libc::msghdr {
107        let addr_ptr = &mut self.peer as *mut system::Address;
108        #[cfg(not(target_env = "musl"))]
109        let hdr = {
110            let iov_len = iov.len();
111            let hdr = libc::msghdr {
112                msg_iovlen: iov_len,
113                msg_iov: iov.as_mut_ptr(),
114                msg_namelen: size_of::<system::Address>() as u32,
115                msg_name: addr_ptr as *mut libc::c_void,
116                msg_flags: 0,
117                msg_controllen: 0,
118                msg_control: 0 as *mut libc::c_void,
119            };
120            hdr
121        };
122        #[cfg(target_env = "musl")]
123        let hdr = {
124            let iov_len = iov.len() as libc::c_int;
125            let mut hdr: libc::msghdr = unsafe { mem::zeroed() };
126            hdr.msg_iovlen = iov_len;
127            hdr.msg_iov = iov.as_mut_ptr();
128            hdr.msg_namelen = size_of::<system::Address>() as u32;
129            hdr.msg_name = addr_ptr as *mut libc::c_void;
130            hdr.msg_flags = 0;
131            hdr.msg_controllen = 0;
132            hdr.msg_control = 0 as *mut libc::c_void;
133            hdr
134        };
135        hdr
136    }
137
138    /// Send the provided package on the socket
139    pub fn send_message<S: SendMessage>(&mut self, payload: &S) -> Result<usize> {
140        let hdr_size = size_of::<Header>();
141        let flags = payload.query_flags();
142        let payload_size = payload.pack(&mut self.send_buffer[hdr_size..])?;
143        let size = hdr_size + payload_size;
144        let hdr = Header {
145            length: size as u32,
146            identifier: payload.message_type(),
147            flags: flags.bits(),
148            sequence: self.sequence_next,
149            pid: self.local.pid,
150        };
151        let _slice = hdr.pack(&mut self.send_buffer[..hdr_size])?;
152
153        self.sent
154            .insert(self.sequence_next, MessageMode::from(flags));
155        self.sequence_next += 1;
156
157        let sent_size = system::send(self.socket, &self.send_buffer[..size], 0)?;
158        Ok(sent_size)
159    }
160
161    fn receive_bytes(&mut self) -> Result<usize> {
162        let mut iov = [libc::iovec {
163            iov_base: self.receive_buffer.as_mut_ptr() as *mut libc::c_void,
164            iov_len: self.page_size,
165        }];
166
167        let mut msg_header = self.message_header(&mut iov);
168        let result = system::receive_message(self.socket, &mut msg_header);
169        match result {
170            Err(err) => {
171                if err.raw_os_error() == Some(libc::EAGAIN) {
172                    return Ok(0);
173                }
174                Err(err.into())
175            }
176            Ok(bytes) => Ok(bytes),
177        }
178    }
179
180    /// Receive binary data on the socket
181    pub fn receive(&mut self) -> Result<Vec<u8>> {
182        let bytes = self.receive_bytes()?;
183        Ok(self.receive_buffer[0..bytes].to_vec())
184    }
185
186    /// Receive Messages pending on the socket
187    pub fn receive_messages(&mut self) -> Result<Messages> {
188        let mut more_messages = true;
189        let mut result_messages = Vec::new();
190        while more_messages {
191            match self.receive_bytes() {
192                Err(err) => {
193                    return Err(err);
194                }
195                Ok(bytes) => {
196                    if bytes == 0 {
197                        break;
198                    }
199                    more_messages = self.unpack_data(bytes, &mut result_messages)?;
200                }
201            }
202        }
203        Ok(result_messages)
204    }
205
206    fn check_sequence(&self, sequence: &u32) -> bool {
207        if *sequence == 0 {
208            return true;
209        }
210        self.sent.contains_key(sequence)
211    }
212
213    fn expect_more(&self, sequence: &u32) -> bool {
214        if *sequence == 0 {
215            return false;
216        }
217        assert!(self.sent.contains_key(sequence));
218        if let Some(f) = self.sent.get(sequence) {
219            return *f != MessageMode::None;
220        }
221        false
222    }
223
224    fn unpack_data(&mut self, bytes: usize, messages: &mut Messages) -> Result<bool> {
225        let mut more_messages = false;
226        let data = &self.receive_buffer[..bytes];
227        let mut pos = 0;
228        while pos < bytes {
229            let (used, header) = Header::unpack_with_size(&data[pos..])?;
230
231            pos = pos + used;
232            if !header.check_pid(self.local.pid) {
233                return Err(NetlinkError::new(NetlinkErrorKind::InvalidValue).into());
234            }
235            if !self.check_sequence(&header.sequence) {
236                return Err(NetlinkError::new(NetlinkErrorKind::InvalidValue).into());
237            }
238            let sequence = header.sequence;
239            if header.identifier == NLMSG_NOOP {
240                continue;
241            } else if header.identifier == NLMSG_ERROR {
242                self.sent.remove(&sequence);
243                let (used, emsg) = ErrorMessage::unpack(&data[pos..], header)?;
244                pos = pos + used;
245                if emsg.code != 0 {
246                    return Err(io::Error::from_raw_os_error(-emsg.code).into());
247                } else {
248                    more_messages = false;
249                }
250            } else if header.identifier == NLMSG_DONE {
251                self.sent.remove(&sequence);
252                more_messages = false;
253                pos = pos + header.aligned_data_length();
254            } else {
255                let flags = MessageFlags::from_bits(header.flags).unwrap_or(MessageFlags::empty());
256                more_messages =
257                    flags.contains(MessageFlags::MULTIPART) || self.expect_more(&sequence);
258                let (used, msg) = Message::unpack(&data[pos..], header)?;
259                pos = pos + used;
260                messages.push(msg);
261            }
262        }
263        return Ok(more_messages);
264    }
265}
266
267impl AsRawFd for Socket {
268    fn as_raw_fd(&self) -> RawFd {
269        self.socket
270    }
271}