1use std::{fmt, io, mem::size_of, num::NonZeroI32};
4
5use crate::{emit_i32, parse_i32, Emitable, Field, Parseable, Rest};
6
7const CODE: Field = 0..4;
8const PAYLOAD: Rest = 4..;
9const ERROR_HEADER_LEN: usize = PAYLOAD.start;
10
11pub trait ErrorContext<T: std::fmt::Display> {
12 fn context(self, msg: T) -> Self;
13}
14
15#[derive(Debug)]
16pub struct DecodeError {
17 msg: String,
18}
19
20impl<T: std::fmt::Display> ErrorContext<T> for DecodeError {
21 fn context(self, msg: T) -> Self {
22 Self {
23 msg: format!("{} caused by {}", msg, self.msg),
24 }
25 }
26}
27
28impl<T, M> ErrorContext<M> for Result<T, DecodeError>
29where
30 M: std::fmt::Display,
31{
32 fn context(self, msg: M) -> Result<T, DecodeError> {
33 match self {
34 Ok(t) => Ok(t),
35 Err(e) => Err(e.context(msg)),
36 }
37 }
38}
39
40impl From<&str> for DecodeError {
41 fn from(msg: &str) -> Self {
42 Self {
43 msg: msg.to_string(),
44 }
45 }
46}
47
48impl From<String> for DecodeError {
49 fn from(msg: String) -> Self {
50 Self { msg }
51 }
52}
53
54impl From<std::string::FromUtf8Error> for DecodeError {
55 fn from(err: std::string::FromUtf8Error) -> Self {
56 Self {
57 msg: format!("Invalid UTF-8 sequence: {}", err),
58 }
59 }
60}
61
62impl std::fmt::Display for DecodeError {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.msg)
65 }
66}
67
68impl std::error::Error for DecodeError {}
69
70impl DecodeError {
71 pub fn invalid_buffer(
72 name: &str,
73 received: usize,
74 minimum_length: usize,
75 ) -> Self {
76 Self {
77 msg: format!(
78 "Invalid buffer {name}. Expected at least {minimum_length} \
79 bytes, received {received} bytes"
80 ),
81 }
82 }
83 pub fn invalid_mac_address(received: usize) -> Self {
84 Self {
85 msg: format!(
86 "Invalid MAC address. Expected 6 bytes, received {received} \
87 bytes"
88 ),
89 }
90 }
91
92 pub fn invalid_ip_address(received: usize) -> Self {
93 Self {
94 msg: format!(
95 "Invalid IP address. Expected 4 or 16 bytes, received \
96 {received} bytes"
97 ),
98 }
99 }
100
101 pub fn invalid_number(expected: usize, received: usize) -> Self {
102 Self {
103 msg: format!(
104 "Invalid number. Expected {expected} bytes, received \
105 {received} bytes"
106 ),
107 }
108 }
109
110 pub fn nla_buffer_too_small(buffer_len: usize, nla_len: usize) -> Self {
111 Self {
112 msg: format!(
113 "buffer has length {buffer_len}, but an NLA header is \
114 {nla_len} bytes"
115 ),
116 }
117 }
118
119 pub fn nla_length_mismatch(buffer_len: usize, nla_len: usize) -> Self {
120 Self {
121 msg: format!(
122 "buffer has length: {buffer_len}, but the NLA is {nla_len} \
123 bytes"
124 ),
125 }
126 }
127
128 pub fn nla_invalid_length(buffer_len: usize, nla_len: usize) -> Self {
129 Self {
130 msg: format!(
131 "NLA has invalid length: {nla_len} (should be at least \
132 {buffer_len} bytes)"
133 ),
134 }
135 }
136
137 pub fn buffer_too_small(buffer_len: usize, value_len: usize) -> Self {
138 Self {
139 msg: format!(
140 "Buffer too small: {buffer_len} (should be at least \
141 {value_len} bytes"
142 ),
143 }
144 }
145}
146
147#[derive(Debug, PartialEq, Eq, Clone)]
148#[non_exhaustive]
149pub struct ErrorBuffer<T> {
150 buffer: T,
151}
152
153impl<T: AsRef<[u8]>> ErrorBuffer<T> {
154 pub fn new(buffer: T) -> ErrorBuffer<T> {
155 ErrorBuffer { buffer }
156 }
157
158 pub fn into_inner(self) -> T {
160 self.buffer
161 }
162
163 pub fn new_checked(buffer: T) -> Result<Self, DecodeError> {
164 let packet = Self::new(buffer);
165 packet
166 .check_buffer_length()
167 .context("invalid ErrorBuffer length")?;
168 Ok(packet)
169 }
170
171 fn check_buffer_length(&self) -> Result<(), DecodeError> {
172 let len = self.buffer.as_ref().len();
173 if len < ERROR_HEADER_LEN {
174 Err(DecodeError {
175 msg: format!(
176 "invalid ErrorBuffer: length is {len} but ErrorBuffer are \
177 at least {ERROR_HEADER_LEN} bytes"
178 ),
179 })
180 } else {
181 Ok(())
182 }
183 }
184
185 pub fn code(&self) -> Option<NonZeroI32> {
191 let data = self.buffer.as_ref();
192 NonZeroI32::new(parse_i32(&data[CODE]).unwrap())
193 }
194}
195
196impl<'a, T: AsRef<[u8]> + ?Sized> ErrorBuffer<&'a T> {
197 pub fn payload(&self) -> &'a [u8] {
199 let data = self.buffer.as_ref();
200 &data[PAYLOAD]
201 }
202}
203
204impl<T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> ErrorBuffer<&mut T> {
205 pub fn payload_mut(&mut self) -> &mut [u8] {
207 let data = self.buffer.as_mut();
208 &mut data[PAYLOAD]
209 }
210}
211
212impl<T: AsRef<[u8]> + AsMut<[u8]>> ErrorBuffer<T> {
213 pub fn set_code(&mut self, value: i32) {
215 let data = self.buffer.as_mut();
216 emit_i32(&mut data[CODE], value).unwrap();
217 }
218}
219
220#[derive(Debug, Default, Clone, PartialEq, Eq)]
227#[non_exhaustive]
228pub struct ErrorMessage {
229 pub code: Option<NonZeroI32>,
239 pub header: Vec<u8>,
241}
242
243impl Emitable for ErrorMessage {
244 fn buffer_len(&self) -> usize {
245 size_of::<i32>() + self.header.len()
246 }
247 fn emit(&self, buffer: &mut [u8]) {
248 let mut buffer = ErrorBuffer::new(buffer);
249 buffer.set_code(self.raw_code());
250 buffer.payload_mut().copy_from_slice(&self.header)
251 }
252}
253
254impl<T: AsRef<[u8]>> Parseable<ErrorBuffer<&T>> for ErrorMessage {
255 fn parse(buf: &ErrorBuffer<&T>) -> Result<ErrorMessage, DecodeError> {
256 Ok(ErrorMessage {
265 code: buf.code(),
266 header: buf.payload().to_vec(),
267 })
268 }
269}
270
271impl ErrorMessage {
272 pub fn raw_code(&self) -> i32 {
274 self.code.map_or(0, NonZeroI32::get)
275 }
276
277 pub fn to_io(&self) -> io::Error {
283 io::Error::from_raw_os_error(self.raw_code().abs())
284 }
285}
286
287impl fmt::Display for ErrorMessage {
288 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
289 fmt::Display::fmt(&self.to_io(), f)
290 }
291}
292
293impl From<ErrorMessage> for io::Error {
294 fn from(e: ErrorMessage) -> io::Error {
295 e.to_io()
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn into_io_error() {
305 let io_err = io::Error::from_raw_os_error(95);
306 let err_msg = ErrorMessage {
307 code: NonZeroI32::new(-95),
308 header: vec![],
309 };
310
311 let to_io: io::Error = err_msg.to_io();
312
313 assert_eq!(err_msg.to_string(), io_err.to_string());
314 assert_eq!(to_io.raw_os_error(), io_err.raw_os_error());
315 }
316
317 #[test]
318 fn parse_ack() {
319 let bytes = vec![0, 0, 0, 0];
320 let msg = ErrorBuffer::new_checked(&bytes)
321 .and_then(|buf| ErrorMessage::parse(&buf))
322 .expect("failed to parse NLMSG_ERROR");
323 assert_eq!(
324 ErrorMessage {
325 code: None,
326 header: Vec::new()
327 },
328 msg
329 );
330 assert_eq!(msg.raw_code(), 0);
331 }
332
333 #[test]
334 fn parse_nack() {
335 const ERROR_CODE: NonZeroI32 = NonZeroI32::new(-1234).unwrap();
337 let mut bytes = vec![0, 0, 0, 0];
338 emit_i32(&mut bytes, ERROR_CODE.get()).unwrap();
339 let msg = ErrorBuffer::new_checked(&bytes)
340 .and_then(|buf| ErrorMessage::parse(&buf))
341 .expect("failed to parse NLMSG_ERROR");
342 assert_eq!(
343 ErrorMessage {
344 code: Some(ERROR_CODE),
345 header: Vec::new()
346 },
347 msg
348 );
349 assert_eq!(msg.raw_code(), ERROR_CODE.get());
350 }
351}