netlink_rs/socket/
mod.rs

1mod socket_impl;
2
3mod address;
4pub use self::address::*;
5
6mod msg;
7pub use self::msg::*;
8
9use socket::socket_impl::Socket as SocketImpl;
10
11use std::mem::{size_of};
12
13use libc::{AF_NETLINK, SOCK_RAW};
14
15use std::convert::Into;
16use std::io::{self, Write, Cursor};
17use std::iter::{repeat};
18
19use byteorder::{NativeEndian, WriteBytesExt, ReadBytesExt};
20
21// #define NLMSG_ALIGNTO   4
22const NLMSG_ALIGNTO: usize = 4;
23
24
25/// Converts a value from host byte order to network byte order.
26#[inline]
27fn htons(hostshort: u16) -> u16 {
28    hostshort.to_be()
29}
30
31
32/// Converts a value from network byte order to host byte order.
33// #[inline]
34// fn ntohs(netshort: u16) -> u16 {
35//     u16::from_be(netshort)
36// }
37
38/// Converts a value from host byte order to network byte order.
39#[inline]
40fn htonl(hostlong: u32) -> u32 {
41    hostlong.to_be()
42}
43
44/// Converts a value from network byte order to host byte order.
45#[inline]
46fn ntohl(netlong: u32) -> u32 {
47    u32::from_be(netlong)
48}
49
50#[derive(Clone, Eq, PartialEq, Debug)]
51pub enum Payload<'a> {
52    None,
53    Data(&'a [u8]),
54    Ack(NlMsgHeader),
55    Err(NlMsgHeader),
56}
57
58impl<'a> Payload<'a> {
59    fn data(bytes: &'a [u8], len: usize) -> io::Result<(Payload<'a>, usize)> {
60        use std::io::{ErrorKind, Error};
61
62        let l = bytes.len();
63        if l < len {
64            Err(Error::new(ErrorKind::InvalidData, "length of bytes too small"))
65        } else {
66            Ok((Payload::Data(&bytes[..len]), len))
67        }
68    }
69
70    fn nlmsg_error(bytes: &'a [u8]) -> io::Result<(Payload<'a>, usize)> {
71        let mut cursor = Cursor::new(bytes);
72        let err = try!(cursor.read_u32::<NativeEndian>());
73        let n = cursor.position() as usize;
74        let (hdr, n2) = try!(NlMsgHeader::from_bytes(&bytes[n..]));
75        let num = n + n2;
76        if err == 0 {
77            Ok((Payload::Ack(hdr), num))
78        } else {
79            Ok((Payload::Err(hdr), num))
80        }
81    }
82
83    fn bytes(&self) -> io::Result<Vec<u8>> {
84        match *self {
85            Payload::None => {
86                Ok(vec!())
87            },
88            Payload::Data(b) => {
89                Ok(b.into())
90            },
91            Payload::Ack(h) => {
92                let mut vec = vec![];
93                try!(vec.write_u32::<NativeEndian>(0));
94                try!(vec.write(h.bytes()));
95                Ok(vec)
96            },
97            Payload::Err(h) => {
98                let mut vec = vec![];
99                try!(vec.write_u32::<NativeEndian>(1));
100                try!(vec.write(h.bytes()));
101                Ok(vec)
102            },
103        }
104    }
105}
106
107#[derive(Clone, Eq, PartialEq, Debug)]
108pub struct Msg<'a> {
109    header: NlMsgHeader,
110    payload: Payload<'a>,
111}
112
113impl<'a> Msg<'a> {
114    pub fn from_bytes(bytes: &'a [u8]) -> io::Result<(Msg<'a>, usize)> {
115        let (hdr, n) = try!(NlMsgHeader::from_bytes(bytes));
116        let (payload, n2) = match hdr.msg_type() {
117            MsgType::Done => {
118                (Payload::None, 0)
119            },
120            MsgType::Error => {
121                try!(Payload::nlmsg_error(&bytes[n..]))
122            },
123            _ => {
124                let msg_len = hdr.msg_length() as usize - nlmsg_header_length();
125                try!(Payload::data(&bytes[n..], msg_len))
126            },
127        };
128
129        Ok((Msg{
130            header: hdr,
131            payload: payload,
132        }, n + n2))
133    }
134
135    pub fn new(hdr: NlMsgHeader, payload: Payload<'a>) -> Msg<'a> {
136        Msg{
137            header: hdr,
138            payload: payload,
139        }
140    }
141
142    pub fn bytes(&self) -> io::Result<Vec<u8>> {
143        let mut bytes: Vec<u8> = self.header.bytes().into();
144        let mut payload = try!(self.payload.bytes());
145        bytes.append(&mut payload);
146        Ok(bytes)
147    }
148
149    pub fn header(&self) -> NlMsgHeader {
150        self.header
151    }
152
153    pub fn payload(&self) -> &Payload<'a> {
154        &self.payload
155    }
156}
157
158// #[repr(C)]
159// #[derive(Clone, Copy, Eq, PartialEq, Debug)]
160// struct NlErr {
161//     /// 0 if used as acknowledgement
162//     err: u32,
163//     /// Msg header that caused the error
164//     hdr: NlMsgHeader,
165// }
166
167pub struct Socket {
168    inner: SocketImpl,
169    buf: Vec<u8>,
170}
171
172impl Socket {
173    pub fn new<P: Into<i32>>(protocol: P) -> io::Result<Socket> {
174        let s = try!(SocketImpl::new(AF_NETLINK, SOCK_RAW, protocol.into()));
175        let bytes = 4096;
176        let mut buf = Vec::with_capacity(bytes);
177        buf.extend(repeat(0u8).take(bytes));
178        Ok(Socket {
179            inner: s,
180            buf: buf,
181        })
182    }
183
184    pub fn bind(&self, addr: NetlinkAddr) -> io::Result<()> {
185        self.inner.bind(&addr.as_sockaddr())
186    }
187
188    pub fn close(&self) -> io::Result<()> {
189        self.inner.close()
190    }
191
192    pub fn send<'a>(&self, message: Msg<'a>, addr: &NetlinkAddr)
193        -> io::Result<usize> {
194            let b = try!(message.bytes());
195            self.inner.sendto(b.as_slice(), 0, &addr.as_sockaddr())
196        }
197
198    pub fn send_multi<'a>(&self, messages: Vec<Msg<'a>>, addr: &NetlinkAddr)
199        -> io::Result<usize> {
200            let mut bytes = vec![];
201            for m in messages {
202                let mut b = try!(m.bytes());
203                bytes.append(&mut b);
204            }
205
206            self.inner.sendto(bytes.as_slice(), 0, &addr.as_sockaddr())
207        }
208
209    pub fn recv(&mut self) -> io::Result<(NetlinkAddr, Vec<Msg>)> {
210        let buffer = &mut self.buf[..];
211        let (saddr, _) = try!(self.inner.recvfrom_into(buffer, 0));
212        let addr = try!(sockaddr_to_netlinkaddr(&saddr));
213        let mut messages = vec![];
214
215        let mut n = 0;
216        while let Ok((msg, num_bytes)) = Msg::from_bytes(&buffer[n..]) {
217            n += num_bytes;
218            let t = msg.header().msg_type();
219            match t {
220                MsgType::Done => {
221                    break
222                },
223                _ => {
224                    messages.push(msg);
225                },
226            }
227        }
228
229        Ok((addr, messages))
230    }
231}
232
233// NLMSG_ALIGN()
234//       Round the length of a netlink message up to align it properly.
235// #define NLMSG_ALIGN(len) ( ((len)+NLMSG_ALIGNTO-1) & ~(NLMSG_ALIGNTO-1) )
236#[inline]
237fn nlmsg_align(len: usize) -> usize {
238    (len + (NLMSG_ALIGNTO - 1)) & !(NLMSG_ALIGNTO - 1)
239}
240
241// #define NLMSG_HDRLEN     ((int) NLMSG_ALIGN(sizeof(struct nlmsghdr)))
242#[inline]
243fn nlmsg_header_length() -> usize {
244    nlmsg_align(size_of::<NlMsgHeader>())
245}
246
247// NLMSG_LENGTH()
248//        Given the payload length, len, this macro returns the aligned
249//        length to store in the nlmsg_len field of the nlmsghdr.
250// #define NLMSG_LENGTH(len) ((len)+NLMSG_ALIGN(NLMSG_HDRLEN))
251#[inline]
252fn nlmsg_length(len: usize) -> usize {
253    len + nlmsg_align(nlmsg_header_length())
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use byteorder::{NativeEndian, WriteBytesExt};
260    use Protocol;
261    use std::io::Write;
262
263    #[test]
264    fn test_send_recv() {
265        let send = Socket::new(Protocol::Usersock).unwrap();
266        let mut recv = Socket::new(Protocol::Usersock).unwrap();
267        let send_addr = NetlinkAddr::new(101, 0);
268        let recv_addr = NetlinkAddr::new(102, 0);
269
270        send.bind(send_addr).unwrap();
271        recv.bind(recv_addr).unwrap();
272
273        let bytes = [0,1,2,3,4,5];
274        let mut shdr = NlMsgHeader::request();
275        shdr.data_length(6).seq(1).pid(102);
276        let msg = Msg::new(shdr, Payload::Data(&bytes));
277
278        send.send(msg, &recv_addr).unwrap();
279
280        let (ref addr, ref vec) = recv.recv().unwrap();
281        assert_eq!(vec.len(), 1);
282
283        let ref msg = vec.first().unwrap();
284        assert_eq!(addr, &send_addr);
285        if let &Payload::Data(b) = msg.payload() {
286            assert_eq!(b, &bytes);
287        } else {
288            panic!("msg is not Data enum");
289        }
290    }
291
292    #[test]
293    fn test_send_multi_recv() {
294        let send = Socket::new(Protocol::Usersock).unwrap();
295        let mut recv = Socket::new(Protocol::Usersock).unwrap();
296        let send_addr = NetlinkAddr::new(99, 0);
297        let recv_addr = NetlinkAddr::new(100, 0);
298
299        send.bind(send_addr).unwrap();
300        recv.bind(recv_addr).unwrap();
301
302        let bytes = [0,1,2,3,4,5];
303        let mut shdr = NlMsgHeader::request();
304        shdr.data_length(6).multipart().seq(1).pid(100);
305        let msg = Msg::new(shdr, Payload::Data(&bytes));
306        let msg2 = msg.clone();
307
308
309        let mut donehdr = NlMsgHeader::done();
310        donehdr.pid(100);
311        let donemsg = Msg::new(donehdr, Payload::None);
312
313        send.send_multi(vec![msg, msg2, donemsg], &recv_addr).unwrap();
314
315        let (ref addr, ref vec) = recv.recv().unwrap();
316        assert_eq!(vec.len(), 2);
317
318        let ref msg = vec.first().unwrap();
319        assert_eq!(addr, &send_addr);
320        if let &Payload::Data(b) = msg.payload() {
321            assert_eq!(b, &bytes);
322        } else {
323            panic!("msg is not Data enum");
324        }
325    }
326
327    #[test]
328    fn test_payload_decode() {
329        let bytes = [0,1,2,3,4,5];
330        let (payload, n) = Payload::data(&bytes, bytes.len()).unwrap();
331        assert_eq!(n, bytes.len());
332
333        if let Payload::Data(b) = payload {
334            assert_eq!(b, &bytes);
335        } else {
336            panic!("payload is not Data enum");
337        }
338    }
339
340    #[test]
341    fn test_payload_decode_with_err() {
342        let mut bytes = vec![];
343        bytes.write_u32::<NativeEndian>(1).unwrap();
344
345        // Little endian only right now
346        let expected = [20, 0, 0, 0, 0, 0, 1, 3, 1, 0, 0, 0, 9, 0, 0, 0];
347        let mut hdr = NlMsgHeader::request();
348        hdr.data_length(4).pid(9).seq(1).dump();
349
350        bytes.write(&expected).unwrap();
351
352        let (p, n) = Payload::nlmsg_error(&bytes).unwrap();
353
354        assert_eq!(n, bytes.len());
355        if let Payload::Err(h) = p {
356            assert_eq!(h, hdr);
357        } else {
358            panic!("payload is not Err enum");
359        }
360    }
361
362    #[test]
363    fn test_payload_decode_with_ack() {
364        let mut bytes = vec![];
365        bytes.write_u32::<NativeEndian>(0).unwrap();
366
367        let mut hdr = NlMsgHeader::request();
368        hdr.data_length(4).pid(9).seq(1).dump();
369
370        bytes.write(&hdr.bytes()).unwrap();
371
372        let (p, n) = Payload::nlmsg_error(&bytes).unwrap();
373
374        assert_eq!(n, bytes.len());
375        if let Payload::Ack(h) = p {
376            assert_eq!(h, hdr);
377        } else {
378            panic!("payload is not Ack enum");
379        }
380    }
381
382    #[test]
383    fn test_msg_decode() {
384        // Little endian only right now
385        let mut hdr = NlMsgHeader::request();
386        hdr.data_length(4).pid(9).seq(1).dump();
387        let hdr_bytes = hdr.bytes();
388
389        let data = [0,1,2,3];
390
391        let mut bytes = vec![];
392        bytes.write(&hdr_bytes).unwrap();
393        bytes.write(&data).unwrap();
394        // Random data
395        bytes.write(&[1,1,1,1,1,1,1]).unwrap();
396
397        let (msg, n) = Msg::from_bytes(&bytes).unwrap();
398        assert_eq!(n, hdr_bytes.len() + data.len());
399        assert_eq!(hdr, msg.header());
400
401        if let &Payload::Data(b) = msg.payload() {
402            assert_eq!(b, &data);
403        } else {
404            panic!("msg is not Data enum");
405        }
406    }
407
408    #[test]
409    fn test_msg_decode_with_err() {
410        let mut hdr = NlMsgHeader::error();
411        hdr.pid(9).seq(1);
412        let hdr_bytes = hdr.bytes();
413
414        let mut bytes = vec![];
415        bytes.write(&hdr_bytes).unwrap();
416
417        bytes.write_u32::<NativeEndian>(1).unwrap();
418        let mut err_hdr = NlMsgHeader::request();
419        err_hdr.data_length(4).pid(9).seq(1).dump();
420        bytes.write(&err_hdr.bytes()).unwrap();
421
422        let (msg, n) = Msg::from_bytes(&bytes).unwrap();
423        assert_eq!(n, bytes.len());
424        assert_eq!(hdr, msg.header());
425
426        if let &Payload::Err(h) = msg.payload() {
427            assert_eq!(h, err_hdr);
428        } else {
429            panic!("msg is not Err enum");
430        }
431    }
432}