1use super::{nlmsg_length, nlmsg_header_length};
2use std::mem::{size_of};
3use std::slice::{from_raw_parts};
4use std::io::{self, ErrorKind, Cursor};
5
6use byteorder::{NativeEndian, ReadBytesExt};
7
8#[derive(Clone, Copy)]
9pub enum MsgType {
10 Request,
12 Noop,
14 Error,
16 Done,
18 Overrun,
20 MinType,
22 UserDefined(u16),
24}
25
26impl Into<u16> for MsgType {
27 fn into(self) -> u16 {
28 use self::MsgType::*;
29 match self {
30 Request => 0,
31 Noop => 1,
32 Error => 2,
33 Done => 3,
34 Overrun => 4,
35 MinType => 10,
36 UserDefined(i) => i,
37 }
38 }
39}
40
41impl From<u16> for MsgType {
42 fn from(t: u16) -> MsgType {
43 use self::MsgType::*;
44 match t {
45 0 => Request,
46 1 => Noop,
47 2 => Error,
48 3 => Done,
49 4 => Overrun,
50 10 => MinType,
51 i => UserDefined(i),
52 }
53 }
54}
55
56#[derive(Clone, Copy)]
57enum Flags {
58 Request,
60 Multi,
62 Ack,
64 Echo,
66}
67
68impl Into<u16> for Flags {
69 fn into(self) -> u16 {
70 use self::Flags::*;
71 match self {
72 Request => 1,
73 Multi => 2,
74 Ack => 4,
75 Echo => 8,
76 }
77 }
78}
79
80#[derive(Clone, Copy)]
82enum GetFlags {
83 Root,
85 Match,
87 Atomic,
89 Dump,
91}
92
93impl Into<u16> for GetFlags {
94 fn into(self) -> u16 {
95 use self::GetFlags::*;
96 match self {
97 Root => 0x100,
98 Match => 0x200,
99 Atomic => 0x400,
100 Dump => 0x100 | 0x200,
101 }
102 }
103}
104
105#[derive(Clone, Copy)]
107enum NewFlags {
108 Replace,
110 Excl,
112 Create,
114 Append,
116}
117
118impl Into<u16> for NewFlags {
119 fn into(self) -> u16 {
120 use self::NewFlags::*;
121 match self {
122 Replace => 0x100,
123 Excl => 0x200,
124 Create => 0x400,
125 Append => 0x800,
126 }
127 }
128}
129
130#[repr(C)]
137#[derive(Clone, Copy, Eq, PartialEq, Debug)]
138pub struct NlMsgHeader {
139 msg_length: u32,
140 nl_type: u16,
141 flags: u16,
142 seq: u32,
143 pid: u32,
144}
145
146impl NlMsgHeader {
147 pub fn user_defined(t: u16) -> NlMsgHeader {
148 NlMsgHeader {
149 msg_length: nlmsg_header_length() as u32,
150 nl_type: t,
151 flags: Flags::Request.into(),
152 seq: 0,
153 pid: 0,
154 }
155 }
156
157 pub fn request() -> NlMsgHeader {
158 NlMsgHeader {
159 msg_length: nlmsg_header_length() as u32,
160 nl_type: MsgType::Request.into(),
161 flags: Flags::Request.into(),
162 seq: 0,
163 pid: 0,
164 }
165 }
166
167 pub fn done() -> NlMsgHeader {
168 NlMsgHeader {
169 msg_length: nlmsg_header_length() as u32,
170 nl_type: MsgType::Done.into(),
171 flags: Flags::Multi.into(),
172 seq: 0,
173 pid: 0,
174 }
175 }
176
177 pub fn error() -> NlMsgHeader {
178 NlMsgHeader {
179 msg_length: nlmsg_length(nlmsg_header_length() + 4) as u32, nl_type: MsgType::Error.into(),
181 flags: 0,
182 seq: 0,
183 pid: 0,
184 }
185 }
186
187 pub fn from_bytes(bytes: &[u8]) -> io::Result<(NlMsgHeader, usize)> {
188 let buf_len = bytes.len() as u32;
189 let mut cursor = Cursor::new(bytes);
190 let len = try!(cursor.read_u32::<NativeEndian>());
191 let nl_type = try!(cursor.read_u16::<NativeEndian>());
192 let flags = try!(cursor.read_u16::<NativeEndian>());
193 let seq = try!(cursor.read_u32::<NativeEndian>());
194 let pid = try!(cursor.read_u32::<NativeEndian>());
195
196 if len < nlmsg_header_length() as u32 {
197 Err(io::Error::new(ErrorKind::InvalidInput, "length smaller than msg header size"))
198 } else {
199 Ok((NlMsgHeader{
200 msg_length: len,
201 nl_type: nl_type,
202 flags: flags,
203 seq: seq,
204 pid: pid,
205 }, cursor.position() as usize))
206 }
207 }
208
209 pub fn bytes(&self) -> &[u8] {
210 let size = size_of::<NlMsgHeader>();
211 unsafe {
212 let head = self as *const NlMsgHeader as *const u8;
213 from_raw_parts(head, size)
214 }
215 }
216
217 pub fn msg_type(&self) -> MsgType {
218 self.nl_type.into()
219 }
220
221 pub fn msg_length(&self) -> u32 {
222 self.msg_length
223 }
224
225 pub fn data_length(&mut self, len: u32) -> &mut NlMsgHeader {
227 self.msg_length = nlmsg_length(len as usize) as u32;
228 self
229 }
230
231 pub fn multipart(&mut self) -> &mut NlMsgHeader {
233 self.flags |= Flags::Multi.into();
234 self
235 }
236
237 pub fn ack(&mut self) -> &mut NlMsgHeader {
239 self.flags |= Flags::Ack.into();
240 self
241 }
242
243 pub fn echo(&mut self) -> &mut NlMsgHeader {
245 self.flags |= Flags::Echo.into();
246 self
247 }
248
249 pub fn seq(&mut self, n: u32) -> &mut NlMsgHeader {
251 self.seq = n;
252 self
253 }
254
255 pub fn pid(&mut self, n: u32) -> &mut NlMsgHeader {
257 self.pid = n;
258 self
259 }
260
261 pub fn replace(&mut self) -> &mut NlMsgHeader {
263 self.flags |= NewFlags::Replace.into();
264 self
265 }
266
267 pub fn excl(&mut self) -> &mut NlMsgHeader {
269 self.flags |= NewFlags::Excl.into();
270 self
271 }
272
273 pub fn create(&mut self) -> &mut NlMsgHeader {
275 self.flags |= NewFlags::Create.into();
276 self
277 }
278
279 pub fn append(&mut self) -> &mut NlMsgHeader {
281 self.flags |= NewFlags::Append.into();
282 self
283 }
284
285 pub fn root(&mut self) -> &mut NlMsgHeader {
287 self.flags |= GetFlags::Root.into();
288 self
289 }
290
291 pub fn match_provided(&mut self) -> &mut NlMsgHeader {
293 self.flags |= GetFlags::Match.into();
294 self
295 }
296
297 pub fn atomic(&mut self) -> &mut NlMsgHeader {
299 self.flags |= GetFlags::Atomic.into();
300 self
301 }
302
303 pub fn dump(&mut self) -> &mut NlMsgHeader {
305 self.flags |= GetFlags::Dump.into();
306 self
307 }
308}
309
310#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn test_encoding() {
354 let expected = [20, 0, 0, 0, 0, 0, 1, 3, 1, 0, 0, 0, 9, 0, 0, 0];
356 let mut hdr = NlMsgHeader::request();
357 let bytes = hdr.data_length(4).pid(9).seq(1).dump().bytes();
358
359 assert_eq!(bytes, expected);
360 }
361
362 #[test]
363 fn test_decoding() {
364 let bytes = [16, 0, 0, 0, 0, 0, 1, 3, 1, 0, 0, 0, 9, 0, 0, 0, 1, 1, 1];
366 let mut h = NlMsgHeader::request();
367 let expected = h.data_length(0).pid(9).seq(1).dump();
368
369 let (hdr, n) = NlMsgHeader::from_bytes(&bytes).unwrap();
370 assert_eq!(hdr, *expected);
371 assert_eq!(n, 16);
372 }
373
374 #[test]
375 fn test_decoding_error() {
376 let bytes = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
378 let res = NlMsgHeader::from_bytes(&bytes);
379 assert!(res.is_err());
380 }
381}