netlink_rust/core/
socket.rs1use std::collections::HashMap;
2use std::io;
3use std::mem::size_of;
4use std::os::unix::io::{AsRawFd, RawFd};
5
6use libc;
7
8use crate::errors::{NetlinkError, NetlinkErrorKind, Result};
9
10use crate::core::message::{
11 netlink_align, ErrorMessage, Header, Message, MessageFlags, MessageMode, Messages,
12};
13use crate::core::pack::{NativePack, NativeUnpack};
14use crate::core::system;
15use crate::core::Protocol;
16
17pub trait SendMessage {
19 fn pack(&self, data: &mut [u8]) -> Result<usize>;
21 fn message_type(&self) -> u16;
23 fn query_flags(&self) -> MessageFlags;
25}
26
27const NLMSG_NOOP: u16 = 1;
28const NLMSG_ERROR: u16 = 2;
29const NLMSG_DONE: u16 = 3;
30const NETLINK_ADD_MEMBERSHIP: i32 = 1;
33pub struct Socket {
47 local: system::Address,
48 peer: system::Address,
49 socket: RawFd,
50 sequence_next: u32,
51 page_size: usize,
52 receive_buffer: Vec<u8>,
53 send_buffer: Vec<u8>,
54 sent: HashMap<u32, MessageMode>,
55}
56
57impl Socket {
58 pub fn new(protocol: Protocol) -> Result<Socket> {
60 Socket::new_multicast(protocol, 0)
61 }
62
63 pub fn new_multicast(protocol: Protocol, groups: u32) -> Result<Socket> {
65 let socket = system::netlink_socket(protocol as i32)?;
66 system::set_socket_option(socket, libc::SOL_SOCKET, libc::SO_SNDBUF, 32768)?;
67 system::set_socket_option(socket, libc::SOL_SOCKET, libc::SO_RCVBUF, 32768)?;
68 let mut local_addr = system::Address {
69 family: libc::AF_NETLINK as u16,
70 _pad: 0,
71 pid: 0,
72 groups: groups,
73 };
74 system::bind(socket, &mut local_addr)?;
75 system::get_socket_address(socket, &mut local_addr)?;
76 let page_size = netlink_align(system::get_page_size());
77 let peer_addr = system::Address {
78 family: libc::AF_NETLINK as u16,
79 _pad: 0,
80 pid: 0,
81 groups: groups,
82 };
83 Ok(Socket {
84 local: local_addr,
85 peer: peer_addr,
86 socket: socket,
87 sequence_next: 1,
88 page_size: page_size,
89 receive_buffer: vec![0u8; page_size],
90 send_buffer: vec![0u8; page_size],
91 sent: HashMap::new(),
92 })
93 }
94
95 pub fn multicast_group_subscribe(&mut self, group: u32) -> Result<()> {
97 system::set_socket_option(
98 self.socket,
99 libc::SOL_NETLINK,
100 NETLINK_ADD_MEMBERSHIP,
101 group,
102 )?;
103 Ok(())
104 }
105
106 fn message_header(&mut self, iov: &mut [libc::iovec]) -> libc::msghdr {
107 let addr_ptr = &mut self.peer as *mut system::Address;
108 #[cfg(not(target_env = "musl"))]
109 let hdr = {
110 let iov_len = iov.len();
111 let hdr = libc::msghdr {
112 msg_iovlen: iov_len,
113 msg_iov: iov.as_mut_ptr(),
114 msg_namelen: size_of::<system::Address>() as u32,
115 msg_name: addr_ptr as *mut libc::c_void,
116 msg_flags: 0,
117 msg_controllen: 0,
118 msg_control: 0 as *mut libc::c_void,
119 };
120 hdr
121 };
122 #[cfg(target_env = "musl")]
123 let hdr = {
124 let iov_len = iov.len() as libc::c_int;
125 let mut hdr: libc::msghdr = unsafe { mem::zeroed() };
126 hdr.msg_iovlen = iov_len;
127 hdr.msg_iov = iov.as_mut_ptr();
128 hdr.msg_namelen = size_of::<system::Address>() as u32;
129 hdr.msg_name = addr_ptr as *mut libc::c_void;
130 hdr.msg_flags = 0;
131 hdr.msg_controllen = 0;
132 hdr.msg_control = 0 as *mut libc::c_void;
133 hdr
134 };
135 hdr
136 }
137
138 pub fn send_message<S: SendMessage>(&mut self, payload: &S) -> Result<usize> {
140 let hdr_size = size_of::<Header>();
141 let flags = payload.query_flags();
142 let payload_size = payload.pack(&mut self.send_buffer[hdr_size..])?;
143 let size = hdr_size + payload_size;
144 let hdr = Header {
145 length: size as u32,
146 identifier: payload.message_type(),
147 flags: flags.bits(),
148 sequence: self.sequence_next,
149 pid: self.local.pid,
150 };
151 let _slice = hdr.pack(&mut self.send_buffer[..hdr_size])?;
152
153 self.sent
154 .insert(self.sequence_next, MessageMode::from(flags));
155 self.sequence_next += 1;
156
157 let sent_size = system::send(self.socket, &self.send_buffer[..size], 0)?;
158 Ok(sent_size)
159 }
160
161 fn receive_bytes(&mut self) -> Result<usize> {
162 let mut iov = [libc::iovec {
163 iov_base: self.receive_buffer.as_mut_ptr() as *mut libc::c_void,
164 iov_len: self.page_size,
165 }];
166
167 let mut msg_header = self.message_header(&mut iov);
168 let result = system::receive_message(self.socket, &mut msg_header);
169 match result {
170 Err(err) => {
171 if err.raw_os_error() == Some(libc::EAGAIN) {
172 return Ok(0);
173 }
174 Err(err.into())
175 }
176 Ok(bytes) => Ok(bytes),
177 }
178 }
179
180 pub fn receive(&mut self) -> Result<Vec<u8>> {
182 let bytes = self.receive_bytes()?;
183 Ok(self.receive_buffer[0..bytes].to_vec())
184 }
185
186 pub fn receive_messages(&mut self) -> Result<Messages> {
188 let mut more_messages = true;
189 let mut result_messages = Vec::new();
190 while more_messages {
191 match self.receive_bytes() {
192 Err(err) => {
193 return Err(err);
194 }
195 Ok(bytes) => {
196 if bytes == 0 {
197 break;
198 }
199 more_messages = self.unpack_data(bytes, &mut result_messages)?;
200 }
201 }
202 }
203 Ok(result_messages)
204 }
205
206 fn check_sequence(&self, sequence: &u32) -> bool {
207 if *sequence == 0 {
208 return true;
209 }
210 self.sent.contains_key(sequence)
211 }
212
213 fn expect_more(&self, sequence: &u32) -> bool {
214 if *sequence == 0 {
215 return false;
216 }
217 assert!(self.sent.contains_key(sequence));
218 if let Some(f) = self.sent.get(sequence) {
219 return *f != MessageMode::None;
220 }
221 false
222 }
223
224 fn unpack_data(&mut self, bytes: usize, messages: &mut Messages) -> Result<bool> {
225 let mut more_messages = false;
226 let data = &self.receive_buffer[..bytes];
227 let mut pos = 0;
228 while pos < bytes {
229 let (used, header) = Header::unpack_with_size(&data[pos..])?;
230
231 pos = pos + used;
232 if !header.check_pid(self.local.pid) {
233 return Err(NetlinkError::new(NetlinkErrorKind::InvalidValue).into());
234 }
235 if !self.check_sequence(&header.sequence) {
236 return Err(NetlinkError::new(NetlinkErrorKind::InvalidValue).into());
237 }
238 let sequence = header.sequence;
239 if header.identifier == NLMSG_NOOP {
240 continue;
241 } else if header.identifier == NLMSG_ERROR {
242 self.sent.remove(&sequence);
243 let (used, emsg) = ErrorMessage::unpack(&data[pos..], header)?;
244 pos = pos + used;
245 if emsg.code != 0 {
246 return Err(io::Error::from_raw_os_error(-emsg.code).into());
247 } else {
248 more_messages = false;
249 }
250 } else if header.identifier == NLMSG_DONE {
251 self.sent.remove(&sequence);
252 more_messages = false;
253 pos = pos + header.aligned_data_length();
254 } else {
255 let flags = MessageFlags::from_bits(header.flags).unwrap_or(MessageFlags::empty());
256 more_messages =
257 flags.contains(MessageFlags::MULTIPART) || self.expect_more(&sequence);
258 let (used, msg) = Message::unpack(&data[pos..], header)?;
259 pos = pos + used;
260 messages.push(msg);
261 }
262 }
263 return Ok(more_messages);
264 }
265}
266
267impl AsRawFd for Socket {
268 fn as_raw_fd(&self) -> RawFd {
269 self.socket
270 }
271}