nftnl_rs/netlink/
socket.rs

1/*-
2 * nftnl-rs - a netlink NFtables firewall.
3 * Copyright (C) 2020 Aleksandr Morozov
4 * 
5 * This Source Code Form is subject to the terms of the Mozilla Public
6 * License, v. 2.0. If a copy of the MPL was not distributed with this
7 *  file, You can obtain one at https://mozilla.org/MPL/2.0/.
8 */
9
10use std::os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd};
11
12use crate::{error::NtflRes, int_error_code};
13
14#[allow(non_camel_case_types)]
15#[repr(C)]
16#[derive(Default)]
17pub struct sockaddr_nl 
18{
19    pub nl_family: libc::sa_family_t,
20    pub nl_pad: libc::c_ushort,
21    pub nl_pid: u32,
22    pub nl_groups: u32,
23}
24
25
26pub struct MnlSocket 
27{
28    /// A raw fd of the netlink socket.
29	fd: OwnedFd,
30
31    /// A socket bind information.
32	addr: sockaddr_nl,
33}
34
35
36impl MnlSocket
37{
38    unsafe 
39    fn mnl_socket_open_intern(bus: i32, flags: libc::c_int) -> NtflRes<Self>
40    {
41        let fd = 
42            unsafe { libc::socket(libc::AF_NETLINK, libc::SOCK_RAW | flags, bus) };
43
44        if fd == -1
45        {
46            int_error_code!(libc::EINVAL, "can not open socket, error: '{}'", std::io::Error::last_os_error());
47        }
48
49        let ret = 
50            Self
51            {
52                fd: unsafe { OwnedFd::from_raw_fd(fd) },
53                addr: sockaddr_nl::default()
54            };
55
56        return Ok(ret);
57    }
58
59    
60    /// mnl_socket_open - open a netlink socket
61    /// 
62    /// # Arguments 
63    /// * `bus` - bus the netlink socket bus ID (see NETLINK_* constants)
64    ///
65    /// # Returns
66    /// 
67    /// A [Result] is retuened which has alias [NtflRes].
68    /// 
69    /// * [Result::Ok] is returned with valid instance.
70    /// 
71    /// * [Result::Err] is returned with error description.
72    pub 
73    fn mnl_socket_open(bus: i32) -> NtflRes<Self>
74    {
75        return unsafe{ Self::mnl_socket_open_intern(bus, 0) };
76    }
77
78    pub
79    fn mnl_socket_bind(&mut self, groups: u32, pid: libc::pid_t) -> NtflRes<()>
80    {
81        //int ret;
82        //socklen_t addr_len;
83
84        self.addr.nl_family = libc::AF_NETLINK as u16;
85        self.addr.nl_groups = groups;
86        self.addr.nl_pid = pid as u32;
87
88       // let mut addr_len = std::mem::size_of_val(&self.addr) as libc::socklen_t;
89       let mut addr_len = std::mem::size_of::<sockaddr_nl>() as libc::socklen_t;
90
91
92        let ret = 
93            unsafe
94            {
95                libc::bind(
96                    self.fd.as_raw_fd(), 
97                    (&self.addr) as *const _ as *const libc::sockaddr, 
98                    addr_len
99                )
100            };
101
102        if ret < 0
103        {
104            int_error_code!(libc::EINVAL, "bind error, '{}'", std::io::Error::last_os_error());
105        }
106        
107        let ret = 
108            unsafe
109            {
110                libc::getsockname(self.fd.as_raw_fd(), &mut self.addr as *mut _ as *mut libc::sockaddr, &mut addr_len as *mut _ as *mut u32)
111            };
112
113        if ret < 0
114        {
115            int_error_code!(libc::EINVAL, "getsockname error, '{}'", std::io::Error::last_os_error());
116        }
117
118        if addr_len != std::mem::size_of_val(&self.addr) as u32
119        {
120            int_error_code!(libc::EINVAL, "size mismatch {} != {}", addr_len, std::mem::size_of_val(&self.addr));
121        }
122
123        if self.addr.nl_family as i32 != libc::AF_NETLINK 
124        {
125            int_error_code!(libc::EINVAL, "wring addr family {} != {}", self.addr.nl_family, libc::AF_NETLINK);
126        }
127
128        return Ok(());
129    }
130
131    /// mnl_socket_get_portid - obtain Netlink PortID from netlink socket
132    /// 
133    /// Copypaste from libmnl:
134    /// This function returns the Netlink PortID of a given netlink socket.
135    /// It's a common mistake to assume that this PortID equals the process ID
136    /// which is not always true. This is the case if you open more than one
137    /// socket that is binded to the same Netlink subsystem from the same process.
138    pub 
139    fn mnl_socket_get_portid(&self) -> u32
140    {
141        return self.addr.nl_pid;
142    }
143
144    
145    /// mnl_socket_sendto - send a netlink message of a certain size
146    /// 
147    /// # Arguments
148    /// 
149    /// * `buf` - buffer containing the netlink message to be sent
150    /// * `len` - number of bytes in the buffer that you want to send
151    ///
152    /// # Result
153    /// 
154    /// A [Result] is retuened which has alias [NtflRes].
155    /// 
156    /// * [Result::Ok] is returned a number of written bytes
157    /// 
158    /// * [Result::Err] is returned with error description.
159    pub 
160    fn mnl_socket_sendto(&self, buf: &[u8], len: usize) -> NtflRes<isize>
161    {
162        let snl: sockaddr_nl = 
163            sockaddr_nl
164            {
165                nl_family: libc::AF_NETLINK as u16,
166                ..Default::default()
167            };
168
169        let res = 
170            unsafe
171            {
172                libc::sendto(self.fd.as_raw_fd(), buf.as_ptr() as *const libc::c_void, 
173                    len as libc::size_t, 0, &snl as *const _ as *const libc::sockaddr, 
174                    std::mem::size_of_val(&snl) as u32)
175            };
176
177        if res < 0
178        {
179            int_error_code!(libc::EINVAL, "sendto() error, '{}'", std::io::Error::last_os_error());
180        }
181
182        return Ok(res);
183    }
184
185    /**
186     * mnl_socket_recvfrom - receive a netlink message
187     * \param nl netlink socket obtained via mnl_socket_open()
188     * \param buf buffer that you want to use to store the netlink message
189     * \param bufsiz size of the buffer passed to store the netlink message
190     *
191     * On error, it returns -1 and errno is appropriately set. If errno is set
192     * to ENOSPC, it means that the buffer that you have passed to store the
193     * netlink message is too small, so you have received a truncated message.
194     * To avoid this, you have to allocate a buffer of MNL_SOCKET_BUFFER_SIZE
195     * (which is 8KB, see linux/netlink.h for more information). Using this
196     * buffer size ensures that your buffer is big enough to store the netlink
197     * message without truncating it.
198     */
199    pub 
200    fn mnl_socket_recvfrom(&self, bufsiz: usize) -> NtflRes<Vec<u8>>
201    {
202        let mut buf: Vec<u8> = vec![0_u8; bufsiz];
203
204        //ssize_t ret;
205        let mut addr: sockaddr_nl = unsafe{ std::mem::zeroed() };
206
207        let mut iov: libc::iovec = 
208            libc::iovec
209            {
210                iov_base: buf.as_mut_ptr() as *mut libc::c_void,
211                iov_len: bufsiz,
212            };
213
214        let mut msg: libc::msghdr = 
215            libc::msghdr
216            {
217                msg_name: &mut addr as *mut _ as *mut libc::c_void,
218                msg_namelen: std::mem::size_of::<sockaddr_nl>() as u32,
219                msg_iov: &mut iov as *mut _ as *mut libc::iovec,
220                msg_iovlen: 1,
221                msg_control: std::ptr::null_mut(),
222                msg_controllen: 0,
223                msg_flags: 0,
224            };
225
226        let ret = 
227            unsafe 
228            {
229                libc::recvmsg(self.fd.as_raw_fd(), &mut msg as *mut _ as *mut libc::msghdr, 0)
230            };
231        
232        if ret == -1
233        {
234            int_error_code!(libc::EINVAL, "recvmsg() error, {}", std::io::Error::last_os_error());
235        }
236
237        if msg.msg_flags & libc::MSG_TRUNC > 0
238        {
239            int_error_code!(libc::EINVAL, "recvmsg() error, {}", std::io::Error::from_raw_os_error(libc::ENOSPC));
240        }
241
242        if msg.msg_namelen != std::mem::size_of::<sockaddr_nl>() as u32 
243        {
244            int_error_code!(libc::EINVAL, "recvmsg() error, {}", std::io::Error::from_raw_os_error(libc::EINVAL));
245        }
246
247        unsafe { buf.set_len(ret as usize) };
248
249        return Ok(buf);
250    }
251
252}
253
254/// Obtains the pid of the current process.
255#[inline]
256pub 
257fn mnl_socket_get_pid() -> libc::pid_t
258{
259    return unsafe{ libc::getpid() };
260}
261
262
263/// Helps to initialize and bind the socket.
264/// 
265/// # Arguments 
266/// 
267/// * `bus_opt` - an optional bus value. If not set, then [crate::NETLINK_NETFILTER]
268///     will be used.
269/// 
270/// * `groups` - a groups number, usually 0.
271/// 
272/// * `pid_opt` - an optional pid value. If not set, then [crate::netlink::MNL_SOCKET_AUTOPID]
273///     will be used.
274/// 
275/// # Returns 
276/// 
277/// A [Result] as type [NtflRes] is returned.
278/// 
279/// * [Result::Ok] - with the instance [MnlSocket] is returned.
280/// 
281/// * [Result::Err] - with the error description.
282pub 
283fn mnl_socket_helper(bus_opt: Option<i32>, groups: u32, pid_opt: Option<libc::pid_t>) -> NtflRes<MnlSocket>
284{
285    let mut nl = 
286        MnlSocket::mnl_socket_open(bus_opt.map_or(crate::NETLINK_NETFILTER, |f| f) )?;
287
288    let pid = pid_opt.map_or(crate::netlink::MNL_SOCKET_AUTOPID, |v| v);
289
290    nl.mnl_socket_bind(groups, pid).unwrap();
291
292    return Ok(nl);
293}