netlink_packet_core/
error.rs

1// SPDX-License-Identifier: MIT
2
3use std::{fmt, io, mem::size_of, num::NonZeroI32};
4
5use crate::{emit_i32, parse_i32, Emitable, Field, Parseable, Rest};
6
7const CODE: Field = 0..4;
8const PAYLOAD: Rest = 4..;
9const ERROR_HEADER_LEN: usize = PAYLOAD.start;
10
11pub trait ErrorContext<T: std::fmt::Display> {
12    fn context(self, msg: T) -> Self;
13}
14
15#[derive(Debug)]
16pub struct DecodeError {
17    msg: String,
18}
19
20impl<T: std::fmt::Display> ErrorContext<T> for DecodeError {
21    fn context(self, msg: T) -> Self {
22        Self {
23            msg: format!("{} caused by {}", msg, self.msg),
24        }
25    }
26}
27
28impl<T, M> ErrorContext<M> for Result<T, DecodeError>
29where
30    M: std::fmt::Display,
31{
32    fn context(self, msg: M) -> Result<T, DecodeError> {
33        match self {
34            Ok(t) => Ok(t),
35            Err(e) => Err(e.context(msg)),
36        }
37    }
38}
39
40impl From<&str> for DecodeError {
41    fn from(msg: &str) -> Self {
42        Self {
43            msg: msg.to_string(),
44        }
45    }
46}
47
48impl From<String> for DecodeError {
49    fn from(msg: String) -> Self {
50        Self { msg }
51    }
52}
53
54impl From<std::string::FromUtf8Error> for DecodeError {
55    fn from(err: std::string::FromUtf8Error) -> Self {
56        Self {
57            msg: format!("Invalid UTF-8 sequence: {}", err),
58        }
59    }
60}
61
62impl std::fmt::Display for DecodeError {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        write!(f, "{}", self.msg)
65    }
66}
67
68impl std::error::Error for DecodeError {}
69
70impl DecodeError {
71    pub fn invalid_buffer(
72        name: &str,
73        received: usize,
74        minimum_length: usize,
75    ) -> Self {
76        Self {
77            msg: format!(
78                "Invalid buffer {name}. Expected at least {minimum_length} \
79                 bytes, received {received} bytes"
80            ),
81        }
82    }
83    pub fn invalid_mac_address(received: usize) -> Self {
84        Self {
85            msg: format!(
86                "Invalid MAC address. Expected 6 bytes, received {received} \
87                 bytes"
88            ),
89        }
90    }
91
92    pub fn invalid_ip_address(received: usize) -> Self {
93        Self {
94            msg: format!(
95                "Invalid IP address. Expected 4 or 16 bytes, received \
96                 {received} bytes"
97            ),
98        }
99    }
100
101    pub fn invalid_number(expected: usize, received: usize) -> Self {
102        Self {
103            msg: format!(
104                "Invalid number. Expected {expected} bytes, received \
105                 {received} bytes"
106            ),
107        }
108    }
109
110    pub fn nla_buffer_too_small(buffer_len: usize, nla_len: usize) -> Self {
111        Self {
112            msg: format!(
113                "buffer has length {buffer_len}, but an NLA header is \
114                 {nla_len} bytes"
115            ),
116        }
117    }
118
119    pub fn nla_length_mismatch(buffer_len: usize, nla_len: usize) -> Self {
120        Self {
121            msg: format!(
122                "buffer has length: {buffer_len}, but the NLA is {nla_len} \
123                 bytes"
124            ),
125        }
126    }
127
128    pub fn nla_invalid_length(buffer_len: usize, nla_len: usize) -> Self {
129        Self {
130            msg: format!(
131                "NLA has invalid length: {nla_len} (should be at least \
132                 {buffer_len} bytes)"
133            ),
134        }
135    }
136
137    pub fn buffer_too_small(buffer_len: usize, value_len: usize) -> Self {
138        Self {
139            msg: format!(
140                "Buffer too small: {buffer_len} (should be at least \
141                 {value_len} bytes"
142            ),
143        }
144    }
145}
146
147#[derive(Debug, PartialEq, Eq, Clone)]
148#[non_exhaustive]
149pub struct ErrorBuffer<T> {
150    buffer: T,
151}
152
153impl<T: AsRef<[u8]>> ErrorBuffer<T> {
154    pub fn new(buffer: T) -> ErrorBuffer<T> {
155        ErrorBuffer { buffer }
156    }
157
158    /// Consume the packet, returning the underlying buffer.
159    pub fn into_inner(self) -> T {
160        self.buffer
161    }
162
163    pub fn new_checked(buffer: T) -> Result<Self, DecodeError> {
164        let packet = Self::new(buffer);
165        packet
166            .check_buffer_length()
167            .context("invalid ErrorBuffer length")?;
168        Ok(packet)
169    }
170
171    fn check_buffer_length(&self) -> Result<(), DecodeError> {
172        let len = self.buffer.as_ref().len();
173        if len < ERROR_HEADER_LEN {
174            Err(DecodeError {
175                msg: format!(
176                    "invalid ErrorBuffer: length is {len} but ErrorBuffer are \
177                     at least {ERROR_HEADER_LEN} bytes"
178                ),
179            })
180        } else {
181            Ok(())
182        }
183    }
184
185    /// Return the error code.
186    ///
187    /// Returns `None` when there is no error to report (the message is an ACK),
188    /// or a `Some(e)` if there is a non-zero error code `e` to report (the
189    /// message is a NACK).
190    pub fn code(&self) -> Option<NonZeroI32> {
191        let data = self.buffer.as_ref();
192        NonZeroI32::new(parse_i32(&data[CODE]).unwrap())
193    }
194}
195
196impl<'a, T: AsRef<[u8]> + ?Sized> ErrorBuffer<&'a T> {
197    /// Return a pointer to the payload.
198    pub fn payload(&self) -> &'a [u8] {
199        let data = self.buffer.as_ref();
200        &data[PAYLOAD]
201    }
202}
203
204impl<T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> ErrorBuffer<&mut T> {
205    /// Return a mutable pointer to the payload.
206    pub fn payload_mut(&mut self) -> &mut [u8] {
207        let data = self.buffer.as_mut();
208        &mut data[PAYLOAD]
209    }
210}
211
212impl<T: AsRef<[u8]> + AsMut<[u8]>> ErrorBuffer<T> {
213    /// set the error code field
214    pub fn set_code(&mut self, value: i32) {
215        let data = self.buffer.as_mut();
216        emit_i32(&mut data[CODE], value).unwrap();
217    }
218}
219
220/// An `NLMSG_ERROR` message.
221///
222/// Per [RFC 3549 section 2.3.2.2], this message carries the return code for a
223/// request which will indicate either success (an ACK) or failure (a NACK).
224///
225/// [RFC 3549 section 2.3.2.2]: https://datatracker.ietf.org/doc/html/rfc3549#section-2.3.2.2
226#[derive(Debug, Default, Clone, PartialEq, Eq)]
227#[non_exhaustive]
228pub struct ErrorMessage {
229    /// The error code.
230    ///
231    /// Holds `None` when there is no error to report (the message is an ACK),
232    /// or a `Some(e)` if there is a non-zero error code `e` to report (the
233    /// message is a NACK).
234    ///
235    /// See [Netlink message types] for details.
236    ///
237    /// [Netlink message types]: https://kernel.org/doc/html/next/userspace-api/netlink/intro.html#netlink-message-types
238    pub code: Option<NonZeroI32>,
239    /// The original request's header.
240    pub header: Vec<u8>,
241}
242
243impl Emitable for ErrorMessage {
244    fn buffer_len(&self) -> usize {
245        size_of::<i32>() + self.header.len()
246    }
247    fn emit(&self, buffer: &mut [u8]) {
248        let mut buffer = ErrorBuffer::new(buffer);
249        buffer.set_code(self.raw_code());
250        buffer.payload_mut().copy_from_slice(&self.header)
251    }
252}
253
254impl<T: AsRef<[u8]>> Parseable<ErrorBuffer<&T>> for ErrorMessage {
255    fn parse(buf: &ErrorBuffer<&T>) -> Result<ErrorMessage, DecodeError> {
256        // FIXME: The payload of an error is basically a truncated packet, which
257        // requires custom logic to parse correctly. For now we just
258        // return it as a Vec<u8> let header: NetlinkHeader = {
259        //     NetlinkBuffer::new_checked(self.payload())
260        //         .context("failed to parse netlink header")?
261        //         .parse()
262        //         .context("failed to parse nelink header")?
263        // };
264        Ok(ErrorMessage {
265            code: buf.code(),
266            header: buf.payload().to_vec(),
267        })
268    }
269}
270
271impl ErrorMessage {
272    /// Returns the raw error code.
273    pub fn raw_code(&self) -> i32 {
274        self.code.map_or(0, NonZeroI32::get)
275    }
276
277    /// According to [`netlink(7)`](https://linux.die.net/man/7/netlink)
278    /// the `NLMSG_ERROR` return Negative errno or 0 for acknowledgements.
279    ///
280    /// convert into [`std::io::Error`](https://doc.rust-lang.org/std/io/struct.Error.html)
281    /// using the absolute value from errno code
282    pub fn to_io(&self) -> io::Error {
283        io::Error::from_raw_os_error(self.raw_code().abs())
284    }
285}
286
287impl fmt::Display for ErrorMessage {
288    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
289        fmt::Display::fmt(&self.to_io(), f)
290    }
291}
292
293impl From<ErrorMessage> for io::Error {
294    fn from(e: ErrorMessage) -> io::Error {
295        e.to_io()
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn into_io_error() {
305        let io_err = io::Error::from_raw_os_error(95);
306        let err_msg = ErrorMessage {
307            code: NonZeroI32::new(-95),
308            header: vec![],
309        };
310
311        let to_io: io::Error = err_msg.to_io();
312
313        assert_eq!(err_msg.to_string(), io_err.to_string());
314        assert_eq!(to_io.raw_os_error(), io_err.raw_os_error());
315    }
316
317    #[test]
318    fn parse_ack() {
319        let bytes = vec![0, 0, 0, 0];
320        let msg = ErrorBuffer::new_checked(&bytes)
321            .and_then(|buf| ErrorMessage::parse(&buf))
322            .expect("failed to parse NLMSG_ERROR");
323        assert_eq!(
324            ErrorMessage {
325                code: None,
326                header: Vec::new()
327            },
328            msg
329        );
330        assert_eq!(msg.raw_code(), 0);
331    }
332
333    #[test]
334    fn parse_nack() {
335        // SAFETY: value is non-zero.
336        const ERROR_CODE: NonZeroI32 = NonZeroI32::new(-1234).unwrap();
337        let mut bytes = vec![0, 0, 0, 0];
338        emit_i32(&mut bytes, ERROR_CODE.get()).unwrap();
339        let msg = ErrorBuffer::new_checked(&bytes)
340            .and_then(|buf| ErrorMessage::parse(&buf))
341            .expect("failed to parse NLMSG_ERROR");
342        assert_eq!(
343            ErrorMessage {
344                code: Some(ERROR_CODE),
345                header: Vec::new()
346            },
347            msg
348        );
349        assert_eq!(msg.raw_code(), ERROR_CODE.get());
350    }
351}