linux-taskstats 0.1.3

Rust interface to Linux taskstats
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
use crate::AsBuf;
use libc;
use log::debug;
use netlink_sys::{self as nl, Protocol, Socket, SocketAddr};
use std::io;
use std::mem;
use std::process;
use std::slice;
use thiserror::Error;

const MAX_MESSAGE_SIZE: usize = 1024;

#[derive(Debug, Error)]
pub enum Error {
    #[error("error in I/O with netlink socket: {0}")]
    SocketIo(#[from] io::Error),
    #[error("corrupted data read from netlink socket: {0}")]
    Protocol(String),
    #[error("error response received from remote")]
    ErrorResponse,
}

pub type Result<T> = std::result::Result<T, Error>;

mod nlmsg {
    use crate::c_headers::NLMSG_ALIGNTO;
    use std::mem;

    pub const HDRLEN: usize = align(mem::size_of::<libc::nlmsghdr>());
    pub const GENL_HDRLEN: usize = align(mem::size_of::<libc::genlmsghdr>());

    pub const fn align(len: usize) -> usize {
        (len + NLMSG_ALIGNTO as usize - 1) & !(NLMSG_ALIGNTO as usize - 1)
    }

    #[inline]
    pub fn is_valid(nlh: &libc::nlmsghdr, len: usize) -> bool {
        len >= mem::size_of::<libc::nlmsghdr>()
            && nlh.nlmsg_len as usize >= mem::size_of::<libc::nlmsghdr>()
            && nlh.nlmsg_len as usize <= len
    }
}

mod nla {
    use std::mem;

    pub const HDRLEN: usize = align(mem::size_of::<libc::nlattr>());

    pub const fn align(len: usize) -> usize {
        (len + libc::NLA_ALIGNTO as usize - 1) & !(libc::NLA_ALIGNTO as usize - 1)
    }

    #[inline]
    pub fn payload(na: &libc::nlattr) -> *const u8 {
        unsafe { (na as *const libc::nlattr as *const u8).offset(HDRLEN as isize) }
    }

    #[inline]
    pub fn next(na: &libc::nlattr) -> &libc::nlattr {
        unsafe {
            &*((na as *const libc::nlattr as *const u8).offset(align(na.nla_len as usize) as isize)
                as *const libc::nlattr)
        }
    }
}

/// Trait abstracting netlink socket IO.
/// This trait is only meant to replace socket implementation at unit testing.
pub trait NlSocket {
    type Addr;

    fn send_to(&self, buf: &[u8], addr: &Self::Addr) -> io::Result<usize>;

    fn recv(&self, buf: &mut [u8]) -> io::Result<usize>;
}

impl NlSocket for nl::Socket {
    type Addr = nl::SocketAddr;

    fn send_to(&self, buf: &[u8], addr: &Self::Addr) -> io::Result<usize> {
        self.send_to(buf, addr, 0)
    }

    fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
        self.recv(buf, 0)
    }
}

/// Netlink protocol implementation specifically for taskstats querying.
pub struct Netlink<S: NlSocket = nl::Socket> {
    sock: S,
    remote_addr: S::Addr,
    mypid: u32,
}

impl Netlink<nl::Socket> {
    pub fn open() -> Result<Netlink<nl::Socket>> {
        let mut sock = Socket::new(Protocol::Generic)?;
        let addr = SocketAddr::new(0, 0);
        sock.bind(&addr)?;
        Ok(Netlink {
            sock,
            remote_addr: SocketAddr::new(0, 0),
            mypid: process::id(),
        })
    }
}

impl<S: NlSocket> Netlink<S> {
    pub fn send_cmd(
        &self,
        nlmsg_type: u16,
        genl_cmd: u8,
        nla_type: u16,
        nla_data: &[u8],
    ) -> Result<()> {
        debug!(
            "Sending nl cmd: type={}, genl_cmd={}, nla_type={} nla_data.len={}",
            nlmsg_type,
            genl_cmd,
            nla_type,
            nla_data.len()
        );

        let attr = libc::nlattr {
            nla_type,
            nla_len: nla::align(nla::HDRLEN + nla_data.len()) as u16,
        };
        let mut buf = [0u8; MAX_MESSAGE_SIZE];
        let bufp = buf.as_mut_ptr();
        unsafe {
            std::ptr::copy_nonoverlapping(
                &attr as *const libc::nlattr as *const u8,
                bufp,
                mem::size_of::<libc::nlattr>(),
            );
            std::ptr::copy_nonoverlapping(
                nla_data.as_ptr() as *const u8,
                bufp.offset(nla::HDRLEN as isize),
                nla_data.len(),
            );
        }

        let nlmsg_len = nlmsg::HDRLEN + nlmsg::GENL_HDRLEN + attr.nla_len as usize;
        let msg = GenNlMsg {
            nlmsg_header: libc::nlmsghdr {
                nlmsg_len: nlmsg_len as u32,
                nlmsg_type,
                nlmsg_flags: libc::NLM_F_REQUEST as u16,
                nlmsg_seq: 0,
                nlmsg_pid: self.mypid,
            },
            genlmsg_header: libc::genlmsghdr {
                cmd: genl_cmd,
                version: 0x1,
                reserved: 0x0,
            },
            buf,
        };
        debug!("Sending msg of size={}", nlmsg_len);

        let mut send_buf = &msg.as_buf()[..msg.nlmsg_header.nlmsg_len as usize];
        loop {
            let sent_size = self.sock.send_to(&send_buf, &self.remote_addr)?;
            if sent_size == send_buf.len() {
                break;
            }
            send_buf = &send_buf[sent_size..];
        }
        Ok(())
    }

    pub fn recv_response(&self) -> Result<GenNlMsg> {
        let mut msg: GenNlMsg = unsafe { mem::zeroed() };
        let rep_len = self.sock.recv(msg.as_buf_mut())?;

        debug!(
            "Received msg: size={}, type={}, nlmsg_len={}",
            rep_len, msg.nlmsg_header.nlmsg_type, msg.nlmsg_header.nlmsg_len
        );

        if !nlmsg::is_valid(&msg.nlmsg_header, rep_len) {
            return Err(Error::Protocol(format!(
                "header len: {}, recv size: {}",
                msg.nlmsg_header.nlmsg_len, rep_len
            )));
        }
        if msg.nlmsg_header.nlmsg_len as usize > mem::size_of::<GenNlMsg>() {
            return Err(Error::Protocol(format!(
                "too large message size: {}",
                msg.nlmsg_header.nlmsg_len
            )));
        }

        if msg.nlmsg_header.nlmsg_type == libc::NLMSG_ERROR as u16 {
            return Err(Error::ErrorResponse);
        }

        Ok(msg)
    }
}

pub trait NlPayload {
    fn payload(&self) -> &[u8];

    #[inline]
    fn payload_len(&self) -> usize {
        self.payload().len()
    }

    fn payload_as<T>(&self) -> &T {
        if mem::size_of::<T>() > self.payload_len() {
            panic!(
                "attempt to cast buffer into type that has larger size than buf length: {} > {}",
                mem::size_of::<T>(),
                self.payload_len()
            );
        }
        unsafe { &*(self.payload().as_ptr() as *const T) }
    }

    fn payload_as_nlattrs(&self) -> NlAttrs<'_> {
        NlAttrs {
            next: Some(self.payload_as::<libc::nlattr>()),
            rem_size: self.payload_len(),
        }
    }
}

#[repr(C)]
pub struct GenNlMsg {
    pub nlmsg_header: libc::nlmsghdr,
    pub genlmsg_header: libc::genlmsghdr,
    pub buf: [u8; MAX_MESSAGE_SIZE],
}

impl NlPayload for GenNlMsg {
    fn payload(&self) -> &[u8] {
        let len = self.nlmsg_header.nlmsg_len as usize - nlmsg::HDRLEN - nlmsg::GENL_HDRLEN;
        &self.buf[..len]
    }
}

pub struct NlAttr<'a> {
    pub header: &'a libc::nlattr,
}

impl<'a> NlPayload for NlAttr<'a> {
    fn payload(&self) -> &[u8] {
        let len = self.header.nla_len as usize - nla::HDRLEN;
        unsafe { slice::from_raw_parts(nla::payload(self.header), len) }
    }
}

pub struct NlAttrs<'a> {
    next: Option<&'a libc::nlattr>,
    rem_size: usize,
}

impl<'a> Iterator for NlAttrs<'a> {
    type Item = NlAttr<'a>;

    fn next(&mut self) -> Option<Self::Item> {
        if let Some(ret) = self.next.take() {
            self.rem_size -= nla::align(ret.nla_len as usize);
            if self.rem_size >= nla::HDRLEN {
                let next = nla::next(&ret);
                self.next.replace(next);
            }
            return Some(NlAttr { header: ret });
        }
        None
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::net::{SocketAddr, UdpSocket};
    use std::ptr;

    const NLMSG_TYPE: u16 = 32;
    const GENL_CMD: u8 = 3;
    const NLA_TYPE: u16 = 17;
    const PID: u32 = 1234;
    const PAYLOAD: &'static str = "Hello";

    impl NlSocket for UdpSocket {
        type Addr = SocketAddr;

        fn send_to(&self, buf: &[u8], addr: &Self::Addr) -> io::Result<usize> {
            self.send_to(buf, addr)
        }

        fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
            self.recv(buf)
        }
    }

    fn nl_sock() -> UdpSocket {
        UdpSocket::bind("localhost:0").unwrap()
    }

    fn nl(serv_sock: &UdpSocket) -> Netlink<UdpSocket> {
        let sock = nl_sock();
        Netlink {
            sock,
            remote_addr: serv_sock.local_addr().unwrap(),
            mypid: PID,
        }
    }

    #[test]
    fn test_send_cmd() {
        let serv_sock = nl_sock();
        let nl = nl(&serv_sock);

        nl.send_cmd(NLMSG_TYPE, GENL_CMD, NLA_TYPE, PAYLOAD.as_bytes())
            .unwrap();
        let mut buf = [0u8; 256];
        let size = serv_sock.recv(&mut buf).unwrap();

        let expect_size =
            nlmsg::HDRLEN + nlmsg::GENL_HDRLEN + nla::HDRLEN + nla::align(PAYLOAD.as_bytes().len());
        assert_eq!(expect_size, size);

        let n = unsafe { &*(&buf as *const u8 as *const libc::nlmsghdr) };
        assert_eq!(expect_size, n.nlmsg_len as usize);
        assert_eq!(NLMSG_TYPE, n.nlmsg_type);
        assert_eq!(PID, n.nlmsg_pid);

        let g = unsafe {
            &*((&buf as *const u8).offset(nlmsg::HDRLEN as isize) as *const libc::genlmsghdr)
        };
        assert_eq!(GENL_CMD, g.cmd);

        let payload = unsafe {
            slice::from_raw_parts(
                (&buf as *const u8)
                    .offset((nlmsg::HDRLEN + nlmsg::GENL_HDRLEN + nla::HDRLEN) as isize),
                PAYLOAD.len(),
            )
        };
        assert_eq!(PAYLOAD.as_bytes(), payload);
    }

    #[test]
    fn test_recv_response() {
        let serv_sock = nl_sock();
        let nl = nl(&serv_sock);

        let mut pos = 0;

        let mut buf = [0u8; 256];
        let nlmsg_len = nlmsg::HDRLEN + nlmsg::GENL_HDRLEN + PAYLOAD.len();
        let addr = nl.sock.local_addr().unwrap();
        let n = libc::nlmsghdr {
            nlmsg_len: nlmsg_len as u32,
            nlmsg_type: NLMSG_TYPE,
            nlmsg_flags: 0,
            nlmsg_seq: 0,
            nlmsg_pid: PID,
        };
        unsafe {
            ptr::copy_nonoverlapping(
                &n as *const libc::nlmsghdr as *const u8,
                buf.as_mut_ptr().offset(pos as isize),
                mem::size_of::<libc::nlmsghdr>(),
            );
        }
        pos += nlmsg::HDRLEN;

        let g = libc::genlmsghdr {
            cmd: GENL_CMD,
            version: 0x1,
            reserved: 0x0,
        };
        unsafe {
            ptr::copy_nonoverlapping(
                &g as *const libc::genlmsghdr as *const u8,
                buf.as_mut_ptr().offset(pos as isize),
                mem::size_of::<libc::genlmsghdr>(),
            );
        }
        pos += nlmsg::GENL_HDRLEN;

        unsafe {
            ptr::copy_nonoverlapping(
                PAYLOAD.as_ptr(),
                buf.as_mut_ptr().offset(pos as isize),
                PAYLOAD.len(),
            );
        }
        pos += PAYLOAD.len();

        serv_sock.send_to(&buf[..pos], &addr).unwrap();

        let resp = nl.recv_response().unwrap();
        assert_eq!(n.nlmsg_len, resp.nlmsg_header.nlmsg_len);
        assert_eq!(n.nlmsg_type, resp.nlmsg_header.nlmsg_type);
        assert_eq!(n.nlmsg_pid, resp.nlmsg_header.nlmsg_pid);
        assert_eq!(g.cmd, resp.genlmsg_header.cmd);
        assert_eq!(PAYLOAD.as_bytes(), &resp.buf[..PAYLOAD.len()]);
    }

    #[test]
    fn test_nlpayload() {
        struct Msg<'a>(&'a [u8]);
        impl<'a> NlPayload for Msg<'a> {
            fn payload(&self) -> &[u8] {
                self.0
            }
        }

        let n: u32 = 1234;
        let m = Msg(unsafe {
            slice::from_raw_parts(&n as *const u32 as *const u8, mem::size_of::<u32>())
        });
        assert_eq!(mem::size_of::<u32>(), m.payload_len());
        assert_eq!(n, *m.payload_as());
    }

    #[test]
    fn test_nlpayload_nlattrs() {
        let mut buf = [0u8; 256];

        fn add_na<T>(buf: &mut [u8], pos: &mut usize, val: T) {
            let header =
                unsafe { &mut *(buf.as_mut_ptr().offset(*pos as isize) as *mut libc::nlattr) };
            header.nla_type = 0;
            header.nla_len = nla::align(nla::HDRLEN + mem::size_of::<T>()) as u16;
            unsafe {
                ptr::copy_nonoverlapping(
                    &val as *const T as *const u8,
                    buf.as_mut_ptr().offset((*pos + nla::HDRLEN) as isize),
                    mem::size_of::<T>(),
                )
            };
            *pos += header.nla_len as usize;
        }

        let header = unsafe { &mut *(buf.as_mut_ptr() as *mut libc::nlattr) };
        header.nla_type = 0;
        header.nla_len =
            nla::align(nla::HDRLEN + nla::align(nla::HDRLEN + mem::size_of::<char>()) * 3) as u16;

        let mut pos = nla::HDRLEN;
        add_na(&mut buf, &mut pos, 'a');
        add_na(&mut buf, &mut pos, 'b');
        add_na(&mut buf, &mut pos, 'c');

        let outer = NlAttr {
            header: unsafe { &*(buf.as_ptr() as *const libc::nlattr) },
        };
        let mut iter = outer.payload_as_nlattrs();
        assert_eq!(Some('a' as u8), iter.next().map(|x| *x.payload_as()));
        assert_eq!(Some('b' as u8), iter.next().map(|x| *x.payload_as()));
        assert_eq!(Some('c' as u8), iter.next().map(|x| *x.payload_as()));
        assert_eq!(None, iter.next().map(|x| *x.payload_as::<u8>()));
    }

    #[test]
    fn test_gennlmsg_payload() {
        const LEN: usize = 3;
        let mut msg: GenNlMsg = unsafe { mem::zeroed() };
        msg.nlmsg_header.nlmsg_len = nlmsg::align(nlmsg::HDRLEN + nlmsg::GENL_HDRLEN + LEN) as u32;
        let p = msg.payload();
        assert_eq!(msg.buf.as_ptr(), p.as_ptr());
        assert_eq!(nlmsg::align(LEN), p.len());
    }

    #[test]
    fn test_nlattr_payload() {
        const LEN: usize = 3;
        let na = libc::nlattr {
            nla_len: nla::align(nla::HDRLEN + LEN) as u16,
            nla_type: 0,
        };
        let nlattr = NlAttr { header: &na };
        let p = nlattr.payload();
        let expect_p =
            unsafe { (&na as *const libc::nlattr as *const u8).offset(nla::HDRLEN as isize) };
        assert_eq!(expect_p, p.as_ptr());
        assert_eq!(nlmsg::align(LEN), p.len());
    }
}