1use libc;
3use std::io::{Error, Result};
4use std::mem;
5use std::os::unix::io::{AsRawFd, RawFd};
6
7use super::Protocol;
8
9pub struct Socket(RawFd);
10
11impl AsRawFd for Socket {
12 fn as_raw_fd(&self) -> RawFd {
13 self.0
14 }
15}
16
17impl Drop for Socket {
18 fn drop(&mut self) {
19 unsafe { libc::close(self.0) };
20 }
21}
22
23#[derive(Copy, Clone)]
24pub struct SocketAddr(libc::sockaddr_nl);
25
26impl SocketAddr {
27 pub fn new(port_number: u32, multicast_groups: u32) -> Self {
28 let mut addr: libc::sockaddr_nl = unsafe { mem::zeroed() };
29 addr.nl_family = libc::PF_NETLINK as libc::sa_family_t;
30 addr.nl_pid = port_number;
31 addr.nl_groups = multicast_groups;
32 SocketAddr(addr)
33 }
34
35 pub fn port_number(&self) -> u32 {
36 self.0.nl_pid
37 }
38
39 pub fn multicast_groups(&self) -> u32 {
40 self.0.nl_groups
41 }
42
43 fn as_raw(&self) -> (*const libc::sockaddr, libc::socklen_t) {
44 let addr_ptr = &self.0 as *const libc::sockaddr_nl as *const libc::sockaddr;
45 let addr_len = mem::size_of::<libc::sockaddr_nl>() as libc::socklen_t;
62 (addr_ptr, addr_len)
63 }
64
65 fn as_raw_mut(&mut self) -> (*mut libc::sockaddr, libc::socklen_t) {
66 let addr_ptr = &mut self.0 as *mut libc::sockaddr_nl as *mut libc::sockaddr;
67 let addr_len = mem::size_of::<libc::sockaddr_nl>() as libc::socklen_t;
68 (addr_ptr, addr_len)
69 }
70}
71
72impl Socket {
73 pub fn new(protocol: Protocol) -> Result<Self> {
74 let res =
75 unsafe { libc::socket(libc::PF_NETLINK, libc::SOCK_DGRAM, protocol as libc::c_int) };
76 if res < 0 {
77 return Err(Error::last_os_error());
78 }
79 Ok(Socket(res))
80 }
81
82 pub fn bind(&mut self, addr: &SocketAddr) -> Result<()> {
83 let (addr_ptr, addr_len) = addr.as_raw();
84 let res = unsafe { libc::bind(self.0, addr_ptr, addr_len) };
85 if res < 0 {
86 return Err(Error::last_os_error());
87 }
88 Ok(())
89 }
90
91 pub fn bind_auto(&mut self) -> Result<SocketAddr> {
92 let mut addr = SocketAddr::new(0, 0);
93 self.bind(&addr)?;
94 self.get_address(&mut addr)?;
95 Ok(addr)
96 }
97
98 pub fn get_address(&self, addr: &mut SocketAddr) -> Result<()> {
99 let (addr_ptr, mut addr_len) = addr.as_raw_mut();
100 let addr_len_copy = addr_len;
101 let addr_len_ptr = &mut addr_len as *mut libc::socklen_t;
102 let res = unsafe { libc::getsockname(self.0, addr_ptr, addr_len_ptr) };
103 if res < 0 {
104 return Err(Error::last_os_error());
105 }
106 assert_eq!(addr_len, addr_len_copy);
107 Ok(())
108 }
109
110 pub fn set_non_blocking(&self, non_blocking: bool) -> Result<()> {
111 let mut non_blocking = non_blocking as libc::c_int;
112 let res = unsafe { libc::ioctl(self.0, libc::FIONBIO, &mut non_blocking) };
113 if res < 0 {
114 return Err(Error::last_os_error());
115 }
116 Ok(())
117 }
118
119 pub fn connect(&self, remote_addr: &SocketAddr) -> Result<()> {
120 let (addr, addr_len) = remote_addr.as_raw();
148 let res = unsafe { libc::connect(self.0, addr, addr_len) };
149 if res < 0 {
150 return Err(Error::last_os_error());
151 }
152 Ok(())
153 }
154
155 pub fn recv_from(&self, buf: &mut [u8], flags: libc::c_int) -> Result<(usize, SocketAddr)> {
158 let mut addr = unsafe { mem::zeroed::<libc::sockaddr_nl>() };
162
163 let addr_ptr = &mut addr as *mut libc::sockaddr_nl as *mut libc::sockaddr;
174
175 let mut addrlen = mem::size_of_val(&addr);
181 let addrlen_ptr = &mut addrlen as *mut usize as *mut libc::socklen_t;
184
185 let buf_ptr = buf.as_mut_ptr() as *mut libc::c_void;
195 let buf_len = buf.len() as libc::size_t;
196
197 let res = unsafe { libc::recvfrom(self.0, buf_ptr, buf_len, flags, addr_ptr, addrlen_ptr) };
198 if res < 0 {
199 return Err(Error::last_os_error());
200 }
201 Ok((res as usize, SocketAddr(addr)))
202 }
203
204 pub fn recv(&self, buf: &mut [u8], flags: libc::c_int) -> Result<usize> {
205 let buf_ptr = buf.as_mut_ptr() as *mut libc::c_void;
206 let buf_len = buf.len() as libc::size_t;
207
208 let res = unsafe { libc::recv(self.0, buf_ptr, buf_len, flags) };
209 if res < 0 {
210 return Err(Error::last_os_error());
211 }
212 Ok(res as usize)
213 }
214
215 pub fn send_to(&self, buf: &[u8], addr: &SocketAddr, flags: libc::c_int) -> Result<usize> {
216 let (addr_ptr, addr_len) = addr.as_raw();
217 let buf_ptr = buf.as_ptr() as *const libc::c_void;
218 let buf_len = buf.len() as libc::size_t;
219
220 let res = unsafe { libc::sendto(self.0, buf_ptr, buf_len, flags, addr_ptr, addr_len) };
221 if res < 0 {
222 return Err(Error::last_os_error());
223 }
224 Ok(res as usize)
225 }
226
227 pub fn send(&self, buf: &[u8], flags: libc::c_int) -> Result<usize> {
228 let buf_ptr = buf.as_ptr() as *const libc::c_void;
229 let buf_len = buf.len() as libc::size_t;
230
231 let res = unsafe { libc::send(self.0, buf_ptr, buf_len, flags) };
232 if res < 0 {
233 return Err(Error::last_os_error());
234 }
235 Ok(res as usize)
236 }
237
238 pub fn set_pktinfo(&mut self, set: bool) -> Result<()> {
239 setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_PKTINFO, set)
240 }
241
242 pub fn get_pktinfo(&self) -> Result<bool> {
243 getsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_PKTINFO)
244 }
245
246 pub fn add_membership(&mut self, group: u32) -> Result<()> {
247 setsockopt(
248 self.0,
249 libc::SOL_NETLINK,
250 libc::NETLINK_ADD_MEMBERSHIP,
251 group,
252 )
253 }
254
255 pub fn drop_membership(&mut self, group: u32) -> Result<()> {
256 setsockopt(
257 self.0,
258 libc::SOL_NETLINK,
259 libc::NETLINK_DROP_MEMBERSHIP,
260 group,
261 )
262 }
263
264 pub fn list_membership(&self) -> Vec<u32> {
265 unimplemented!();
266 }
270
271 pub fn set_broadcast_error(&mut self, set: bool) -> Result<()> {
272 setsockopt(
273 self.0,
274 libc::SOL_NETLINK,
275 libc::NETLINK_BROADCAST_ERROR,
276 set,
277 )
278 }
279
280 pub fn get_broadcast_error(&self) -> Result<bool> {
281 getsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_BROADCAST_ERROR)
282 }
283
284 pub fn set_no_enobufs(&mut self, set: bool) -> Result<()> {
285 setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_NO_ENOBUFS, set)
286 }
287
288 pub fn get_no_enobufs(&self) -> Result<bool> {
289 getsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_NO_ENOBUFS)
290 }
291
292 pub fn set_listen_all_namespaces(&mut self, set: bool) -> Result<()> {
293 setsockopt(
294 self.0,
295 libc::SOL_NETLINK,
296 libc::NETLINK_LISTEN_ALL_NSID,
297 set,
298 )
299 }
300
301 pub fn get_listen_all_namespaces(&self) -> Result<bool> {
302 getsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_LISTEN_ALL_NSID)
303 }
304
305 pub fn set_cap_ack(&mut self, set: bool) -> Result<()> {
306 setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_CAP_ACK, set)
307 }
308
309 pub fn get_cap_ack(&self) -> Result<bool> {
310 getsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_CAP_ACK)
311 }
312}
313
314fn getsockopt<T: Copy>(fd: RawFd, opt: libc::c_int, val: libc::c_int) -> Result<T> {
322 unsafe {
323 let mut slot: T = mem::zeroed();
325
326 let slot_ptr = &mut slot as *mut T as *mut libc::c_void;
328
329 let mut slot_len = mem::size_of::<T>() as libc::socklen_t;
331
332 let slot_len_ptr = &mut slot_len as *mut libc::socklen_t;
336
337 let res = libc::getsockopt(fd, opt, val, slot_ptr, slot_len_ptr);
338 if res < 0 {
339 return Err(Error::last_os_error());
340 }
341
342 assert_eq!(slot_len as usize, mem::size_of::<T>());
345
346 Ok(slot)
347 }
348}
349
350fn setsockopt<T>(fd: RawFd, opt: libc::c_int, val: libc::c_int, payload: T) -> Result<()> {
352 unsafe {
353 let payload = &payload as *const T as *const libc::c_void;
354 let payload_len = mem::size_of::<T>() as libc::socklen_t;
355
356 let res = libc::setsockopt(fd, opt, val, payload, payload_len);
357 if res < 0 {
358 return Err(Error::last_os_error());
359 }
360 }
361 Ok(())
362}
363
364#[cfg(test)]
365mod test {
366 use super::*;
367
368 #[test]
369 fn new() {
370 Socket::new(Protocol::Route).unwrap();
371 }
372
373 #[test]
374 fn connect() {
375 let sock = Socket::new(Protocol::Route).unwrap();
376 sock.connect(&SocketAddr::new(0, 0)).unwrap();
377 }
378
379 #[test]
380 fn bind() {
381 let mut sock = Socket::new(Protocol::Route).unwrap();
382 sock.bind(&SocketAddr::new(4321, 0)).unwrap();
383 }
384
385 #[test]
386 fn bind_auto() {
387 let mut sock = Socket::new(Protocol::Route).unwrap();
388 let addr = sock.bind_auto().unwrap();
389 assert!(addr.port_number() != 0);
391 }
392
393 #[test]
394 fn set_non_blocking() {
395 let sock = Socket::new(Protocol::Route).unwrap();
396 sock.set_non_blocking(true).unwrap();
397 sock.set_non_blocking(false).unwrap();
398 }
399
400 #[test]
427 fn address() {
428 let mut addr = SocketAddr::new(42, 1234);
429 assert_eq!(addr.port_number(), 42);
430 assert_eq!(addr.multicast_groups(), 1234);
431
432 {
433 let (addr_ptr, _) = addr.as_raw();
434 let inner_addr = unsafe { *(addr_ptr as *const libc::sockaddr_nl) };
435 assert_eq!(inner_addr.nl_pid, 42);
436 assert_eq!(inner_addr.nl_groups, 1234);
437 }
438
439 {
440 let (addr_ptr, _) = addr.as_raw_mut();
441 let sockaddr_nl = addr_ptr as *mut libc::sockaddr_nl;
442 unsafe {
443 sockaddr_nl.as_mut().unwrap().nl_pid = 24;
444 sockaddr_nl.as_mut().unwrap().nl_groups = 4321
445 }
446 }
447 assert_eq!(addr.port_number(), 24);
448 assert_eq!(addr.multicast_groups(), 4321);
449 }
450}