netlink_packet_core/
message.rs

1// SPDX-License-Identifier: MIT
2
3use std::fmt::Debug;
4
5use crate::{
6    done::DONE_HEADER_LEN,
7    payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN},
8    DecodeError, DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorContext,
9    ErrorMessage, NetlinkBuffer, NetlinkDeserializable, NetlinkHeader,
10    NetlinkPayload, NetlinkSerializable, Parseable,
11};
12
13/// Represent a netlink message.
14#[derive(Debug, PartialEq, Eq, Clone)]
15#[non_exhaustive]
16pub struct NetlinkMessage<I> {
17    /// Message header (this is common to all the netlink protocols)
18    pub header: NetlinkHeader,
19    /// Inner message, which depends on the netlink protocol being used.
20    pub payload: NetlinkPayload<I>,
21}
22
23impl<I> NetlinkMessage<I> {
24    /// Create a new netlink message from the given header and payload
25    pub fn new(header: NetlinkHeader, payload: NetlinkPayload<I>) -> Self {
26        NetlinkMessage { header, payload }
27    }
28
29    /// Consume this message and return its header and payload
30    pub fn into_parts(self) -> (NetlinkHeader, NetlinkPayload<I>) {
31        (self.header, self.payload)
32    }
33}
34
35impl<I> NetlinkMessage<I>
36where
37    I: NetlinkDeserializable,
38{
39    /// Parse the given buffer as a netlink message
40    pub fn deserialize(buffer: &[u8]) -> Result<Self, DecodeError> {
41        let netlink_buffer = NetlinkBuffer::new_checked(&buffer)
42            .context("failed deserializing NetlinkMessage")?;
43        <Self as Parseable<NetlinkBuffer<&&[u8]>>>::parse(&netlink_buffer)
44    }
45}
46
47impl<I> NetlinkMessage<I>
48where
49    I: NetlinkSerializable,
50{
51    /// Return the length of this message in bytes
52    pub fn buffer_len(&self) -> usize {
53        <Self as Emitable>::buffer_len(self)
54    }
55
56    /// Serialize this message and write the serialized data into the
57    /// given buffer. `buffer` must big large enough for the whole
58    /// message to fit, otherwise, this method will panic. To know how
59    /// big the serialized message is, call `buffer_len()`.
60    ///
61    /// # Panic
62    ///
63    /// This method panics if the buffer is not big enough.
64    pub fn serialize(&self, buffer: &mut [u8]) {
65        self.emit(buffer)
66    }
67
68    /// Ensure the header (`NetlinkHeader`) is consistent with the payload
69    /// (`NetlinkPayload`):
70    ///
71    /// - compute the payload length and set the header's length field
72    /// - check the payload type and set the header's message type field
73    ///   accordingly
74    ///
75    /// If you are not 100% sure the header is correct, this method should be
76    /// called before calling [`Emitable::emit()`](trait.Emitable.html#
77    /// tymethod.emit), as it could panic if the header is inconsistent with
78    /// the rest of the message.
79    pub fn finalize(&mut self) {
80        self.header.length = self.buffer_len() as u32;
81        self.header.message_type = self.payload.message_type();
82    }
83}
84
85impl<B, I> Parseable<NetlinkBuffer<&B>> for NetlinkMessage<I>
86where
87    B: AsRef<[u8]>,
88    I: NetlinkDeserializable,
89{
90    fn parse(buf: &NetlinkBuffer<&B>) -> Result<Self, DecodeError> {
91        use self::NetlinkPayload::*;
92
93        let header =
94            <NetlinkHeader as Parseable<NetlinkBuffer<&B>>>::parse(buf)
95                .context("failed parsing NetlinkHeader")?;
96
97        let bytes = buf.payload();
98        let payload = match header.message_type {
99            NLMSG_ERROR => {
100                let msg = ErrorBuffer::new_checked(&bytes)
101                    .and_then(|buf| ErrorMessage::parse(&buf))
102                    .context("failed parsing NLMSG_ERROR")?;
103                Error(msg)
104            }
105            NLMSG_NOOP => Noop,
106            NLMSG_DONE => {
107                // Linux kernel allows zero sized of NLMSG_DONE
108                let msg = if bytes.is_empty() {
109                    DoneBuffer::new_checked(&[0u8; DONE_HEADER_LEN])
110                        .and_then(|buf| DoneMessage::parse(&buf))
111                        .context("failed to parse NLMSG_DONE")?
112                } else {
113                    DoneBuffer::new_checked(&bytes)
114                        .and_then(|buf| DoneMessage::parse(&buf))
115                        .context("failed to parse NLMSG_DONE")?
116                };
117                Done(msg)
118            }
119            NLMSG_OVERRUN => Overrun(bytes.to_vec()),
120            message_type => match I::deserialize(&header, bytes) {
121                Err(e) => {
122                    return Err(format!(
123                        "Failed to parse message with type {message_type}: {e}"
124                    )
125                    .into())
126                }
127                Ok(inner_msg) => InnerMessage(inner_msg),
128            },
129        };
130        Ok(NetlinkMessage { header, payload })
131    }
132}
133
134impl<I> Emitable for NetlinkMessage<I>
135where
136    I: NetlinkSerializable,
137{
138    fn buffer_len(&self) -> usize {
139        use self::NetlinkPayload::*;
140
141        let payload_len = match self.payload {
142            Noop => 0,
143            Done(ref msg) => msg.buffer_len(),
144            Overrun(ref bytes) => bytes.len(),
145            Error(ref msg) => msg.buffer_len(),
146            InnerMessage(ref msg) => msg.buffer_len(),
147        };
148
149        self.header.buffer_len() + payload_len
150    }
151
152    fn emit(&self, buffer: &mut [u8]) {
153        use self::NetlinkPayload::*;
154
155        self.header.emit(buffer);
156
157        let buffer =
158            &mut buffer[self.header.buffer_len()..self.header.length as usize];
159        match self.payload {
160            Noop => {}
161            Done(ref msg) => msg.emit(buffer),
162            Overrun(ref bytes) => buffer.copy_from_slice(bytes),
163            Error(ref msg) => msg.emit(buffer),
164            InnerMessage(ref msg) => msg.serialize(buffer),
165        }
166    }
167}
168
169impl<T> From<T> for NetlinkMessage<T>
170where
171    T: Into<NetlinkPayload<T>>,
172{
173    fn from(inner_message: T) -> Self {
174        NetlinkMessage {
175            header: NetlinkHeader::default(),
176            payload: inner_message.into(),
177        }
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    use std::{convert::Infallible, mem::size_of, num::NonZeroI32};
186
187    #[derive(Clone, Debug, Default, PartialEq)]
188    struct FakeNetlinkInnerMessage;
189
190    impl NetlinkSerializable for FakeNetlinkInnerMessage {
191        fn message_type(&self) -> u16 {
192            unimplemented!("unused by tests")
193        }
194
195        fn buffer_len(&self) -> usize {
196            unimplemented!("unused by tests")
197        }
198
199        fn serialize(&self, _buffer: &mut [u8]) {
200            unimplemented!("unused by tests")
201        }
202    }
203
204    impl NetlinkDeserializable for FakeNetlinkInnerMessage {
205        type Error = Infallible;
206
207        fn deserialize(
208            _header: &NetlinkHeader,
209            _payload: &[u8],
210        ) -> Result<Self, Self::Error> {
211            unimplemented!("unused by tests")
212        }
213    }
214
215    #[test]
216    fn test_done() {
217        let header = NetlinkHeader::default();
218        let done_msg = DoneMessage {
219            code: 0,
220            extended_ack: vec![6, 7, 8, 9],
221        };
222        let mut want = NetlinkMessage::new(
223            header,
224            NetlinkPayload::<FakeNetlinkInnerMessage>::Done(done_msg.clone()),
225        );
226        want.finalize();
227
228        let len = want.buffer_len();
229        assert_eq!(
230            len,
231            header.buffer_len()
232                + size_of::<i32>()
233                + done_msg.extended_ack.len()
234        );
235
236        let mut buf = vec![1; len];
237        want.emit(&mut buf);
238
239        let done_buf = DoneBuffer::new(&buf[header.buffer_len()..]);
240        assert_eq!(done_buf.code(), done_msg.code);
241        assert_eq!(done_buf.extended_ack(), &done_msg.extended_ack);
242
243        let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap();
244        assert_eq!(got, want);
245    }
246
247    #[test]
248    fn test_error() {
249        // SAFETY: value is non-zero.
250        const ERROR_CODE: NonZeroI32 = NonZeroI32::new(-8765).unwrap();
251
252        let header = NetlinkHeader::default();
253        let error_msg = ErrorMessage {
254            code: Some(ERROR_CODE),
255            header: vec![],
256        };
257        let mut want = NetlinkMessage::new(
258            header,
259            NetlinkPayload::<FakeNetlinkInnerMessage>::Error(error_msg.clone()),
260        );
261        want.finalize();
262
263        let len = want.buffer_len();
264        assert_eq!(len, header.buffer_len() + error_msg.buffer_len());
265
266        let mut buf = vec![1; len];
267        want.emit(&mut buf);
268
269        let error_buf = ErrorBuffer::new(&buf[header.buffer_len()..]);
270        assert_eq!(error_buf.code(), error_msg.code);
271
272        let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap();
273        assert_eq!(got, want);
274    }
275}