netlink_packet_core/
nla.rs

1// SPDX-License-Identifier: MIT
2
3use crate::{
4    emit_u16, parse_u16,
5    traits::{Emitable, Parseable},
6    DecodeError,
7};
8use core::ops::Range;
9
10/// Represent a multi-bytes field with a fixed size in a packet
11type Field = Range<usize>;
12
13/// Identify the bits that represent the "nested" flag of a netlink attribute.
14pub const NLA_F_NESTED: u16 = 0x8000;
15/// Identify the bits that represent the "byte order" flag of a netlink
16/// attribute.
17pub const NLA_F_NET_BYTEORDER: u16 = 0x4000;
18/// Identify the bits that represent the type of a netlink attribute.
19pub const NLA_TYPE_MASK: u16 = !(NLA_F_NET_BYTEORDER | NLA_F_NESTED);
20/// NlA(RTA) align size
21pub const NLA_ALIGNTO: usize = 4;
22/// NlA(RTA) header size. (unsigned short rta_len) + (unsigned short rta_type)
23pub const NLA_HEADER_SIZE: usize = 4;
24
25#[macro_export]
26macro_rules! nla_align {
27    ($len: expr) => {
28        ($len + NLA_ALIGNTO - 1) & !(NLA_ALIGNTO - 1)
29    };
30}
31
32const LENGTH: Field = 0..2;
33const TYPE: Field = 2..4;
34#[allow(non_snake_case)]
35fn VALUE(length: usize) -> Field {
36    TYPE.end..TYPE.end + length
37}
38
39// with Copy, NlaBuffer<&'buffer T> can be copied, which turns out to be pretty
40// conveninent. And since it's boils down to copying a reference it's pretty
41// cheap
42#[derive(Debug, PartialEq, Eq, Clone, Copy)]
43pub struct NlaBuffer<T: AsRef<[u8]>> {
44    buffer: T,
45}
46
47impl<T: AsRef<[u8]>> NlaBuffer<T> {
48    pub fn new(buffer: T) -> NlaBuffer<T> {
49        NlaBuffer { buffer }
50    }
51
52    pub fn new_checked(buffer: T) -> Result<NlaBuffer<T>, DecodeError> {
53        let buffer = Self::new(buffer);
54        buffer.check_buffer_length()?;
55        Ok(buffer)
56    }
57
58    pub fn check_buffer_length(&self) -> Result<(), DecodeError> {
59        let len = self.buffer.as_ref().len();
60        if len < TYPE.end {
61            Err(DecodeError::nla_buffer_too_small(len, TYPE.end))
62        } else if len < self.length() as usize {
63            Err(DecodeError::nla_length_mismatch(
64                len,
65                self.length() as usize,
66            ))
67        } else if (self.length() as usize) < TYPE.end {
68            Err(DecodeError::nla_invalid_length(len, self.length() as usize))
69        } else {
70            Ok(())
71        }
72    }
73
74    /// Consume the buffer, returning the underlying buffer.
75    pub fn into_inner(self) -> T {
76        self.buffer
77    }
78
79    /// Return a reference to the underlying buffer
80    pub fn inner(&mut self) -> &T {
81        &self.buffer
82    }
83
84    /// Return a mutable reference to the underlying buffer
85    pub fn inner_mut(&mut self) -> &mut T {
86        &mut self.buffer
87    }
88
89    /// Return the `type` field
90    pub fn kind(&self) -> u16 {
91        let data = self.buffer.as_ref();
92        parse_u16(&data[TYPE]).unwrap() & NLA_TYPE_MASK
93    }
94
95    pub fn nested_flag(&self) -> bool {
96        let data = self.buffer.as_ref();
97        (parse_u16(&data[TYPE]).unwrap() & NLA_F_NESTED) != 0
98    }
99
100    pub fn network_byte_order_flag(&self) -> bool {
101        let data = self.buffer.as_ref();
102        (parse_u16(&data[TYPE]).unwrap() & NLA_F_NET_BYTEORDER) != 0
103    }
104
105    /// Return the `length` field. The `length` field corresponds to the length
106    /// of the nla header (type and length fields, and the value field).
107    /// However, it does not account for the potential padding that follows
108    /// the value field.
109    pub fn length(&self) -> u16 {
110        let data = self.buffer.as_ref();
111        parse_u16(&data[LENGTH]).unwrap()
112    }
113
114    /// Return the length of the `value` field
115    ///
116    /// # Panic
117    ///
118    /// This panics if the length field value is less than the attribut header
119    /// size.
120    pub fn value_length(&self) -> usize {
121        self.length() as usize - TYPE.end
122    }
123}
124
125impl<T: AsRef<[u8]> + AsMut<[u8]>> NlaBuffer<T> {
126    /// Set the `type` field
127    pub fn set_kind(&mut self, kind: u16) {
128        let data = self.buffer.as_mut();
129        emit_u16(&mut data[TYPE], kind & NLA_TYPE_MASK).unwrap()
130    }
131
132    pub fn set_nested_flag(&mut self) {
133        let kind = self.kind();
134        let data = self.buffer.as_mut();
135        emit_u16(&mut data[TYPE], kind | NLA_F_NESTED).unwrap()
136    }
137
138    pub fn set_network_byte_order_flag(&mut self) {
139        let kind = self.kind();
140        let data = self.buffer.as_mut();
141        emit_u16(&mut data[TYPE], kind | NLA_F_NET_BYTEORDER).unwrap()
142    }
143
144    /// Set the `length` field
145    pub fn set_length(&mut self, length: u16) {
146        let data = self.buffer.as_mut();
147        emit_u16(&mut data[LENGTH], length).unwrap()
148    }
149}
150
151impl<T: AsRef<[u8]> + ?Sized> NlaBuffer<&T> {
152    /// Return the `value` field
153    pub fn value(&self) -> &[u8] {
154        &self.buffer.as_ref()[VALUE(self.value_length())]
155    }
156}
157
158impl<T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> NlaBuffer<&mut T> {
159    /// Return the `value` field
160    pub fn value_mut(&mut self) -> &mut [u8] {
161        let length = VALUE(self.value_length());
162        &mut self.buffer.as_mut()[length]
163    }
164}
165
166#[derive(Debug, PartialEq, Eq, Clone)]
167pub struct DefaultNla {
168    kind: u16,
169    value: Vec<u8>,
170}
171
172impl DefaultNla {
173    pub fn new(kind: u16, value: Vec<u8>) -> Self {
174        Self { kind, value }
175    }
176}
177
178impl Nla for DefaultNla {
179    fn value_len(&self) -> usize {
180        self.value.len()
181    }
182    fn kind(&self) -> u16 {
183        self.kind
184    }
185    fn emit_value(&self, buffer: &mut [u8]) {
186        buffer.copy_from_slice(self.value.as_slice());
187    }
188}
189
190impl<'buffer, T: AsRef<[u8]> + ?Sized> Parseable<NlaBuffer<&'buffer T>>
191    for DefaultNla
192{
193    fn parse(buf: &NlaBuffer<&'buffer T>) -> Result<Self, DecodeError> {
194        let mut kind = buf.kind();
195
196        if buf.network_byte_order_flag() {
197            kind |= NLA_F_NET_BYTEORDER;
198        }
199
200        if buf.nested_flag() {
201            kind |= NLA_F_NESTED;
202        }
203
204        Ok(DefaultNla {
205            kind,
206            value: buf.value().to_vec(),
207        })
208    }
209}
210
211pub trait Nla {
212    fn value_len(&self) -> usize;
213    fn kind(&self) -> u16;
214    fn emit_value(&self, buffer: &mut [u8]);
215
216    #[inline]
217    fn is_nested(&self) -> bool {
218        (self.kind() & NLA_F_NESTED) != 0
219    }
220
221    #[inline]
222    fn is_network_byteorder(&self) -> bool {
223        (self.kind() & NLA_F_NET_BYTEORDER) != 0
224    }
225}
226
227impl<T: Nla> Emitable for T {
228    fn buffer_len(&self) -> usize {
229        nla_align!(self.value_len()) + NLA_HEADER_SIZE
230    }
231    fn emit(&self, buffer: &mut [u8]) {
232        let mut buffer = NlaBuffer::new(buffer);
233        buffer.set_kind(self.kind());
234
235        if self.is_network_byteorder() {
236            buffer.set_network_byte_order_flag()
237        }
238
239        if self.is_nested() {
240            buffer.set_nested_flag()
241        }
242
243        // do not include the padding here, but do include the header
244        buffer.set_length(self.value_len() as u16 + NLA_HEADER_SIZE as u16);
245
246        self.emit_value(buffer.value_mut());
247
248        let padding = nla_align!(self.value_len()) - self.value_len();
249        for i in 0..padding {
250            buffer.inner_mut()[NLA_HEADER_SIZE + self.value_len() + i] = 0;
251        }
252    }
253}
254
255// FIXME: whern specialization lands, why can actually have
256//
257// impl<'a, T: Nla, I: Iterator<Item=T>> Emitable for I { ...}
258//
259// The reason this does not work today is because it conflicts with
260//
261// impl<T: Nla> Emitable for T { ... }
262impl<T: Nla> Emitable for &[T] {
263    fn buffer_len(&self) -> usize {
264        self.iter().fold(0, |acc, nla| {
265            assert_eq!(nla.buffer_len() % NLA_ALIGNTO, 0);
266            acc + nla.buffer_len()
267        })
268    }
269
270    fn emit(&self, buffer: &mut [u8]) {
271        let mut start = 0;
272        let mut end: usize;
273        for nla in self.iter() {
274            let attr_len = nla.buffer_len();
275            assert_eq!(nla.buffer_len() % NLA_ALIGNTO, 0);
276            end = start + attr_len;
277            nla.emit(&mut buffer[start..end]);
278            start = end;
279        }
280    }
281}
282
283/// An iterator that iteratates over nlas without decoding them. This is useful
284/// when looking for specific nlas.
285#[derive(Debug, Clone, Copy, PartialEq, Eq)]
286pub struct NlasIterator<T> {
287    position: usize,
288    buffer: T,
289}
290
291impl<T> NlasIterator<T> {
292    pub fn new(buffer: T) -> Self {
293        NlasIterator {
294            position: 0,
295            buffer,
296        }
297    }
298}
299
300impl<'buffer, T: AsRef<[u8]> + ?Sized + 'buffer> Iterator
301    for NlasIterator<&'buffer T>
302{
303    type Item = Result<NlaBuffer<&'buffer [u8]>, DecodeError>;
304
305    fn next(&mut self) -> Option<Self::Item> {
306        if self.position >= self.buffer.as_ref().len() {
307            return None;
308        }
309
310        match NlaBuffer::new_checked(&self.buffer.as_ref()[self.position..]) {
311            Ok(nla_buffer) => {
312                self.position += nla_align!(nla_buffer.length() as usize);
313                Some(Ok(nla_buffer))
314            }
315            Err(e) => {
316                // Make sure next time we call `next()`, we return None. We
317                // don't try to continue iterating after we
318                // failed to return a buffer.
319                self.position = self.buffer.as_ref().len();
320                Some(Err(e))
321            }
322        }
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn network_byteorder() {
332        // The IPSET_ATTR_TIMEOUT attribute should have the network byte order
333        // flag set. IPSET_ATTR_TIMEOUT(3600)
334        static TEST_ATTRIBUTE: &[u8] =
335            &[0x08, 0x00, 0x06, 0x40, 0x00, 0x00, 0x0e, 0x10];
336        let buffer = NlaBuffer::new(TEST_ATTRIBUTE);
337        let buffer_is_net = buffer.network_byte_order_flag();
338        let buffer_is_nest = buffer.nested_flag();
339
340        let nla = DefaultNla::parse(&buffer).unwrap();
341        let mut emitted_buffer = vec![0; nla.buffer_len()];
342
343        nla.emit(&mut emitted_buffer);
344
345        let attr_is_net = nla.is_network_byteorder();
346        let attr_is_nest = nla.is_nested();
347
348        let emit = NlaBuffer::new(emitted_buffer);
349        let emit_is_net = emit.network_byte_order_flag();
350        let emit_is_nest = emit.nested_flag();
351
352        assert_eq!(
353            [buffer_is_net, buffer_is_nest],
354            [attr_is_net, attr_is_nest]
355        );
356        assert_eq!([attr_is_net, attr_is_nest], [emit_is_net, emit_is_nest]);
357    }
358
359    fn get_len() -> usize {
360        // usize::MAX
361        18446744073709551615
362    }
363
364    #[test]
365    fn test_align() {
366        assert_eq!(nla_align!(13), 16);
367        assert_eq!(nla_align!(16), 16);
368        assert_eq!(nla_align!(0), 0);
369        assert_eq!(nla_align!(1), 4);
370        assert_eq!(nla_align!(get_len() - 4), usize::MAX - 3);
371    }
372    #[test]
373    #[should_panic]
374    fn test_align_overflow() {
375        assert_eq!(nla_align!(get_len() - 3), usize::MAX);
376    }
377}