netlink_packet_utils/
nla.rs

1// SPDX-License-Identifier: MIT
2
3use crate::{
4    traits::{Emitable, Parseable},
5    DecodeError,
6};
7use byteorder::{ByteOrder, NativeEndian};
8use core::ops::Range;
9use thiserror::Error;
10
11/// Represent a multi-bytes field with a fixed size in a packet
12type Field = Range<usize>;
13
14/// Identify the bits that represent the "nested" flag of a netlink attribute.
15pub const NLA_F_NESTED: u16 = 0x8000;
16/// Identify the bits that represent the "byte order" flag of a netlink
17/// attribute.
18pub const NLA_F_NET_BYTEORDER: u16 = 0x4000;
19/// Identify the bits that represent the type of a netlink attribute.
20pub const NLA_TYPE_MASK: u16 = !(NLA_F_NET_BYTEORDER | NLA_F_NESTED);
21/// NlA(RTA) align size
22pub const NLA_ALIGNTO: usize = 4;
23/// NlA(RTA) header size. (unsigned short rta_len) + (unsigned short rta_type)
24pub const NLA_HEADER_SIZE: usize = 4;
25
26#[derive(Debug, Error)]
27pub enum NlaError {
28    #[error("buffer has length {buffer_len}, but an NLA header is {} bytes", TYPE.end)]
29    BufferTooSmall { buffer_len: usize },
30
31    #[error("buffer has length: {buffer_len}, but the NLA is {nla_len} bytes")]
32    LengthMismatch { buffer_len: usize, nla_len: u16 },
33
34    #[error(
35        "NLA has invalid length: {nla_len} (should be at least {} bytes", TYPE.end
36    )]
37    InvalidLength { nla_len: u16 },
38}
39
40#[macro_export]
41macro_rules! nla_align {
42    ($len: expr) => {
43        ($len + NLA_ALIGNTO - 1) & !(NLA_ALIGNTO - 1)
44    };
45}
46
47const LENGTH: Field = 0..2;
48const TYPE: Field = 2..4;
49#[allow(non_snake_case)]
50fn VALUE(length: usize) -> Field {
51    TYPE.end..TYPE.end + length
52}
53
54// with Copy, NlaBuffer<&'buffer T> can be copied, which turns out to be pretty
55// conveninent. And since it's boils down to copying a reference it's pretty
56// cheap
57#[derive(Debug, PartialEq, Eq, Clone, Copy)]
58pub struct NlaBuffer<T: AsRef<[u8]>> {
59    buffer: T,
60}
61
62impl<T: AsRef<[u8]>> NlaBuffer<T> {
63    pub fn new(buffer: T) -> NlaBuffer<T> {
64        NlaBuffer { buffer }
65    }
66
67    pub fn new_checked(buffer: T) -> Result<NlaBuffer<T>, NlaError> {
68        let buffer = Self::new(buffer);
69        buffer.check_buffer_length()?;
70        Ok(buffer)
71    }
72
73    pub fn check_buffer_length(&self) -> Result<(), NlaError> {
74        let len = self.buffer.as_ref().len();
75        if len < TYPE.end {
76            Err(NlaError::BufferTooSmall { buffer_len: len })
77        } else if len < self.length() as usize {
78            Err(NlaError::LengthMismatch {
79                buffer_len: len,
80                nla_len: self.length(),
81            })
82        } else if (self.length() as usize) < TYPE.end {
83            Err(NlaError::InvalidLength {
84                nla_len: self.length(),
85            })
86        } else {
87            Ok(())
88        }
89    }
90
91    /// Consume the buffer, returning the underlying buffer.
92    pub fn into_inner(self) -> T {
93        self.buffer
94    }
95
96    /// Return a reference to the underlying buffer
97    pub fn inner(&mut self) -> &T {
98        &self.buffer
99    }
100
101    /// Return a mutable reference to the underlying buffer
102    pub fn inner_mut(&mut self) -> &mut T {
103        &mut self.buffer
104    }
105
106    /// Return the `type` field
107    pub fn kind(&self) -> u16 {
108        let data = self.buffer.as_ref();
109        NativeEndian::read_u16(&data[TYPE]) & NLA_TYPE_MASK
110    }
111
112    pub fn nested_flag(&self) -> bool {
113        let data = self.buffer.as_ref();
114        (NativeEndian::read_u16(&data[TYPE]) & NLA_F_NESTED) != 0
115    }
116
117    pub fn network_byte_order_flag(&self) -> bool {
118        let data = self.buffer.as_ref();
119        (NativeEndian::read_u16(&data[TYPE]) & NLA_F_NET_BYTEORDER) != 0
120    }
121
122    /// Return the `length` field. The `length` field corresponds to the length
123    /// of the nla header (type and length fields, and the value field).
124    /// However, it does not account for the potential padding that follows
125    /// the value field.
126    pub fn length(&self) -> u16 {
127        let data = self.buffer.as_ref();
128        NativeEndian::read_u16(&data[LENGTH])
129    }
130
131    /// Return the length of the `value` field
132    ///
133    /// # Panic
134    ///
135    /// This panics if the length field value is less than the attribut header
136    /// size.
137    pub fn value_length(&self) -> usize {
138        self.length() as usize - TYPE.end
139    }
140}
141
142impl<T: AsRef<[u8]> + AsMut<[u8]>> NlaBuffer<T> {
143    /// Set the `type` field
144    pub fn set_kind(&mut self, kind: u16) {
145        let data = self.buffer.as_mut();
146        NativeEndian::write_u16(&mut data[TYPE], kind & NLA_TYPE_MASK)
147    }
148
149    pub fn set_nested_flag(&mut self) {
150        let kind = self.kind();
151        let data = self.buffer.as_mut();
152        NativeEndian::write_u16(&mut data[TYPE], kind | NLA_F_NESTED)
153    }
154
155    pub fn set_network_byte_order_flag(&mut self) {
156        let kind = self.kind();
157        let data = self.buffer.as_mut();
158        NativeEndian::write_u16(&mut data[TYPE], kind | NLA_F_NET_BYTEORDER)
159    }
160
161    /// Set the `length` field
162    pub fn set_length(&mut self, length: u16) {
163        let data = self.buffer.as_mut();
164        NativeEndian::write_u16(&mut data[LENGTH], length)
165    }
166}
167
168impl<T: AsRef<[u8]> + ?Sized> NlaBuffer<&T> {
169    /// Return the `value` field
170    pub fn value(&self) -> &[u8] {
171        &self.buffer.as_ref()[VALUE(self.value_length())]
172    }
173}
174
175impl<T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> NlaBuffer<&mut T> {
176    /// Return the `value` field
177    pub fn value_mut(&mut self) -> &mut [u8] {
178        let length = VALUE(self.value_length());
179        &mut self.buffer.as_mut()[length]
180    }
181}
182
183#[derive(Debug, PartialEq, Eq, Clone)]
184pub struct DefaultNla {
185    kind: u16,
186    value: Vec<u8>,
187}
188
189impl DefaultNla {
190    pub fn new(kind: u16, value: Vec<u8>) -> Self {
191        Self { kind, value }
192    }
193}
194
195impl Nla for DefaultNla {
196    fn value_len(&self) -> usize {
197        self.value.len()
198    }
199    fn kind(&self) -> u16 {
200        self.kind
201    }
202    fn emit_value(&self, buffer: &mut [u8]) {
203        buffer.copy_from_slice(self.value.as_slice());
204    }
205}
206
207impl<'buffer, T: AsRef<[u8]> + ?Sized> Parseable<NlaBuffer<&'buffer T>>
208    for DefaultNla
209{
210    type Error = DecodeError;
211
212    fn parse(buf: &NlaBuffer<&'buffer T>) -> Result<Self, Self::Error> {
213        let mut kind = buf.kind();
214
215        if buf.network_byte_order_flag() {
216            kind |= NLA_F_NET_BYTEORDER;
217        }
218
219        if buf.nested_flag() {
220            kind |= NLA_F_NESTED;
221        }
222
223        Ok(DefaultNla {
224            kind,
225            value: buf.value().to_vec(),
226        })
227    }
228}
229
230pub trait Nla {
231    fn value_len(&self) -> usize;
232    fn kind(&self) -> u16;
233    fn emit_value(&self, buffer: &mut [u8]);
234
235    #[inline]
236    fn is_nested(&self) -> bool {
237        (self.kind() & NLA_F_NESTED) != 0
238    }
239
240    #[inline]
241    fn is_network_byteorder(&self) -> bool {
242        (self.kind() & NLA_F_NET_BYTEORDER) != 0
243    }
244}
245
246impl<T: Nla> Emitable for T {
247    fn buffer_len(&self) -> usize {
248        nla_align!(self.value_len()) + NLA_HEADER_SIZE
249    }
250    fn emit(&self, buffer: &mut [u8]) {
251        let mut buffer = NlaBuffer::new(buffer);
252        buffer.set_kind(self.kind());
253
254        if self.is_network_byteorder() {
255            buffer.set_network_byte_order_flag()
256        }
257
258        if self.is_nested() {
259            buffer.set_nested_flag()
260        }
261
262        // do not include the padding here, but do include the header
263        buffer.set_length(self.value_len() as u16 + NLA_HEADER_SIZE as u16);
264
265        self.emit_value(buffer.value_mut());
266
267        let padding = nla_align!(self.value_len()) - self.value_len();
268        for i in 0..padding {
269            buffer.inner_mut()[NLA_HEADER_SIZE + self.value_len() + i] = 0;
270        }
271    }
272}
273
274// FIXME: whern specialization lands, why can actually have
275//
276// impl<'a, T: Nla, I: Iterator<Item=T>> Emitable for I { ...}
277//
278// The reason this does not work today is because it conflicts with
279//
280// impl<T: Nla> Emitable for T { ... }
281impl<T: Nla> Emitable for &[T] {
282    fn buffer_len(&self) -> usize {
283        self.iter().fold(0, |acc, nla| {
284            assert_eq!(nla.buffer_len() % NLA_ALIGNTO, 0);
285            acc + nla.buffer_len()
286        })
287    }
288
289    fn emit(&self, buffer: &mut [u8]) {
290        let mut start = 0;
291        let mut end: usize;
292        for nla in self.iter() {
293            let attr_len = nla.buffer_len();
294            assert_eq!(nla.buffer_len() % NLA_ALIGNTO, 0);
295            end = start + attr_len;
296            nla.emit(&mut buffer[start..end]);
297            start = end;
298        }
299    }
300}
301
302/// An iterator that iteratates over nlas without decoding them. This is useful
303/// when looking for specific nlas.
304#[derive(Debug, Clone, Copy, PartialEq, Eq)]
305pub struct NlasIterator<T> {
306    position: usize,
307    buffer: T,
308}
309
310impl<T> NlasIterator<T> {
311    pub fn new(buffer: T) -> Self {
312        NlasIterator {
313            position: 0,
314            buffer,
315        }
316    }
317}
318
319impl<'buffer, T: AsRef<[u8]> + ?Sized + 'buffer> Iterator
320    for NlasIterator<&'buffer T>
321{
322    type Item = Result<NlaBuffer<&'buffer [u8]>, NlaError>;
323
324    fn next(&mut self) -> Option<Self::Item> {
325        if self.position >= self.buffer.as_ref().len() {
326            return None;
327        }
328
329        match NlaBuffer::new_checked(&self.buffer.as_ref()[self.position..]) {
330            Ok(nla_buffer) => {
331                self.position += nla_align!(nla_buffer.length() as usize);
332                Some(Ok(nla_buffer))
333            }
334            Err(e) => {
335                // Make sure next time we call `next()`, we return None. We
336                // don't try to continue iterating after we
337                // failed to return a buffer.
338                self.position = self.buffer.as_ref().len();
339                Some(Err(e))
340            }
341        }
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    #[test]
350    fn network_byteorder() {
351        // The IPSET_ATTR_TIMEOUT attribute should have the network byte order
352        // flag set. IPSET_ATTR_TIMEOUT(3600)
353        static TEST_ATTRIBUTE: &[u8] =
354            &[0x08, 0x00, 0x06, 0x40, 0x00, 0x00, 0x0e, 0x10];
355        let buffer = NlaBuffer::new(TEST_ATTRIBUTE);
356        let buffer_is_net = buffer.network_byte_order_flag();
357        let buffer_is_nest = buffer.nested_flag();
358
359        let nla = DefaultNla::parse(&buffer).unwrap();
360        let mut emitted_buffer = vec![0; nla.buffer_len()];
361
362        nla.emit(&mut emitted_buffer);
363
364        let attr_is_net = nla.is_network_byteorder();
365        let attr_is_nest = nla.is_nested();
366
367        let emit = NlaBuffer::new(emitted_buffer);
368        let emit_is_net = emit.network_byte_order_flag();
369        let emit_is_nest = emit.nested_flag();
370
371        assert_eq!(
372            [buffer_is_net, buffer_is_nest],
373            [attr_is_net, attr_is_nest]
374        );
375        assert_eq!([attr_is_net, attr_is_nest], [emit_is_net, emit_is_nest]);
376    }
377
378    fn get_len() -> usize {
379        // usize::MAX
380        18446744073709551615
381    }
382
383    #[test]
384    fn test_align() {
385        assert_eq!(nla_align!(13), 16);
386        assert_eq!(nla_align!(16), 16);
387        assert_eq!(nla_align!(0), 0);
388        assert_eq!(nla_align!(1), 4);
389        assert_eq!(nla_align!(get_len() - 4), usize::MAX - 3);
390    }
391    #[test]
392    #[should_panic]
393    fn test_align_overflow() {
394        assert_eq!(nla_align!(get_len() - 3), usize::MAX);
395    }
396}