1mod socket_impl;
2
3mod address;
4pub use self::address::*;
5
6mod msg;
7pub use self::msg::*;
8
9use socket::socket_impl::Socket as SocketImpl;
10
11use std::mem::{size_of};
12
13use libc::{AF_NETLINK, SOCK_RAW};
14
15use std::convert::Into;
16use std::io::{self, Write, Cursor};
17use std::iter::{repeat};
18
19use byteorder::{NativeEndian, WriteBytesExt, ReadBytesExt};
20
21const NLMSG_ALIGNTO: usize = 4;
23
24
25#[inline]
27fn htons(hostshort: u16) -> u16 {
28 hostshort.to_be()
29}
30
31
32#[inline]
40fn htonl(hostlong: u32) -> u32 {
41 hostlong.to_be()
42}
43
44#[inline]
46fn ntohl(netlong: u32) -> u32 {
47 u32::from_be(netlong)
48}
49
50#[derive(Clone, Eq, PartialEq, Debug)]
51pub enum Payload<'a> {
52 None,
53 Data(&'a [u8]),
54 Ack(NlMsgHeader),
55 Err(NlMsgHeader),
56}
57
58impl<'a> Payload<'a> {
59 fn data(bytes: &'a [u8], len: usize) -> io::Result<(Payload<'a>, usize)> {
60 use std::io::{ErrorKind, Error};
61
62 let l = bytes.len();
63 if l < len {
64 Err(Error::new(ErrorKind::InvalidData, "length of bytes too small"))
65 } else {
66 Ok((Payload::Data(&bytes[..len]), len))
67 }
68 }
69
70 fn nlmsg_error(bytes: &'a [u8]) -> io::Result<(Payload<'a>, usize)> {
71 let mut cursor = Cursor::new(bytes);
72 let err = try!(cursor.read_u32::<NativeEndian>());
73 let n = cursor.position() as usize;
74 let (hdr, n2) = try!(NlMsgHeader::from_bytes(&bytes[n..]));
75 let num = n + n2;
76 if err == 0 {
77 Ok((Payload::Ack(hdr), num))
78 } else {
79 Ok((Payload::Err(hdr), num))
80 }
81 }
82
83 fn bytes(&self) -> io::Result<Vec<u8>> {
84 match *self {
85 Payload::None => {
86 Ok(vec!())
87 },
88 Payload::Data(b) => {
89 Ok(b.into())
90 },
91 Payload::Ack(h) => {
92 let mut vec = vec![];
93 try!(vec.write_u32::<NativeEndian>(0));
94 try!(vec.write(h.bytes()));
95 Ok(vec)
96 },
97 Payload::Err(h) => {
98 let mut vec = vec![];
99 try!(vec.write_u32::<NativeEndian>(1));
100 try!(vec.write(h.bytes()));
101 Ok(vec)
102 },
103 }
104 }
105}
106
107#[derive(Clone, Eq, PartialEq, Debug)]
108pub struct Msg<'a> {
109 header: NlMsgHeader,
110 payload: Payload<'a>,
111}
112
113impl<'a> Msg<'a> {
114 pub fn from_bytes(bytes: &'a [u8]) -> io::Result<(Msg<'a>, usize)> {
115 let (hdr, n) = try!(NlMsgHeader::from_bytes(bytes));
116 let (payload, n2) = match hdr.msg_type() {
117 MsgType::Done => {
118 (Payload::None, 0)
119 },
120 MsgType::Error => {
121 try!(Payload::nlmsg_error(&bytes[n..]))
122 },
123 _ => {
124 let msg_len = hdr.msg_length() as usize - nlmsg_header_length();
125 try!(Payload::data(&bytes[n..], msg_len))
126 },
127 };
128
129 Ok((Msg{
130 header: hdr,
131 payload: payload,
132 }, n + n2))
133 }
134
135 pub fn new(hdr: NlMsgHeader, payload: Payload<'a>) -> Msg<'a> {
136 Msg{
137 header: hdr,
138 payload: payload,
139 }
140 }
141
142 pub fn bytes(&self) -> io::Result<Vec<u8>> {
143 let mut bytes: Vec<u8> = self.header.bytes().into();
144 let mut payload = try!(self.payload.bytes());
145 bytes.append(&mut payload);
146 Ok(bytes)
147 }
148
149 pub fn header(&self) -> NlMsgHeader {
150 self.header
151 }
152
153 pub fn payload(&self) -> &Payload<'a> {
154 &self.payload
155 }
156}
157
158pub struct Socket {
168 inner: SocketImpl,
169 buf: Vec<u8>,
170}
171
172impl Socket {
173 pub fn new<P: Into<i32>>(protocol: P) -> io::Result<Socket> {
174 let s = try!(SocketImpl::new(AF_NETLINK, SOCK_RAW, protocol.into()));
175 let bytes = 4096;
176 let mut buf = Vec::with_capacity(bytes);
177 buf.extend(repeat(0u8).take(bytes));
178 Ok(Socket {
179 inner: s,
180 buf: buf,
181 })
182 }
183
184 pub fn bind(&self, addr: NetlinkAddr) -> io::Result<()> {
185 self.inner.bind(&addr.as_sockaddr())
186 }
187
188 pub fn close(&self) -> io::Result<()> {
189 self.inner.close()
190 }
191
192 pub fn send<'a>(&self, message: Msg<'a>, addr: &NetlinkAddr)
193 -> io::Result<usize> {
194 let b = try!(message.bytes());
195 self.inner.sendto(b.as_slice(), 0, &addr.as_sockaddr())
196 }
197
198 pub fn send_multi<'a>(&self, messages: Vec<Msg<'a>>, addr: &NetlinkAddr)
199 -> io::Result<usize> {
200 let mut bytes = vec![];
201 for m in messages {
202 let mut b = try!(m.bytes());
203 bytes.append(&mut b);
204 }
205
206 self.inner.sendto(bytes.as_slice(), 0, &addr.as_sockaddr())
207 }
208
209 pub fn recv(&mut self) -> io::Result<(NetlinkAddr, Vec<Msg>)> {
210 let buffer = &mut self.buf[..];
211 let (saddr, _) = try!(self.inner.recvfrom_into(buffer, 0));
212 let addr = try!(sockaddr_to_netlinkaddr(&saddr));
213 let mut messages = vec![];
214
215 let mut n = 0;
216 while let Ok((msg, num_bytes)) = Msg::from_bytes(&buffer[n..]) {
217 n += num_bytes;
218 let t = msg.header().msg_type();
219 match t {
220 MsgType::Done => {
221 break
222 },
223 _ => {
224 messages.push(msg);
225 },
226 }
227 }
228
229 Ok((addr, messages))
230 }
231}
232
233#[inline]
237fn nlmsg_align(len: usize) -> usize {
238 (len + (NLMSG_ALIGNTO - 1)) & !(NLMSG_ALIGNTO - 1)
239}
240
241#[inline]
243fn nlmsg_header_length() -> usize {
244 nlmsg_align(size_of::<NlMsgHeader>())
245}
246
247#[inline]
252fn nlmsg_length(len: usize) -> usize {
253 len + nlmsg_align(nlmsg_header_length())
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use byteorder::{NativeEndian, WriteBytesExt};
260 use Protocol;
261 use std::io::Write;
262
263 #[test]
264 fn test_send_recv() {
265 let send = Socket::new(Protocol::Usersock).unwrap();
266 let mut recv = Socket::new(Protocol::Usersock).unwrap();
267 let send_addr = NetlinkAddr::new(101, 0);
268 let recv_addr = NetlinkAddr::new(102, 0);
269
270 send.bind(send_addr).unwrap();
271 recv.bind(recv_addr).unwrap();
272
273 let bytes = [0,1,2,3,4,5];
274 let mut shdr = NlMsgHeader::request();
275 shdr.data_length(6).seq(1).pid(102);
276 let msg = Msg::new(shdr, Payload::Data(&bytes));
277
278 send.send(msg, &recv_addr).unwrap();
279
280 let (ref addr, ref vec) = recv.recv().unwrap();
281 assert_eq!(vec.len(), 1);
282
283 let ref msg = vec.first().unwrap();
284 assert_eq!(addr, &send_addr);
285 if let &Payload::Data(b) = msg.payload() {
286 assert_eq!(b, &bytes);
287 } else {
288 panic!("msg is not Data enum");
289 }
290 }
291
292 #[test]
293 fn test_send_multi_recv() {
294 let send = Socket::new(Protocol::Usersock).unwrap();
295 let mut recv = Socket::new(Protocol::Usersock).unwrap();
296 let send_addr = NetlinkAddr::new(99, 0);
297 let recv_addr = NetlinkAddr::new(100, 0);
298
299 send.bind(send_addr).unwrap();
300 recv.bind(recv_addr).unwrap();
301
302 let bytes = [0,1,2,3,4,5];
303 let mut shdr = NlMsgHeader::request();
304 shdr.data_length(6).multipart().seq(1).pid(100);
305 let msg = Msg::new(shdr, Payload::Data(&bytes));
306 let msg2 = msg.clone();
307
308
309 let mut donehdr = NlMsgHeader::done();
310 donehdr.pid(100);
311 let donemsg = Msg::new(donehdr, Payload::None);
312
313 send.send_multi(vec![msg, msg2, donemsg], &recv_addr).unwrap();
314
315 let (ref addr, ref vec) = recv.recv().unwrap();
316 assert_eq!(vec.len(), 2);
317
318 let ref msg = vec.first().unwrap();
319 assert_eq!(addr, &send_addr);
320 if let &Payload::Data(b) = msg.payload() {
321 assert_eq!(b, &bytes);
322 } else {
323 panic!("msg is not Data enum");
324 }
325 }
326
327 #[test]
328 fn test_payload_decode() {
329 let bytes = [0,1,2,3,4,5];
330 let (payload, n) = Payload::data(&bytes, bytes.len()).unwrap();
331 assert_eq!(n, bytes.len());
332
333 if let Payload::Data(b) = payload {
334 assert_eq!(b, &bytes);
335 } else {
336 panic!("payload is not Data enum");
337 }
338 }
339
340 #[test]
341 fn test_payload_decode_with_err() {
342 let mut bytes = vec![];
343 bytes.write_u32::<NativeEndian>(1).unwrap();
344
345 let expected = [20, 0, 0, 0, 0, 0, 1, 3, 1, 0, 0, 0, 9, 0, 0, 0];
347 let mut hdr = NlMsgHeader::request();
348 hdr.data_length(4).pid(9).seq(1).dump();
349
350 bytes.write(&expected).unwrap();
351
352 let (p, n) = Payload::nlmsg_error(&bytes).unwrap();
353
354 assert_eq!(n, bytes.len());
355 if let Payload::Err(h) = p {
356 assert_eq!(h, hdr);
357 } else {
358 panic!("payload is not Err enum");
359 }
360 }
361
362 #[test]
363 fn test_payload_decode_with_ack() {
364 let mut bytes = vec![];
365 bytes.write_u32::<NativeEndian>(0).unwrap();
366
367 let mut hdr = NlMsgHeader::request();
368 hdr.data_length(4).pid(9).seq(1).dump();
369
370 bytes.write(&hdr.bytes()).unwrap();
371
372 let (p, n) = Payload::nlmsg_error(&bytes).unwrap();
373
374 assert_eq!(n, bytes.len());
375 if let Payload::Ack(h) = p {
376 assert_eq!(h, hdr);
377 } else {
378 panic!("payload is not Ack enum");
379 }
380 }
381
382 #[test]
383 fn test_msg_decode() {
384 let mut hdr = NlMsgHeader::request();
386 hdr.data_length(4).pid(9).seq(1).dump();
387 let hdr_bytes = hdr.bytes();
388
389 let data = [0,1,2,3];
390
391 let mut bytes = vec![];
392 bytes.write(&hdr_bytes).unwrap();
393 bytes.write(&data).unwrap();
394 bytes.write(&[1,1,1,1,1,1,1]).unwrap();
396
397 let (msg, n) = Msg::from_bytes(&bytes).unwrap();
398 assert_eq!(n, hdr_bytes.len() + data.len());
399 assert_eq!(hdr, msg.header());
400
401 if let &Payload::Data(b) = msg.payload() {
402 assert_eq!(b, &data);
403 } else {
404 panic!("msg is not Data enum");
405 }
406 }
407
408 #[test]
409 fn test_msg_decode_with_err() {
410 let mut hdr = NlMsgHeader::error();
411 hdr.pid(9).seq(1);
412 let hdr_bytes = hdr.bytes();
413
414 let mut bytes = vec![];
415 bytes.write(&hdr_bytes).unwrap();
416
417 bytes.write_u32::<NativeEndian>(1).unwrap();
418 let mut err_hdr = NlMsgHeader::request();
419 err_hdr.data_length(4).pid(9).seq(1).dump();
420 bytes.write(&err_hdr.bytes()).unwrap();
421
422 let (msg, n) = Msg::from_bytes(&bytes).unwrap();
423 assert_eq!(n, bytes.len());
424 assert_eq!(hdr, msg.header());
425
426 if let &Payload::Err(h) = msg.payload() {
427 assert_eq!(h, err_hdr);
428 } else {
429 panic!("msg is not Err enum");
430 }
431 }
432}