netlink_packet_core/
message.rs1use std::fmt::Debug;
4
5use crate::{
6 done::DONE_HEADER_LEN,
7 payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN},
8 DecodeError, DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorContext,
9 ErrorMessage, NetlinkBuffer, NetlinkDeserializable, NetlinkHeader,
10 NetlinkPayload, NetlinkSerializable, Parseable,
11};
12
13#[derive(Debug, PartialEq, Eq, Clone)]
15#[non_exhaustive]
16pub struct NetlinkMessage<I> {
17 pub header: NetlinkHeader,
19 pub payload: NetlinkPayload<I>,
21}
22
23impl<I> NetlinkMessage<I> {
24 pub fn new(header: NetlinkHeader, payload: NetlinkPayload<I>) -> Self {
26 NetlinkMessage { header, payload }
27 }
28
29 pub fn into_parts(self) -> (NetlinkHeader, NetlinkPayload<I>) {
31 (self.header, self.payload)
32 }
33}
34
35impl<I> NetlinkMessage<I>
36where
37 I: NetlinkDeserializable,
38{
39 pub fn deserialize(buffer: &[u8]) -> Result<Self, DecodeError> {
41 let netlink_buffer = NetlinkBuffer::new_checked(&buffer)
42 .context("failed deserializing NetlinkMessage")?;
43 <Self as Parseable<NetlinkBuffer<&&[u8]>>>::parse(&netlink_buffer)
44 }
45}
46
47impl<I> NetlinkMessage<I>
48where
49 I: NetlinkSerializable,
50{
51 pub fn buffer_len(&self) -> usize {
53 <Self as Emitable>::buffer_len(self)
54 }
55
56 pub fn serialize(&self, buffer: &mut [u8]) {
65 self.emit(buffer)
66 }
67
68 pub fn finalize(&mut self) {
80 self.header.length = self.buffer_len() as u32;
81 self.header.message_type = self.payload.message_type();
82 }
83}
84
85impl<B, I> Parseable<NetlinkBuffer<&B>> for NetlinkMessage<I>
86where
87 B: AsRef<[u8]>,
88 I: NetlinkDeserializable,
89{
90 fn parse(buf: &NetlinkBuffer<&B>) -> Result<Self, DecodeError> {
91 use self::NetlinkPayload::*;
92
93 let header =
94 <NetlinkHeader as Parseable<NetlinkBuffer<&B>>>::parse(buf)
95 .context("failed parsing NetlinkHeader")?;
96
97 let bytes = buf.payload();
98 let payload = match header.message_type {
99 NLMSG_ERROR => {
100 let msg = ErrorBuffer::new_checked(&bytes)
101 .and_then(|buf| ErrorMessage::parse(&buf))
102 .context("failed parsing NLMSG_ERROR")?;
103 Error(msg)
104 }
105 NLMSG_NOOP => Noop,
106 NLMSG_DONE => {
107 let msg = if bytes.is_empty() {
109 DoneBuffer::new_checked(&[0u8; DONE_HEADER_LEN])
110 .and_then(|buf| DoneMessage::parse(&buf))
111 .context("failed to parse NLMSG_DONE")?
112 } else {
113 DoneBuffer::new_checked(&bytes)
114 .and_then(|buf| DoneMessage::parse(&buf))
115 .context("failed to parse NLMSG_DONE")?
116 };
117 Done(msg)
118 }
119 NLMSG_OVERRUN => Overrun(bytes.to_vec()),
120 message_type => match I::deserialize(&header, bytes) {
121 Err(e) => {
122 return Err(format!(
123 "Failed to parse message with type {message_type}: {e}"
124 )
125 .into())
126 }
127 Ok(inner_msg) => InnerMessage(inner_msg),
128 },
129 };
130 Ok(NetlinkMessage { header, payload })
131 }
132}
133
134impl<I> Emitable for NetlinkMessage<I>
135where
136 I: NetlinkSerializable,
137{
138 fn buffer_len(&self) -> usize {
139 use self::NetlinkPayload::*;
140
141 let payload_len = match self.payload {
142 Noop => 0,
143 Done(ref msg) => msg.buffer_len(),
144 Overrun(ref bytes) => bytes.len(),
145 Error(ref msg) => msg.buffer_len(),
146 InnerMessage(ref msg) => msg.buffer_len(),
147 };
148
149 self.header.buffer_len() + payload_len
150 }
151
152 fn emit(&self, buffer: &mut [u8]) {
153 use self::NetlinkPayload::*;
154
155 self.header.emit(buffer);
156
157 let buffer =
158 &mut buffer[self.header.buffer_len()..self.header.length as usize];
159 match self.payload {
160 Noop => {}
161 Done(ref msg) => msg.emit(buffer),
162 Overrun(ref bytes) => buffer.copy_from_slice(bytes),
163 Error(ref msg) => msg.emit(buffer),
164 InnerMessage(ref msg) => msg.serialize(buffer),
165 }
166 }
167}
168
169impl<T> From<T> for NetlinkMessage<T>
170where
171 T: Into<NetlinkPayload<T>>,
172{
173 fn from(inner_message: T) -> Self {
174 NetlinkMessage {
175 header: NetlinkHeader::default(),
176 payload: inner_message.into(),
177 }
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 use std::{convert::Infallible, mem::size_of, num::NonZeroI32};
186
187 #[derive(Clone, Debug, Default, PartialEq)]
188 struct FakeNetlinkInnerMessage;
189
190 impl NetlinkSerializable for FakeNetlinkInnerMessage {
191 fn message_type(&self) -> u16 {
192 unimplemented!("unused by tests")
193 }
194
195 fn buffer_len(&self) -> usize {
196 unimplemented!("unused by tests")
197 }
198
199 fn serialize(&self, _buffer: &mut [u8]) {
200 unimplemented!("unused by tests")
201 }
202 }
203
204 impl NetlinkDeserializable for FakeNetlinkInnerMessage {
205 type Error = Infallible;
206
207 fn deserialize(
208 _header: &NetlinkHeader,
209 _payload: &[u8],
210 ) -> Result<Self, Self::Error> {
211 unimplemented!("unused by tests")
212 }
213 }
214
215 #[test]
216 fn test_done() {
217 let header = NetlinkHeader::default();
218 let done_msg = DoneMessage {
219 code: 0,
220 extended_ack: vec![6, 7, 8, 9],
221 };
222 let mut want = NetlinkMessage::new(
223 header,
224 NetlinkPayload::<FakeNetlinkInnerMessage>::Done(done_msg.clone()),
225 );
226 want.finalize();
227
228 let len = want.buffer_len();
229 assert_eq!(
230 len,
231 header.buffer_len()
232 + size_of::<i32>()
233 + done_msg.extended_ack.len()
234 );
235
236 let mut buf = vec![1; len];
237 want.emit(&mut buf);
238
239 let done_buf = DoneBuffer::new(&buf[header.buffer_len()..]);
240 assert_eq!(done_buf.code(), done_msg.code);
241 assert_eq!(done_buf.extended_ack(), &done_msg.extended_ack);
242
243 let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap();
244 assert_eq!(got, want);
245 }
246
247 #[test]
248 fn test_error() {
249 const ERROR_CODE: NonZeroI32 = NonZeroI32::new(-8765).unwrap();
251
252 let header = NetlinkHeader::default();
253 let error_msg = ErrorMessage {
254 code: Some(ERROR_CODE),
255 header: vec![],
256 };
257 let mut want = NetlinkMessage::new(
258 header,
259 NetlinkPayload::<FakeNetlinkInnerMessage>::Error(error_msg.clone()),
260 );
261 want.finalize();
262
263 let len = want.buffer_len();
264 assert_eq!(len, header.buffer_len() + error_msg.buffer_len());
265
266 let mut buf = vec![1; len];
267 want.emit(&mut buf);
268
269 let error_buf = ErrorBuffer::new(&buf[header.buffer_len()..]);
270 assert_eq!(error_buf.code(), error_msg.code);
271
272 let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap();
273 assert_eq!(got, want);
274 }
275}