netlink_packet_utils/
nla.rs

1// SPDX-License-Identifier: MIT
2
3use core::ops::Range;
4
5use anyhow::Context;
6use byteorder::{ByteOrder, NativeEndian};
7
8use crate::{
9    traits::{Emitable, Parseable},
10    DecodeError,
11};
12
13/// Represent a multi-bytes field with a fixed size in a packet
14type Field = Range<usize>;
15
16/// Identify the bits that represent the "nested" flag of a netlink attribute.
17pub const NLA_F_NESTED: u16 = 0x8000;
18/// Identify the bits that represent the "byte order" flag of a netlink
19/// attribute.
20pub const NLA_F_NET_BYTEORDER: u16 = 0x4000;
21/// Identify the bits that represent the type of a netlink attribute.
22pub const NLA_TYPE_MASK: u16 = !(NLA_F_NET_BYTEORDER | NLA_F_NESTED);
23/// NlA(RTA) align size
24pub const NLA_ALIGNTO: usize = 4;
25/// NlA(RTA) header size. (unsigned short rta_len) + (unsigned short rta_type)
26pub const NLA_HEADER_SIZE: usize = 4;
27
28#[macro_export]
29macro_rules! nla_align {
30    ($len: expr) => {
31        ($len + NLA_ALIGNTO - 1) & !(NLA_ALIGNTO - 1)
32    };
33}
34
35const LENGTH: Field = 0..2;
36const TYPE: Field = 2..4;
37#[allow(non_snake_case)]
38fn VALUE(length: usize) -> Field {
39    TYPE.end..TYPE.end + length
40}
41
42// with Copy, NlaBuffer<&'buffer T> can be copied, which turns out to be pretty
43// conveninent. And since it's boils down to copying a reference it's pretty
44// cheap
45#[derive(Debug, PartialEq, Eq, Clone, Copy)]
46pub struct NlaBuffer<T: AsRef<[u8]>> {
47    buffer: T,
48}
49
50impl<T: AsRef<[u8]>> NlaBuffer<T> {
51    pub fn new(buffer: T) -> NlaBuffer<T> {
52        NlaBuffer { buffer }
53    }
54
55    pub fn new_checked(buffer: T) -> Result<NlaBuffer<T>, DecodeError> {
56        let buffer = Self::new(buffer);
57        buffer.check_buffer_length().context("invalid NLA buffer")?;
58        Ok(buffer)
59    }
60
61    pub fn check_buffer_length(&self) -> Result<(), DecodeError> {
62        let len = self.buffer.as_ref().len();
63        if len < TYPE.end {
64            Err(format!(
65                "buffer has length {}, but an NLA header is {} bytes",
66                len, TYPE.end
67            )
68            .into())
69        } else if len < self.length() as usize {
70            Err(format!(
71                "buffer has length: {}, but the NLA is {} bytes",
72                len,
73                self.length()
74            )
75            .into())
76        } else if (self.length() as usize) < TYPE.end {
77            Err(format!(
78                "NLA has invalid length: {} (should be at least {} bytes",
79                self.length(),
80                TYPE.end,
81            )
82            .into())
83        } else {
84            Ok(())
85        }
86    }
87
88    /// Consume the buffer, returning the underlying buffer.
89    pub fn into_inner(self) -> T {
90        self.buffer
91    }
92
93    /// Return a reference to the underlying buffer
94    pub fn inner(&mut self) -> &T {
95        &self.buffer
96    }
97
98    /// Return a mutable reference to the underlying buffer
99    pub fn inner_mut(&mut self) -> &mut T {
100        &mut self.buffer
101    }
102
103    /// Return the `type` field
104    pub fn kind(&self) -> u16 {
105        let data = self.buffer.as_ref();
106        NativeEndian::read_u16(&data[TYPE]) & NLA_TYPE_MASK
107    }
108
109    pub fn nested_flag(&self) -> bool {
110        let data = self.buffer.as_ref();
111        (NativeEndian::read_u16(&data[TYPE]) & NLA_F_NESTED) != 0
112    }
113
114    pub fn network_byte_order_flag(&self) -> bool {
115        let data = self.buffer.as_ref();
116        (NativeEndian::read_u16(&data[TYPE]) & NLA_F_NET_BYTEORDER) != 0
117    }
118
119    /// Return the `length` field. The `length` field corresponds to the length
120    /// of the nla header (type and length fields, and the value field).
121    /// However, it does not account for the potential padding that follows
122    /// the value field.
123    pub fn length(&self) -> u16 {
124        let data = self.buffer.as_ref();
125        NativeEndian::read_u16(&data[LENGTH])
126    }
127
128    /// Return the length of the `value` field
129    ///
130    /// # Panic
131    ///
132    /// This panics if the length field value is less than the attribut header
133    /// size.
134    pub fn value_length(&self) -> usize {
135        self.length() as usize - TYPE.end
136    }
137}
138
139impl<T: AsRef<[u8]> + AsMut<[u8]>> NlaBuffer<T> {
140    /// Set the `type` field
141    pub fn set_kind(&mut self, kind: u16) {
142        let data = self.buffer.as_mut();
143        NativeEndian::write_u16(&mut data[TYPE], kind & NLA_TYPE_MASK)
144    }
145
146    pub fn set_nested_flag(&mut self) {
147        let kind = self.kind();
148        let data = self.buffer.as_mut();
149        NativeEndian::write_u16(&mut data[TYPE], kind | NLA_F_NESTED)
150    }
151
152    pub fn set_network_byte_order_flag(&mut self) {
153        let kind = self.kind();
154        let data = self.buffer.as_mut();
155        NativeEndian::write_u16(&mut data[TYPE], kind | NLA_F_NET_BYTEORDER)
156    }
157
158    /// Set the `length` field
159    pub fn set_length(&mut self, length: u16) {
160        let data = self.buffer.as_mut();
161        NativeEndian::write_u16(&mut data[LENGTH], length)
162    }
163}
164
165impl<'buffer, T: AsRef<[u8]> + ?Sized> NlaBuffer<&'buffer T> {
166    /// Return the `value` field
167    pub fn value(&self) -> &[u8] {
168        &self.buffer.as_ref()[VALUE(self.value_length())]
169    }
170}
171
172impl<'buffer, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> NlaBuffer<&'buffer mut T> {
173    /// Return the `value` field
174    pub fn value_mut(&mut self) -> &mut [u8] {
175        let length = VALUE(self.value_length());
176        &mut self.buffer.as_mut()[length]
177    }
178}
179
180#[derive(Debug, PartialEq, Eq, Clone)]
181pub struct DefaultNla {
182    kind: u16,
183    value: Vec<u8>,
184}
185
186impl DefaultNla {
187    pub fn new(kind: u16, value: Vec<u8>) -> Self {
188        Self { kind, value }
189    }
190}
191
192impl Nla for DefaultNla {
193    fn value_len(&self) -> usize {
194        self.value.len()
195    }
196    fn kind(&self) -> u16 {
197        self.kind
198    }
199    fn emit_value(&self, buffer: &mut [u8]) {
200        buffer.copy_from_slice(self.value.as_slice());
201    }
202}
203
204impl<'buffer, T: AsRef<[u8]> + ?Sized> Parseable<NlaBuffer<&'buffer T>>
205    for DefaultNla
206{
207    fn parse(buf: &NlaBuffer<&'buffer T>) -> Result<Self, DecodeError> {
208        let mut kind = buf.kind();
209
210        if buf.network_byte_order_flag() {
211            kind |= NLA_F_NET_BYTEORDER;
212        }
213
214        if buf.nested_flag() {
215            kind |= NLA_F_NESTED;
216        }
217
218        Ok(DefaultNla {
219            kind,
220            value: buf.value().to_vec(),
221        })
222    }
223}
224
225pub trait Nla {
226    fn value_len(&self) -> usize;
227    fn kind(&self) -> u16;
228    fn emit_value(&self, buffer: &mut [u8]);
229
230    #[inline]
231    fn is_nested(&self) -> bool {
232        (self.kind() & NLA_F_NESTED) != 0
233    }
234
235    #[inline]
236    fn is_network_byteorder(&self) -> bool {
237        (self.kind() & NLA_F_NET_BYTEORDER) != 0
238    }
239}
240
241impl<T: Nla> Emitable for T {
242    fn buffer_len(&self) -> usize {
243        nla_align!(self.value_len()) + NLA_HEADER_SIZE
244    }
245    fn emit(&self, buffer: &mut [u8]) {
246        let mut buffer = NlaBuffer::new(buffer);
247        buffer.set_kind(self.kind());
248
249        if self.is_network_byteorder() {
250            buffer.set_network_byte_order_flag()
251        }
252
253        if self.is_nested() {
254            buffer.set_nested_flag()
255        }
256
257        // do not include the padding here, but do include the header
258        buffer.set_length(self.value_len() as u16 + NLA_HEADER_SIZE as u16);
259
260        self.emit_value(buffer.value_mut());
261
262        let padding = nla_align!(self.value_len()) - self.value_len();
263        for i in 0..padding {
264            buffer.inner_mut()[NLA_HEADER_SIZE + self.value_len() + i] = 0;
265        }
266    }
267}
268
269// FIXME: whern specialization lands, why can actually have
270//
271// impl<'a, T: Nla, I: Iterator<Item=T>> Emitable for I { ...}
272//
273// The reason this does not work today is because it conflicts with
274//
275// impl<T: Nla> Emitable for T { ... }
276impl<'a, T: Nla> Emitable for &'a [T] {
277    fn buffer_len(&self) -> usize {
278        self.iter().fold(0, |acc, nla| {
279            assert_eq!(nla.buffer_len() % NLA_ALIGNTO, 0);
280            acc + nla.buffer_len()
281        })
282    }
283
284    fn emit(&self, buffer: &mut [u8]) {
285        let mut start = 0;
286        let mut end: usize;
287        for nla in self.iter() {
288            let attr_len = nla.buffer_len();
289            assert_eq!(nla.buffer_len() % NLA_ALIGNTO, 0);
290            end = start + attr_len;
291            nla.emit(&mut buffer[start..end]);
292            start = end;
293        }
294    }
295}
296
297/// An iterator that iteratates over nlas without decoding them. This is useful
298/// when looking for specific nlas.
299#[derive(Debug, Clone, Copy, PartialEq, Eq)]
300pub struct NlasIterator<T> {
301    position: usize,
302    buffer: T,
303}
304
305impl<T> NlasIterator<T> {
306    pub fn new(buffer: T) -> Self {
307        NlasIterator {
308            position: 0,
309            buffer,
310        }
311    }
312}
313
314impl<'buffer, T: AsRef<[u8]> + ?Sized + 'buffer> Iterator
315    for NlasIterator<&'buffer T>
316{
317    type Item = Result<NlaBuffer<&'buffer [u8]>, DecodeError>;
318
319    fn next(&mut self) -> Option<Self::Item> {
320        if self.position >= self.buffer.as_ref().len() {
321            return None;
322        }
323
324        match NlaBuffer::new_checked(&self.buffer.as_ref()[self.position..]) {
325            Ok(nla_buffer) => {
326                self.position += nla_align!(nla_buffer.length() as usize);
327                Some(Ok(nla_buffer))
328            }
329            Err(e) => {
330                // Make sure next time we call `next()`, we return None. We
331                // don't try to continue iterating after we
332                // failed to return a buffer.
333                self.position = self.buffer.as_ref().len();
334                Some(Err(e))
335            }
336        }
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    #[test]
345    fn network_byteorder() {
346        // The IPSET_ATTR_TIMEOUT attribute should have the network byte order
347        // flag set. IPSET_ATTR_TIMEOUT(3600)
348        static TEST_ATTRIBUTE: &[u8] =
349            &[0x08, 0x00, 0x06, 0x40, 0x00, 0x00, 0x0e, 0x10];
350        let buffer = NlaBuffer::new(TEST_ATTRIBUTE);
351        let buffer_is_net = buffer.network_byte_order_flag();
352        let buffer_is_nest = buffer.nested_flag();
353
354        let nla = DefaultNla::parse(&buffer).unwrap();
355        let mut emitted_buffer = vec![0; nla.buffer_len()];
356
357        nla.emit(&mut emitted_buffer);
358
359        let attr_is_net = nla.is_network_byteorder();
360        let attr_is_nest = nla.is_nested();
361
362        let emit = NlaBuffer::new(emitted_buffer);
363        let emit_is_net = emit.network_byte_order_flag();
364        let emit_is_nest = emit.nested_flag();
365
366        assert_eq!(
367            [buffer_is_net, buffer_is_nest],
368            [attr_is_net, attr_is_nest]
369        );
370        assert_eq!([attr_is_net, attr_is_nest], [emit_is_net, emit_is_nest]);
371    }
372
373    fn get_len() -> usize {
374        // usize::MAX
375        18446744073709551615
376    }
377
378    #[test]
379    fn test_align() {
380        assert_eq!(nla_align!(13), 16);
381        assert_eq!(nla_align!(16), 16);
382        assert_eq!(nla_align!(0), 0);
383        assert_eq!(nla_align!(1), 4);
384        assert_eq!(nla_align!(get_len() - 4), usize::MAX - 3);
385    }
386    #[test]
387    #[should_panic]
388    fn test_align_overflow() {
389        assert_eq!(nla_align!(get_len() - 3), usize::MAX);
390    }
391}