Skip to main content

pkarr_client/extra/
lmdb_cache.rs

1//! Persistent [crate::Cache] implementation using LMDB's bindings [heed]
2
3use std::{
4    borrow::Cow,
5    fmt::Debug,
6    fs,
7    path::Path,
8    sync::{Arc, RwLock},
9};
10
11use byteorder::BigEndian;
12use heed::{
13    types::U64, BoxedError, BytesDecode, BytesEncode, Database, Env, EnvOpenOptions, RwTxn,
14};
15
16use tracing::debug;
17
18use ntimestamp::Timestamp;
19
20use crate::{Cache, CacheKey, SignedPacket};
21
22const MAX_MAP_SIZE: usize = 10995116277760; // 10 TB
23const MIN_MAP_SIZE: usize = 10 * 1024 * 1024; // 10 mb
24
25const SIGNED_PACKET_TABLE: &str = "pkarrcache:signed_packet";
26const KEY_TO_TIME_TABLE: &str = "pkarrcache:key_to_time";
27const TIME_TO_KEY_TABLE: &str = "pkarrcache:time_to_key";
28
29type SignedPacketsTable = Database<CacheKeyCodec, SignedPacket>;
30type KeyToTimeTable = Database<CacheKeyCodec, U64<BigEndian>>;
31type TimeToKeyTable = Database<U64<BigEndian>, CacheKeyCodec>;
32
33/// A wrapper for [CacheKey] to implement [BytesEncode] and [BytesDecode].
34pub struct CacheKeyCodec;
35
36impl<'a> BytesEncode<'a> for CacheKeyCodec {
37    type EItem = CacheKey;
38
39    fn bytes_encode(key: &Self::EItem) -> Result<Cow<'a, [u8]>, BoxedError> {
40        Ok(Cow::Owned(key.to_vec()))
41    }
42}
43
44impl<'a> BytesDecode<'a> for CacheKeyCodec {
45    type DItem = CacheKey;
46
47    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, BoxedError> {
48        let key: [u8; 20] = bytes.try_into()?;
49        Ok(key)
50    }
51}
52
53impl<'a> BytesEncode<'a> for SignedPacket {
54    type EItem = SignedPacket;
55
56    fn bytes_encode(signed_packet: &Self::EItem) -> Result<Cow<'a, [u8]>, BoxedError> {
57        Ok(Cow::Owned(signed_packet.serialize().to_vec()))
58    }
59}
60
61impl<'a> BytesDecode<'a> for SignedPacket {
62    type DItem = SignedPacket;
63
64    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, BoxedError> {
65        Ok(SignedPacket::deserialize(bytes)?)
66    }
67}
68
69#[derive(Clone)]
70/// Persistent [crate::Cache] implementation using LMDB's bindings [heed]
71pub struct LmdbCache {
72    capacity: usize,
73    env: Env,
74    signed_packets_table: SignedPacketsTable,
75    key_to_time_table: KeyToTimeTable,
76    time_to_key_table: TimeToKeyTable,
77    batch: Arc<RwLock<Vec<CacheKey>>>,
78}
79
80impl Debug for LmdbCache {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        f.debug_struct("LmdbCache")
83            .field("capacity", &self.capacity)
84            .field("env", &self.env)
85            .finish_non_exhaustive()
86    }
87}
88
89impl LmdbCache {
90    /// Creates a new [LmdbCache] at the `env_path` and set the [heed::EnvOpenOptions::map_size]
91    /// to a multiple of the `capacity` by [SignedPacket::MAX_BYTES], aligned to system's page size,
92    /// a maximum of 10 TB, and a minimum of 10 MB.
93    ///
94    /// # Safety
95    /// LmdbCache uses LMDB, [opening][heed::EnvOpenOptions::open] which is marked unsafe,
96    /// because the possible Undefined Behavior (UB) if the lock file is broken.
97    pub unsafe fn open(env_path: &Path, capacity: usize) -> Result<Self, Error> {
98        let page_size = page_size::get();
99
100        // Page aligned but more than enough bytes for `capacity` many SignedPacket
101        let map_size = capacity
102            .checked_mul(SignedPacket::MAX_BYTES as usize)
103            .and_then(|x| x.checked_add(page_size))
104            .and_then(|x| x.checked_div(page_size))
105            .and_then(|x| x.checked_mul(page_size))
106            .unwrap_or(MAX_MAP_SIZE)
107            .max(MIN_MAP_SIZE);
108
109        fs::create_dir_all(env_path)?;
110
111        let env = unsafe {
112            EnvOpenOptions::new()
113                .map_size(map_size)
114                .max_dbs(3)
115                .open(env_path)?
116        };
117
118        let mut wtxn = env.write_txn()?;
119
120        let signed_packets_table: SignedPacketsTable =
121            env.create_database(&mut wtxn, Some(SIGNED_PACKET_TABLE))?;
122        let key_to_time_table: KeyToTimeTable =
123            env.create_database(&mut wtxn, Some(KEY_TO_TIME_TABLE))?;
124        let time_to_key_table: TimeToKeyTable =
125            env.create_database(&mut wtxn, Some(TIME_TO_KEY_TABLE))?;
126
127        wtxn.commit()?;
128
129        let instance = Self {
130            capacity,
131            env,
132            signed_packets_table,
133            key_to_time_table,
134            time_to_key_table,
135            batch: Arc::new(RwLock::new(vec![])),
136        };
137
138        Ok(instance)
139    }
140
141    /// Convenient wrapper around [Self::open].
142    ///
143    /// Make sure to read the safety section in [Self::open]
144    pub fn open_unsafe(env_path: &Path, capacity: usize) -> Result<Self, Error> {
145        unsafe { Self::open(env_path, capacity) }
146    }
147
148    fn internal_len(&self) -> Result<usize, heed::Error> {
149        let rtxn = self.env.read_txn()?;
150        let len = self.signed_packets_table.len(&rtxn)? as usize;
151        rtxn.commit()?;
152
153        Ok(len)
154    }
155
156    fn internal_put(
157        &self,
158        key: &CacheKey,
159        signed_packet: &SignedPacket,
160    ) -> Result<(), heed::Error> {
161        if self.capacity == 0 {
162            return Ok(());
163        }
164
165        let mut wtxn = self.env.write_txn()?;
166
167        let packets = self.signed_packets_table;
168        let key_to_time = self.key_to_time_table;
169        let time_to_key = self.time_to_key_table;
170
171        let mut batch = self.batch.write().expect("LmdbCache::batch.write()");
172        update_lru(&mut wtxn, packets, key_to_time, time_to_key, &batch)?;
173
174        let len = packets.len(&wtxn)? as usize;
175
176        if len >= self.capacity {
177            debug!(?len, ?self.capacity, "Reached cache capacity, deleting extra item.");
178
179            let mut iter = time_to_key.iter(&wtxn)?;
180
181            if let Some((time, key)) = iter.next().transpose()? {
182                drop(iter);
183
184                time_to_key.delete(&mut wtxn, &time)?;
185                key_to_time.delete(&mut wtxn, &key)?;
186                packets.delete(&mut wtxn, &key)?;
187            };
188        }
189
190        batch.clear();
191
192        if let Some(old_time) = key_to_time.get(&wtxn, key)? {
193            time_to_key.delete(&mut wtxn, &old_time)?;
194        }
195
196        let new_time = Timestamp::now();
197
198        time_to_key.put(&mut wtxn, &new_time.as_u64(), key)?;
199        key_to_time.put(&mut wtxn, key, &new_time.as_u64())?;
200
201        packets.put(&mut wtxn, key, signed_packet)?;
202
203        wtxn.commit()?;
204
205        Ok(())
206    }
207
208    fn internal_get(&self, key: &CacheKey) -> Result<Option<SignedPacket>, heed::Error> {
209        self.batch
210            .write()
211            .expect("LmdbCache::batch.write()")
212            .push(*key);
213
214        self.internal_get_read_only(key)
215    }
216
217    fn internal_get_read_only(&self, key: &CacheKey) -> Result<Option<SignedPacket>, heed::Error> {
218        let rtxn = self.env.read_txn()?;
219
220        if let Some(signed_packet) = self.signed_packets_table.get(&rtxn, key)? {
221            return Ok(Some(signed_packet));
222        }
223
224        rtxn.commit()?;
225
226        Ok(None)
227    }
228}
229
230fn update_lru(
231    wtxn: &mut RwTxn,
232    packets: SignedPacketsTable,
233    key_to_time: KeyToTimeTable,
234    time_to_key: TimeToKeyTable,
235    to_update: &[CacheKey],
236) -> Result<(), heed::Error> {
237    for key in to_update {
238        if packets.get(wtxn, key)?.is_some() {
239            if let Some(time) = key_to_time.get(wtxn, key)? {
240                time_to_key.delete(wtxn, &time)?;
241            };
242
243            let new_time = Timestamp::now();
244
245            time_to_key.put(wtxn, &new_time.as_u64(), key)?;
246            key_to_time.put(wtxn, key, &new_time.as_u64())?;
247        }
248    }
249
250    Ok(())
251}
252
253impl Cache for LmdbCache {
254    fn capacity(&self) -> usize {
255        self.capacity
256    }
257
258    fn len(&self) -> usize {
259        match self.internal_len() {
260            Ok(result) => result,
261            Err(error) => {
262                debug!(?error, "Error in LmdbCache::len");
263                0
264            }
265        }
266    }
267
268    fn put(&self, key: &CacheKey, signed_packet: &SignedPacket) {
269        if let Err(error) = self.internal_put(key, signed_packet) {
270            debug!(?error, "Error in LmdbCache::put");
271        };
272    }
273
274    fn get(&self, key: &CacheKey) -> Option<SignedPacket> {
275        match self.internal_get(key) {
276            Ok(result) => result,
277            Err(error) => {
278                debug!(?error, "Error in LmdbCache::get");
279
280                None
281            }
282        }
283    }
284
285    fn get_read_only(&self, key: &CacheKey) -> Option<SignedPacket> {
286        match self.internal_get_read_only(key) {
287            Ok(result) => result,
288            Err(error) => {
289                debug!(?error, "Error in LmdbCache::get");
290
291                None
292            }
293        }
294    }
295}
296
297#[derive(thiserror::Error, Debug)]
298/// Pkarr crate error enum.
299pub enum Error {
300    #[error(transparent)]
301    /// Transparent [heed::Error]
302    Lmdb(#[from] heed::Error),
303
304    #[error(transparent)]
305    /// Transparent [std::io::Error]
306    IO(#[from] std::io::Error),
307}
308
309#[cfg(test)]
310mod tests {
311    use crate::Keypair;
312
313    use super::*;
314
315    #[test]
316    fn max_map_size() {
317        let env_path = std::env::temp_dir().join(Timestamp::now().to_string());
318
319        LmdbCache::open_unsafe(&env_path, usize::MAX).unwrap();
320    }
321
322    #[test]
323    fn lru_capacity() {
324        let env_path = std::env::temp_dir().join(Timestamp::now().to_string());
325
326        let cache = LmdbCache::open_unsafe(&env_path, 2).unwrap();
327
328        let mut keys = vec![];
329
330        for i in 0..2 {
331            let signed_packet = SignedPacket::builder()
332                .txt("foo".try_into().unwrap(), "bar".try_into().unwrap(), i)
333                .sign(&Keypair::random())
334                .unwrap();
335
336            let key = CacheKey::from(signed_packet.public_key());
337            cache.put(&key, &signed_packet);
338
339            keys.push((key, signed_packet));
340        }
341
342        assert_eq!(
343            cache.get_read_only(&keys.first().unwrap().0).unwrap(),
344            keys.first().unwrap().1,
345            "first key saved"
346        );
347        assert_eq!(
348            cache.get_read_only(&keys.last().unwrap().0).unwrap(),
349            keys.last().unwrap().1,
350            "second key saved"
351        );
352
353        assert_eq!(cache.len(), 2);
354
355        // Put another one, effectively deleting the oldest.
356        let signed_packet = SignedPacket::builder()
357            .txt("foo".try_into().unwrap(), "bar".try_into().unwrap(), 3)
358            .sign(&Keypair::random())
359            .unwrap();
360        let key = CacheKey::from(signed_packet.public_key());
361        cache.put(&key, &signed_packet);
362
363        assert_eq!(cache.len(), 2);
364
365        assert!(
366            cache.get_read_only(&keys.first().unwrap().0).is_none(),
367            "oldest key dropped"
368        );
369        assert_eq!(
370            cache.get_read_only(&keys.last().unwrap().0).unwrap(),
371            keys.last().unwrap().1,
372            "more recent key survived"
373        );
374        assert_eq!(
375            cache.get_read_only(&key).unwrap(),
376            signed_packet,
377            "most recent key survived"
378        )
379    }
380
381    #[test]
382    fn lru_capacity_refresh_oldest() {
383        let env_path = std::env::temp_dir().join(Timestamp::now().to_string());
384
385        let cache = LmdbCache::open_unsafe(&env_path, 2).unwrap();
386
387        let mut keys = vec![];
388
389        for i in 0..2 {
390            let signed_packet = SignedPacket::builder()
391                .txt("foo".try_into().unwrap(), "bar".try_into().unwrap(), i)
392                .sign(&Keypair::random())
393                .unwrap();
394
395            let key = CacheKey::from(signed_packet.public_key());
396            cache.put(&key, &signed_packet);
397
398            keys.push((key, signed_packet));
399        }
400
401        assert_eq!(
402            cache.get_read_only(&keys.first().unwrap().0).unwrap(),
403            keys.first().unwrap().1,
404            "first key saved"
405        );
406        assert_eq!(
407            cache.get_read_only(&keys.last().unwrap().0).unwrap(),
408            keys.last().unwrap().1,
409            "second key saved"
410        );
411
412        // refresh the oldest
413        cache.get(&keys.first().unwrap().0).unwrap();
414
415        assert_eq!(cache.len(), 2);
416
417        // Put another one, effectively deleting the oldest.
418        let signed_packet = SignedPacket::builder()
419            .txt("foo".try_into().unwrap(), "bar".try_into().unwrap(), 3)
420            .sign(&Keypair::random())
421            .unwrap();
422        let key = CacheKey::from(signed_packet.public_key());
423        cache.put(&key, &signed_packet);
424
425        assert_eq!(cache.len(), 2);
426
427        assert!(
428            cache.get_read_only(&keys.last().unwrap().0).is_none(),
429            "oldest key dropped"
430        );
431        assert_eq!(
432            cache.get_read_only(&keys.first().unwrap().0).unwrap(),
433            keys.first().unwrap().1,
434            "refreshed key survived"
435        );
436        assert_eq!(
437            cache.get_read_only(&key).unwrap(),
438            signed_packet,
439            "most recent key survived"
440        )
441    }
442}