1use crate::errors::{NetlinkError, NetlinkErrorKind, Result};
2use bitflags::bitflags;
3use std::fmt;
4use std::mem::size_of;
5
6use crate::core::pack::{NativePack, NativeUnpack};
7
8bitflags! {
9 #[derive(Clone, Copy, PartialEq, PartialOrd)]
11 pub struct MessageFlags: u16 {
12 const REQUEST = 0x0001;
14 const MULTIPART = 0x0002;
16 const ACKNOWLEDGE = 0x0004;
18 const DUMP = 0x0100 | 0x0200;
20 }
21}
22
23#[derive(PartialEq)]
27pub enum MessageMode {
28 None,
30 Acknowledge,
32 Dump,
34}
35
36impl From<MessageFlags> for MessageMode {
37 fn from(value: MessageFlags) -> MessageMode {
38 if value.intersects(MessageFlags::DUMP) {
39 MessageMode::Dump
40 } else if value.intersects(MessageFlags::ACKNOWLEDGE) {
41 MessageMode::Acknowledge
42 } else {
43 MessageMode::None
44 }
45 }
46}
47
48impl From<MessageMode> for MessageFlags {
49 fn from(value: MessageMode) -> MessageFlags {
50 let flags = MessageFlags::REQUEST;
51 match value {
52 MessageMode::None => flags,
53 MessageMode::Acknowledge => flags | MessageFlags::ACKNOWLEDGE,
54 MessageMode::Dump => flags | MessageFlags::DUMP,
55 }
56 }
57}
58
59#[inline]
60pub(crate) fn align_to(len: usize, align_to: usize) -> usize {
61 (len + align_to - 1) & !(align_to - 1)
62}
63
64#[inline]
65pub(crate) fn netlink_align(len: usize) -> usize {
66 align_to(len, 4usize)
67}
68
69#[inline]
70pub(crate) fn netlink_padding(len: usize) -> usize {
71 netlink_align(len) - len
72}
73
74#[repr(C)]
88pub struct Header {
89 pub length: u32,
91 pub identifier: u16,
93 pub flags: u16,
95 pub sequence: u32,
97 pub pid: u32,
99}
100
101impl Header {
102 const HEADER_SIZE: usize = 16;
103
104 pub fn length(&self) -> usize {
106 self.length as usize
107 }
108
109 pub fn data_length(&self) -> usize {
111 self.length() - size_of::<Header>()
112 }
113
114 pub fn padding(&self) -> usize {
116 netlink_padding(self.length())
117 }
118
119 pub fn aligned_length(&self) -> usize {
121 netlink_align(self.length())
122 }
123
124 pub fn aligned_data_length(&self) -> usize {
126 netlink_align(self.data_length())
127 }
128
129 pub fn check_pid(&self, pid: u32) -> bool {
131 self.pid == 0 || self.pid == pid
132 }
133
134 pub fn flags(&self) -> MessageFlags {
137 MessageFlags::from_bits_truncate(self.flags)
138 }
139}
140
141impl fmt::Display for Header {
142 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
143 write!(
144 f,
145 "Length: {0:08x} {0}\nIdentifier: {1:04x}\nFlags: {2:04x}\n\
146 Sequence: {3:08x} {3}\nPID: {4:08x} {4}",
147 self.length, self.identifier, self.flags, self.sequence, self.pid,
148 )
149 }
150}
151
152impl NativePack for Header {
153 fn pack_size(&self) -> usize {
154 Self::HEADER_SIZE
155 }
156 fn pack_unchecked(&self, buffer: &mut [u8]) {
157 self.length.pack_unchecked(buffer);
158 self.identifier.pack_unchecked(&mut buffer[4..]);
159 self.flags.pack_unchecked(&mut buffer[6..]);
160 self.sequence.pack_unchecked(&mut buffer[8..]);
161 self.pid.pack_unchecked(&mut buffer[12..]);
162 }
163}
164
165impl NativeUnpack for Header {
166 fn unpack_unchecked(buffer: &[u8]) -> Self {
167 let length = u32::unpack_unchecked(&buffer[..]);
168 let identifier = u16::unpack_unchecked(&buffer[4..]);
169 let flags = u16::unpack_unchecked(&buffer[6..]);
170 let sequence = u32::unpack_unchecked(&buffer[8..]);
171 let pid = u32::unpack_unchecked(&buffer[12..]);
172 Header {
173 length: length,
174 identifier: identifier,
175 flags: flags,
176 sequence: sequence,
177 pid: pid,
178 }
179 }
180}
181
182pub(crate) struct ErrorMessage {
194 pub header: Header,
195 pub code: i32,
196 pub original_header: Header,
197}
198
199impl ErrorMessage {
200 pub fn unpack(data: &[u8], header: Header) -> Result<(usize, ErrorMessage)> {
201 let size = 4 + Header::HEADER_SIZE;
202 if data.len() < size {
203 return Err(NetlinkError::new(NetlinkErrorKind::NotEnoughData).into());
204 }
205 let code = i32::unpack_unchecked(data);
206 let (_, original) = Header::unpack_with_size(&data[4..])?;
207 Ok((
208 size,
209 ErrorMessage {
210 header: header,
211 code: code,
212 original_header: original,
213 },
214 ))
215 }
216}
217
218pub struct Message {
229 pub header: Header,
231 pub data: Vec<u8>,
233}
234
235impl Message {
236 pub fn unpack(data: &[u8], header: Header) -> Result<(usize, Message)> {
238 let size = header.data_length();
239 let aligned_size = netlink_align(size);
240 if data.len() < aligned_size {
241 return Err(NetlinkError::new(NetlinkErrorKind::NotEnoughData).into());
242 }
243 Ok((
244 aligned_size,
245 Message {
246 header: header,
247 data: (&data[..size]).to_vec(),
248 },
249 ))
250 }
251
252 pub fn pack<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8]> {
254 let slice = self.header.pack(buffer)?;
255 let slice = self.data.pack(slice)?;
256 let padding = self.header.padding();
257 Ok(&mut slice[padding..])
258 }
259}
260
261pub type Messages = Vec<Message>;
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn unpack_header() {
269 let data = [
270 0x12, 0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00,
275 ]; assert!(Header::unpack(&data).is_err());
277 let data = [
278 0x12, 0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
283 ]; let (used, header) = Header::unpack_with_size(&data).unwrap();
285 assert_eq!(used, Header::HEADER_SIZE);
286 assert_eq!(header.length, 18u32);
287 assert_eq!(header.length(), 18usize);
288 assert_eq!(header.data_length(), 2usize);
289 assert_eq!(header.identifier, 0x1000u16);
290 assert_eq!(header.flags, 0x0010u16);
291 assert_eq!(header.sequence, 0x00000001u32);
292 assert_eq!(header.pid, 0x00000004u32);
293 }
294
295 #[test]
296 fn pack_header() {
297 let header = Header {
298 length: 18,
299 identifier: 0x1000,
300 flags: 0x0010,
301 sequence: 1,
302 pid: 4,
303 };
304 let mut buffer = [0u8; 32];
305 {
306 let slice = header.pack(&mut buffer).unwrap();
307 assert_eq!(slice.len(), 16usize);
308 }
309 let data = [
310 0x12, 0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
315 ]; assert_eq!(&buffer[..data.len()], data);
317 }
318
319 #[test]
320 fn unpack_data_message() {
321 let data = [
322 0x12, 0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xaa, 0x55, 0x00, 0x00,
328 ]; let (used, header) = Header::unpack_with_size(&data).unwrap();
330 assert_eq!(used, Header::HEADER_SIZE);
331 assert_eq!(header.length, 18u32);
332 assert_eq!(header.length(), 18usize);
333 assert_eq!(header.data_length(), 2usize);
334 assert_eq!(header.aligned_data_length(), 4usize);
335 assert_eq!(header.identifier, 0x1000u16);
336 assert_eq!(header.flags, 0x0010u16);
337 assert_eq!(header.sequence, 0x00000001u32);
338 assert_eq!(header.pid, 0x00000004u32);
339 let (used, msg) = Message::unpack(&data[used..], header).unwrap();
340 assert_eq!(used, 4usize);
341 assert_eq!(msg.data.len(), 2usize);
342 assert_eq!(msg.data[0], 0xaau8);
343 assert_eq!(msg.data[1], 0x55u8);
344 }
345
346 #[test]
347 fn pack_data_message() {
348 let message = Message {
349 header: Header {
350 length: 18,
351 identifier: 0x1000,
352 flags: 0x0010,
353 sequence: 0x12345678,
354 pid: 1,
355 },
356 data: vec![0xaa, 0x55],
357 };
358 let mut buffer = [0xffu8; 32];
359 {
360 let slice = message.pack(&mut buffer).unwrap();
361 assert_eq!(slice.len(), 12usize);
362 }
363 let data = [
364 0x12, 0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x00, 0x78, 0x56, 0x34, 0x12, 0x01, 0x00, 0x00, 0x00, 0xaa, 0x55, 0xff, 0xff,
370 ]; assert_eq!(&buffer[..data.len()], data);
372 }
373
374 #[test]
375 fn unpack_error_message() {
376 let data = [
377 0x24, 0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x12, 0x00, 0x00, 0x00, 0x00, 0x11, 0x11, 0x00, 0xff, 0xff, 0xff, 0xff, 0x05, 0x00, 0x00, 0x00, ];
389 let (used, header) = Header::unpack_with_size(&data).unwrap();
390 assert_eq!(used, Header::HEADER_SIZE);
391 assert_eq!(header.length, 36u32);
392 assert_eq!(header.length(), 36usize);
393 assert_eq!(header.data_length(), 20usize);
394 assert_eq!(header.aligned_data_length(), 20usize);
395 assert_eq!(header.identifier, 0x1000u16);
396 assert_eq!(header.flags, 0x0010u16);
397 assert_eq!(header.sequence, 0x00000001u32);
398 assert_eq!(header.pid, 0x00000004u32);
399 let (used, msg) = ErrorMessage::unpack(&data[used..], header).unwrap();
400 assert_eq!(used, 20usize);
401 assert_eq!(msg.code, -1);
402 assert_eq!(msg.original_header.length, 18u32);
403 assert_eq!(msg.original_header.identifier, 0x1100u16);
404 assert_eq!(msg.original_header.flags, 0x0011u16);
405 assert_eq!(msg.original_header.sequence, u32::max_value());
406 assert_eq!(msg.original_header.pid, 5u32);
407 }
408}