netlink_rust/core/
message.rs

1use crate::errors::{NetlinkError, NetlinkErrorKind, Result};
2use bitflags::bitflags;
3use std::fmt;
4use std::mem::size_of;
5
6use crate::core::pack::{NativePack, NativeUnpack};
7
8bitflags! {
9    /// Message flags
10    #[derive(Clone, Copy, PartialEq, PartialOrd)]
11    pub struct MessageFlags: u16 {
12        /// Request message
13        const REQUEST     = 0x0001;
14        /// Multo-part message
15        const MULTIPART   = 0x0002;
16        /// Acknowledge message
17        const ACKNOWLEDGE = 0x0004;
18        /// Dump message
19        const DUMP        = 0x0100 | 0x0200;
20    }
21}
22
23/// Message mode
24///
25/// Flags wich describes how the messages will be hadled
26#[derive(PartialEq)]
27pub enum MessageMode {
28    /// No special flags
29    None,
30    /// Acknowledge message
31    Acknowledge,
32    /// Dump message
33    Dump,
34}
35
36impl From<MessageFlags> for MessageMode {
37    fn from(value: MessageFlags) -> MessageMode {
38        if value.intersects(MessageFlags::DUMP) {
39            MessageMode::Dump
40        } else if value.intersects(MessageFlags::ACKNOWLEDGE) {
41            MessageMode::Acknowledge
42        } else {
43            MessageMode::None
44        }
45    }
46}
47
48impl From<MessageMode> for MessageFlags {
49    fn from(value: MessageMode) -> MessageFlags {
50        let flags = MessageFlags::REQUEST;
51        match value {
52            MessageMode::None => flags,
53            MessageMode::Acknowledge => flags | MessageFlags::ACKNOWLEDGE,
54            MessageMode::Dump => flags | MessageFlags::DUMP,
55        }
56    }
57}
58
59#[inline]
60pub(crate) fn align_to(len: usize, align_to: usize) -> usize {
61    (len + align_to - 1) & !(align_to - 1)
62}
63
64#[inline]
65pub(crate) fn netlink_align(len: usize) -> usize {
66    align_to(len, 4usize)
67}
68
69#[inline]
70pub(crate) fn netlink_padding(len: usize) -> usize {
71    netlink_align(len) - len
72}
73
74/// Netlink message header
75///
76/// ```text
77/// | length | identifier | flags | sequence | pid |
78/// |--------|------------|-------|----------|-----|
79/// |   u32  |     u16    |  u16  |   u32    | u32 |
80/// ```
81///
82/// Length is the total length of the message in bytes, including the header.
83/// Message data comes after the header. The data is 4 byte aligned, which
84/// means that the actual length message length might be longer than indicated
85/// by the length field.
86///
87#[repr(C)]
88pub struct Header {
89    /// Message length
90    pub length: u32,
91    /// Message identifier
92    pub identifier: u16,
93    /// Message flags
94    pub flags: u16,
95    /// Message sequence
96    pub sequence: u32,
97    /// Message process identifier
98    pub pid: u32,
99}
100
101impl Header {
102    const HEADER_SIZE: usize = 16;
103
104    /// Returns the length including the header
105    pub fn length(&self) -> usize {
106        self.length as usize
107    }
108
109    /// Returns the length of the data section
110    pub fn data_length(&self) -> usize {
111        self.length() - size_of::<Header>()
112    }
113
114    /// Returns padding length in octets
115    pub fn padding(&self) -> usize {
116        netlink_padding(self.length())
117    }
118
119    /// Returns length including header and padding
120    pub fn aligned_length(&self) -> usize {
121        netlink_align(self.length())
122    }
123
124    /// Returns length of the data section header and padding
125    pub fn aligned_data_length(&self) -> usize {
126        netlink_align(self.data_length())
127    }
128
129    /// Check if the message pid equals provided pid or broadcast (0)
130    pub fn check_pid(&self, pid: u32) -> bool {
131        self.pid == 0 || self.pid == pid
132    }
133
134    /// Check if the message sequence number equals the  provided sequence
135    /// number or broadcast (0)
136    pub fn flags(&self) -> MessageFlags {
137        MessageFlags::from_bits_truncate(self.flags)
138    }
139}
140
141impl fmt::Display for Header {
142    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
143        write!(
144            f,
145            "Length: {0:08x} {0}\nIdentifier: {1:04x}\nFlags: {2:04x}\n\
146             Sequence: {3:08x} {3}\nPID: {4:08x} {4}",
147            self.length, self.identifier, self.flags, self.sequence, self.pid,
148        )
149    }
150}
151
152impl NativePack for Header {
153    fn pack_size(&self) -> usize {
154        Self::HEADER_SIZE
155    }
156    fn pack_unchecked(&self, buffer: &mut [u8]) {
157        self.length.pack_unchecked(buffer);
158        self.identifier.pack_unchecked(&mut buffer[4..]);
159        self.flags.pack_unchecked(&mut buffer[6..]);
160        self.sequence.pack_unchecked(&mut buffer[8..]);
161        self.pid.pack_unchecked(&mut buffer[12..]);
162    }
163}
164
165impl NativeUnpack for Header {
166    fn unpack_unchecked(buffer: &[u8]) -> Self {
167        let length = u32::unpack_unchecked(&buffer[..]);
168        let identifier = u16::unpack_unchecked(&buffer[4..]);
169        let flags = u16::unpack_unchecked(&buffer[6..]);
170        let sequence = u32::unpack_unchecked(&buffer[8..]);
171        let pid = u32::unpack_unchecked(&buffer[12..]);
172        Header {
173            length: length,
174            identifier: identifier,
175            flags: flags,
176            sequence: sequence,
177            pid: pid,
178        }
179    }
180}
181
182/// Netlink error message
183///
184/// ```text
185/// | header |  error code  | Original Header |
186/// |--------|--------------|-----------------|
187/// | Header |      i32     |     Header      |
188/// ```
189///
190/// Header is the message header, See [Header](struct.Header.html).
191/// The error code is an errno number reported by the kernel.
192/// The original header is the header of the message that caused this error.
193pub(crate) struct ErrorMessage {
194    pub header: Header,
195    pub code: i32,
196    pub original_header: Header,
197}
198
199impl ErrorMessage {
200    pub fn unpack(data: &[u8], header: Header) -> Result<(usize, ErrorMessage)> {
201        let size = 4 + Header::HEADER_SIZE;
202        if data.len() < size {
203            return Err(NetlinkError::new(NetlinkErrorKind::NotEnoughData).into());
204        }
205        let code = i32::unpack_unchecked(data);
206        let (_, original) = Header::unpack_with_size(&data[4..])?;
207        Ok((
208            size,
209            ErrorMessage {
210                header: header,
211                code: code,
212                original_header: original,
213            },
214        ))
215    }
216}
217
218/// Netlink data message
219///
220/// ```text
221/// | header |    data     | padding |
222/// |--------|-------------|---------|
223/// | Header | u8 * length |         |
224/// ```
225///
226/// Header is the message header, See [Header](struct.Header.html).
227/// The data is 4 byte aligned.
228pub struct Message {
229    /// Message header
230    pub header: Header,
231    /// Message data
232    pub data: Vec<u8>,
233}
234
235impl Message {
236    /// Unpack Message from byte slice and message header
237    pub fn unpack(data: &[u8], header: Header) -> Result<(usize, Message)> {
238        let size = header.data_length();
239        let aligned_size = netlink_align(size);
240        if data.len() < aligned_size {
241            return Err(NetlinkError::new(NetlinkErrorKind::NotEnoughData).into());
242        }
243        Ok((
244            aligned_size,
245            Message {
246                header: header,
247                data: (&data[..size]).to_vec(),
248            },
249        ))
250    }
251
252    /// Pack data into byte slice
253    pub fn pack<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8]> {
254        let slice = self.header.pack(buffer)?;
255        let slice = self.data.pack(slice)?;
256        let padding = self.header.padding();
257        Ok(&mut slice[padding..])
258    }
259}
260
261pub type Messages = Vec<Message>;
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    #[test]
268    fn unpack_header() {
269        let data = [
270            0x12, 0x00, 0x00, 0x00, // size
271            0x00, 0x10, // identifier
272            0x10, 0x00, // flags
273            0x01, 0x00, 0x00, 0x00, // sequence
274            0x04, 0x00, 0x00,
275        ]; // pid
276        assert!(Header::unpack(&data).is_err());
277        let data = [
278            0x12, 0x00, 0x00, 0x00, // size
279            0x00, 0x10, // identifier
280            0x10, 0x00, // flags
281            0x01, 0x00, 0x00, 0x00, // sequence
282            0x04, 0x00, 0x00, 0x00,
283        ]; // pid
284        let (used, header) = Header::unpack_with_size(&data).unwrap();
285        assert_eq!(used, Header::HEADER_SIZE);
286        assert_eq!(header.length, 18u32);
287        assert_eq!(header.length(), 18usize);
288        assert_eq!(header.data_length(), 2usize);
289        assert_eq!(header.identifier, 0x1000u16);
290        assert_eq!(header.flags, 0x0010u16);
291        assert_eq!(header.sequence, 0x00000001u32);
292        assert_eq!(header.pid, 0x00000004u32);
293    }
294
295    #[test]
296    fn pack_header() {
297        let header = Header {
298            length: 18,
299            identifier: 0x1000,
300            flags: 0x0010,
301            sequence: 1,
302            pid: 4,
303        };
304        let mut buffer = [0u8; 32];
305        {
306            let slice = header.pack(&mut buffer).unwrap();
307            assert_eq!(slice.len(), 16usize);
308        }
309        let data = [
310            0x12, 0x00, 0x00, 0x00, // size
311            0x00, 0x10, // identifier
312            0x10, 0x00, // flags
313            0x01, 0x00, 0x00, 0x00, // sequence
314            0x04, 0x00, 0x00, 0x00,
315        ]; // pid
316        assert_eq!(&buffer[..data.len()], data);
317    }
318
319    #[test]
320    fn unpack_data_message() {
321        let data = [
322            0x12, 0x00, 0x00, 0x00, // size
323            0x00, 0x10, // identifier
324            0x10, 0x00, // flags
325            0x01, 0x00, 0x00, 0x00, // sequence
326            0x04, 0x00, 0x00, 0x00, // pid
327            0xaa, 0x55, 0x00, 0x00,
328        ]; // data with padding
329        let (used, header) = Header::unpack_with_size(&data).unwrap();
330        assert_eq!(used, Header::HEADER_SIZE);
331        assert_eq!(header.length, 18u32);
332        assert_eq!(header.length(), 18usize);
333        assert_eq!(header.data_length(), 2usize);
334        assert_eq!(header.aligned_data_length(), 4usize);
335        assert_eq!(header.identifier, 0x1000u16);
336        assert_eq!(header.flags, 0x0010u16);
337        assert_eq!(header.sequence, 0x00000001u32);
338        assert_eq!(header.pid, 0x00000004u32);
339        let (used, msg) = Message::unpack(&data[used..], header).unwrap();
340        assert_eq!(used, 4usize);
341        assert_eq!(msg.data.len(), 2usize);
342        assert_eq!(msg.data[0], 0xaau8);
343        assert_eq!(msg.data[1], 0x55u8);
344    }
345
346    #[test]
347    fn pack_data_message() {
348        let message = Message {
349            header: Header {
350                length: 18,
351                identifier: 0x1000,
352                flags: 0x0010,
353                sequence: 0x12345678,
354                pid: 1,
355            },
356            data: vec![0xaa, 0x55],
357        };
358        let mut buffer = [0xffu8; 32];
359        {
360            let slice = message.pack(&mut buffer).unwrap();
361            assert_eq!(slice.len(), 12usize);
362        }
363        let data = [
364            0x12, 0x00, 0x00, 0x00, // size
365            0x00, 0x10, // identifier
366            0x10, 0x00, // flags
367            0x78, 0x56, 0x34, 0x12, // sequence
368            0x01, 0x00, 0x00, 0x00, // pid
369            0xaa, 0x55, 0xff, 0xff,
370        ]; // padded data
371        assert_eq!(&buffer[..data.len()], data);
372    }
373
374    #[test]
375    fn unpack_error_message() {
376        let data = [
377            0x24, 0x00, 0x00, 0x00, // size
378            0x00, 0x10, // identifier
379            0x10, 0x00, // flags
380            0x01, 0x00, 0x00, 0x00, // sequence
381            0x04, 0x00, 0x00, 0x00, // pid
382            0xff, 0xff, 0xff, 0xff, // error code
383            0x12, 0x00, 0x00, 0x00, // size
384            0x00, 0x11, // identifier
385            0x11, 0x00, // flags
386            0xff, 0xff, 0xff, 0xff, // sequence
387            0x05, 0x00, 0x00, 0x00, // pid
388        ];
389        let (used, header) = Header::unpack_with_size(&data).unwrap();
390        assert_eq!(used, Header::HEADER_SIZE);
391        assert_eq!(header.length, 36u32);
392        assert_eq!(header.length(), 36usize);
393        assert_eq!(header.data_length(), 20usize);
394        assert_eq!(header.aligned_data_length(), 20usize);
395        assert_eq!(header.identifier, 0x1000u16);
396        assert_eq!(header.flags, 0x0010u16);
397        assert_eq!(header.sequence, 0x00000001u32);
398        assert_eq!(header.pid, 0x00000004u32);
399        let (used, msg) = ErrorMessage::unpack(&data[used..], header).unwrap();
400        assert_eq!(used, 20usize);
401        assert_eq!(msg.code, -1);
402        assert_eq!(msg.original_header.length, 18u32);
403        assert_eq!(msg.original_header.identifier, 0x1100u16);
404        assert_eq!(msg.original_header.flags, 0x0011u16);
405        assert_eq!(msg.original_header.sequence, u32::max_value());
406        assert_eq!(msg.original_header.pid, 5u32);
407    }
408}