Skip to main content

aya_friday/maps/
lpm_trie.rs

1//! A LPM Trie.
2
3use std::{
4    borrow::{Borrow, BorrowMut},
5    marker::PhantomData,
6};
7
8use crate::{
9    Pod,
10    maps::{IterableMap, MapData, MapError, MapIter, MapKeys, check_kv_size, hash_map},
11};
12
13/// A Longest Prefix Match Trie.
14///
15/// # Minimum kernel version
16///
17/// The minimum kernel version required to use this feature is 4.20.
18///
19/// # Examples
20///
21/// ```no_run
22/// # let mut bpf = aya::Ebpf::load(&[])?;
23/// use aya::maps::lpm_trie::{LpmTrie, Key};
24/// use std::net::Ipv4Addr;
25///
26/// let mut trie = LpmTrie::try_from(bpf.map_mut("LPM_TRIE").unwrap())?;
27/// let ipaddr = Ipv4Addr::new(8, 8, 8, 8);
28/// // The following represents a key for the "8.8.8.8/16" subnet.
29/// // The first argument - the prefix length - represents how many bits should be matched against. The second argument is the actual data to be matched.
30/// let key = Key::new(16, u32::from(ipaddr).to_be());
31/// trie.insert(&key, 1, 0)?;
32///
33/// // LpmTrie matches against the longest (most accurate) key.
34/// let lookup = Key::new(32, u32::from(ipaddr).to_be());
35/// let value = trie.get(&lookup, 0)?;
36/// assert_eq!(value, 1);
37///
38/// // If we were to insert a key with longer 'prefix_len'
39/// // our trie should match against it.
40/// let longer_key = Key::new(24, u32::from(ipaddr).to_be());
41/// trie.insert(&longer_key, 2, 0)?;
42/// let value = trie.get(&lookup, 0)?;
43/// assert_eq!(value, 2);
44/// # Ok::<(), aya::EbpfError>(())
45/// ```
46#[doc(alias = "BPF_MAP_TYPE_LPM_TRIE")]
47#[derive(Debug)]
48pub struct LpmTrie<T, K, V> {
49    pub(crate) inner: T,
50    _kv: PhantomData<(K, V)>,
51}
52
53/// A Key for an [`LpmTrie`] map.
54///
55/// # Examples
56///
57/// ```no_run
58/// use aya::maps::lpm_trie::{LpmTrie, Key};
59/// use std::net::Ipv4Addr;
60///
61/// let ipaddr = Ipv4Addr::new(8,8,8,8);
62/// let key =  Key::new(16, u32::from(ipaddr).to_be());
63/// ```
64#[derive(Clone, Copy)]
65#[repr(C, packed)]
66pub struct Key<K> {
67    prefix_len: u32,
68    data: K,
69}
70
71impl<K: Pod> Key<K> {
72    /// Creates a new key.
73    ///
74    /// `prefix_len` is the number of bits in the data to match against.
75    /// `data` is the data in the key which is typically an IPv4 or IPv6 address.
76    /// If using a key to perform a longest prefix match on you would use a `prefix_len`
77    /// of 32 for IPv4 and 128 for IPv6.
78    ///
79    /// # Examples
80    ///
81    /// ```no_run
82    /// use aya::maps::lpm_trie::{LpmTrie, Key};
83    /// use std::net::Ipv4Addr;
84    ///
85    /// let ipaddr = Ipv4Addr::new(8, 8, 8, 8);
86    /// let key =  Key::new(16, u32::from(ipaddr).to_be());
87    /// ```
88    pub const fn new(prefix_len: u32, data: K) -> Self {
89        Self { prefix_len, data }
90    }
91
92    /// Returns the number of bits in the data to be matched.
93    pub const fn prefix_len(&self) -> u32 {
94        self.prefix_len
95    }
96
97    /// Returns the data stored in the Key.
98    pub const fn data(&self) -> K {
99        self.data
100    }
101}
102
103// A Pod impl is required as Key struct is a key for a map.
104unsafe impl<K: Pod> Pod for Key<K> {}
105
106impl<T: Borrow<MapData>, K: Pod, V: Pod> LpmTrie<T, K, V> {
107    pub(crate) fn new(map: T) -> Result<Self, MapError> {
108        let data = map.borrow();
109        check_kv_size::<Key<K>, V>(data)?;
110
111        Ok(Self {
112            inner: map,
113            _kv: PhantomData,
114        })
115    }
116
117    /// Returns a copy of the value associated with the longest prefix matching key in the [`LpmTrie`].
118    pub fn get(&self, key: &Key<K>, flags: u64) -> Result<V, MapError> {
119        hash_map::get(self.inner.borrow(), key, flags)
120    }
121
122    /// An iterator visiting all key-value pairs. The
123    /// iterator item type is `Result<(K, V), MapError>`.
124    pub fn iter(&self) -> MapIter<'_, Key<K>, V, Self> {
125        MapIter::new(self)
126    }
127
128    /// An iterator visiting all keys. The iterator element
129    /// type is `Result<Key<K>, MapError>`.
130    pub fn keys(&self) -> MapKeys<'_, Key<K>> {
131        MapKeys::new(self.inner.borrow())
132    }
133}
134
135impl<'a, T: Borrow<MapData>, K: Pod, V: Pod> IntoIterator for &'a LpmTrie<T, K, V> {
136    type Item = Result<(Key<K>, V), MapError>;
137    type IntoIter = MapIter<'a, Key<K>, V, LpmTrie<T, K, V>>;
138
139    fn into_iter(self) -> Self::IntoIter {
140        self.iter()
141    }
142}
143
144impl<T: BorrowMut<MapData>, K: Pod, V: Pod> LpmTrie<T, K, V> {
145    /// Inserts a key value pair into the map.
146    pub fn insert(
147        &mut self,
148        key: &Key<K>,
149        value: impl Borrow<V>,
150        flags: u64,
151    ) -> Result<(), MapError> {
152        hash_map::insert(self.inner.borrow_mut(), key, value.borrow(), flags)
153    }
154
155    /// Removes an element from the map.
156    ///
157    /// Both the prefix and data must match exactly - this method does not do a longest prefix match.
158    pub fn remove(&mut self, key: &Key<K>) -> Result<(), MapError> {
159        hash_map::remove(self.inner.borrow_mut(), key)
160    }
161}
162
163impl<T: Borrow<MapData>, K: Pod, V: Pod> IterableMap<Key<K>, V> for LpmTrie<T, K, V> {
164    fn map(&self) -> &MapData {
165        self.inner.borrow()
166    }
167
168    fn get(&self, key: &Key<K>) -> Result<V, MapError> {
169        self.get(key, 0)
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use std::{io, net::Ipv4Addr};
176
177    use assert_matches::assert_matches;
178    use aya_obj::generated::{bpf_cmd, bpf_map_type};
179    use libc::{EFAULT, ENOENT};
180
181    use super::*;
182    use crate::{
183        maps::{
184            Map,
185            test_utils::{self, new_map},
186        },
187        sys::{SysResult, Syscall, SyscallError, override_syscall},
188    };
189
190    fn new_obj_map() -> aya_obj::Map {
191        test_utils::new_obj_map::<Key<u32>>(bpf_map_type::BPF_MAP_TYPE_LPM_TRIE)
192    }
193
194    fn sys_error(value: i32) -> SysResult {
195        Err((-1, io::Error::from_raw_os_error(value)))
196    }
197
198    #[test]
199    fn test_wrong_key_size() {
200        let map = new_map(new_obj_map());
201        assert_matches!(
202            LpmTrie::<_, u16, u32>::new(&map),
203            Err(MapError::InvalidKeySize {
204                size: 6,
205                expected: 8 // four bytes for prefixlen and four bytes for data.
206            })
207        );
208    }
209
210    #[test]
211    fn test_wrong_value_size() {
212        let map = new_map(new_obj_map());
213        assert_matches!(
214            LpmTrie::<_, u32, u16>::new(&map),
215            Err(MapError::InvalidValueSize {
216                size: 2,
217                expected: 4
218            })
219        );
220    }
221
222    #[test]
223    fn test_try_from_wrong_map() {
224        let map = new_map(test_utils::new_obj_map::<u32>(
225            bpf_map_type::BPF_MAP_TYPE_ARRAY,
226        ));
227        let map = Map::Array(map);
228
229        assert_matches!(
230            LpmTrie::<_, u32, u32>::try_from(&map),
231            Err(MapError::InvalidMapType { .. })
232        );
233    }
234
235    #[test]
236    fn test_new_ok() {
237        let map = new_map(new_obj_map());
238
239        let _: LpmTrie<_, u32, u32> = LpmTrie::new(&map).unwrap();
240    }
241
242    #[test]
243    fn test_try_from_ok() {
244        let map = new_map(new_obj_map());
245
246        let map = Map::LpmTrie(map);
247        let _unused: LpmTrie<_, u32, u32> = map.try_into().unwrap();
248    }
249
250    #[test]
251    fn test_insert_syscall_error() {
252        let mut map = new_map(new_obj_map());
253        let mut trie = LpmTrie::<_, u32, u32>::new(&mut map).unwrap();
254        let ipaddr = Ipv4Addr::new(8, 8, 8, 8);
255        let key = Key::new(16, u32::from(ipaddr).to_be());
256
257        override_syscall(|_| sys_error(EFAULT));
258
259        assert_matches!(
260            trie.insert(&key, 1, 0),
261            Err(MapError::SyscallError(SyscallError { call: "bpf_map_update_elem", io_error })) if io_error.raw_os_error() == Some(EFAULT)
262        );
263    }
264
265    #[test]
266    fn test_insert_ok() {
267        let mut map = new_map(new_obj_map());
268        let mut trie = LpmTrie::<_, u32, u32>::new(&mut map).unwrap();
269        let ipaddr = Ipv4Addr::new(8, 8, 8, 8);
270        let key = Key::new(16, u32::from(ipaddr).to_be());
271
272        override_syscall(|call| match call {
273            Syscall::Ebpf {
274                cmd: bpf_cmd::BPF_MAP_UPDATE_ELEM,
275                ..
276            } => Ok(0),
277            _ => sys_error(EFAULT),
278        });
279
280        assert_matches!(trie.insert(&key, 1, 0), Ok(()));
281    }
282
283    #[test]
284    fn test_remove_syscall_error() {
285        let mut map = new_map(new_obj_map());
286        let mut trie = LpmTrie::<_, u32, u32>::new(&mut map).unwrap();
287        let ipaddr = Ipv4Addr::new(8, 8, 8, 8);
288        let key = Key::new(16, u32::from(ipaddr).to_be());
289
290        override_syscall(|_| sys_error(EFAULT));
291
292        assert_matches!(
293            trie.remove(&key),
294            Err(MapError::SyscallError(SyscallError { call: "bpf_map_delete_elem", io_error })) if io_error.raw_os_error() == Some(EFAULT)
295        );
296    }
297
298    #[test]
299    fn test_remove_ok() {
300        let mut map = new_map(new_obj_map());
301        let mut trie = LpmTrie::<_, u32, u32>::new(&mut map).unwrap();
302        let ipaddr = Ipv4Addr::new(8, 8, 8, 8);
303        let key = Key::new(16, u32::from(ipaddr).to_be());
304
305        override_syscall(|call| match call {
306            Syscall::Ebpf {
307                cmd: bpf_cmd::BPF_MAP_DELETE_ELEM,
308                ..
309            } => Ok(0),
310            _ => sys_error(EFAULT),
311        });
312
313        assert_matches!(trie.remove(&key), Ok(()));
314    }
315
316    #[test]
317    fn test_get_syscall_error() {
318        let map = new_map(new_obj_map());
319        let trie = LpmTrie::<_, u32, u32>::new(&map).unwrap();
320        let ipaddr = Ipv4Addr::new(8, 8, 8, 8);
321        let key = Key::new(16, u32::from(ipaddr).to_be());
322
323        override_syscall(|_| sys_error(EFAULT));
324
325        assert_matches!(
326            trie.get(&key, 0),
327            Err(MapError::SyscallError(SyscallError { call: "bpf_map_lookup_elem", io_error })) if io_error.raw_os_error() == Some(EFAULT)
328        );
329    }
330
331    #[test]
332    fn test_get_not_found() {
333        let map = new_map(new_obj_map());
334        let trie = LpmTrie::<_, u32, u32>::new(&map).unwrap();
335        let ipaddr = Ipv4Addr::new(8, 8, 8, 8);
336        let key = Key::new(16, u32::from(ipaddr).to_be());
337
338        override_syscall(|call| match call {
339            Syscall::Ebpf {
340                cmd: bpf_cmd::BPF_MAP_LOOKUP_ELEM,
341                ..
342            } => sys_error(ENOENT),
343            _ => sys_error(EFAULT),
344        });
345
346        assert_matches!(trie.get(&key, 0), Err(MapError::KeyNotFound));
347    }
348}