1use std::{
4 borrow::Cow,
5 fmt::Debug,
6 fs,
7 path::Path,
8 sync::{Arc, RwLock},
9};
10
11use byteorder::BigEndian;
12use heed::{
13 types::U64, BoxedError, BytesDecode, BytesEncode, Database, Env, EnvOpenOptions, RwTxn,
14};
15
16use tracing::debug;
17
18use ntimestamp::Timestamp;
19
20use crate::{Cache, CacheKey, SignedPacket};
21
22const MAX_MAP_SIZE: usize = 10995116277760; const MIN_MAP_SIZE: usize = 10 * 1024 * 1024; const SIGNED_PACKET_TABLE: &str = "pkarrcache:signed_packet";
26const KEY_TO_TIME_TABLE: &str = "pkarrcache:key_to_time";
27const TIME_TO_KEY_TABLE: &str = "pkarrcache:time_to_key";
28
29type SignedPacketsTable = Database<CacheKeyCodec, SignedPacket>;
30type KeyToTimeTable = Database<CacheKeyCodec, U64<BigEndian>>;
31type TimeToKeyTable = Database<U64<BigEndian>, CacheKeyCodec>;
32
33pub struct CacheKeyCodec;
35
36impl<'a> BytesEncode<'a> for CacheKeyCodec {
37 type EItem = CacheKey;
38
39 fn bytes_encode(key: &Self::EItem) -> Result<Cow<'a, [u8]>, BoxedError> {
40 Ok(Cow::Owned(key.to_vec()))
41 }
42}
43
44impl<'a> BytesDecode<'a> for CacheKeyCodec {
45 type DItem = CacheKey;
46
47 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, BoxedError> {
48 let key: [u8; 20] = bytes.try_into()?;
49 Ok(key)
50 }
51}
52
53impl<'a> BytesEncode<'a> for SignedPacket {
54 type EItem = SignedPacket;
55
56 fn bytes_encode(signed_packet: &Self::EItem) -> Result<Cow<'a, [u8]>, BoxedError> {
57 Ok(Cow::Owned(signed_packet.serialize().to_vec()))
58 }
59}
60
61impl<'a> BytesDecode<'a> for SignedPacket {
62 type DItem = SignedPacket;
63
64 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, BoxedError> {
65 Ok(SignedPacket::deserialize(bytes)?)
66 }
67}
68
69#[derive(Clone)]
70pub struct LmdbCache {
72 capacity: usize,
73 env: Env,
74 signed_packets_table: SignedPacketsTable,
75 key_to_time_table: KeyToTimeTable,
76 time_to_key_table: TimeToKeyTable,
77 batch: Arc<RwLock<Vec<CacheKey>>>,
78}
79
80impl Debug for LmdbCache {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 f.debug_struct("LmdbCache")
83 .field("capacity", &self.capacity)
84 .field("env", &self.env)
85 .finish_non_exhaustive()
86 }
87}
88
89impl LmdbCache {
90 pub unsafe fn open(env_path: &Path, capacity: usize) -> Result<Self, Error> {
98 let page_size = page_size::get();
99
100 let map_size = capacity
102 .checked_mul(SignedPacket::MAX_BYTES as usize)
103 .and_then(|x| x.checked_add(page_size))
104 .and_then(|x| x.checked_div(page_size))
105 .and_then(|x| x.checked_mul(page_size))
106 .unwrap_or(MAX_MAP_SIZE)
107 .max(MIN_MAP_SIZE);
108
109 fs::create_dir_all(env_path)?;
110
111 let env = unsafe {
112 EnvOpenOptions::new()
113 .map_size(map_size)
114 .max_dbs(3)
115 .open(env_path)?
116 };
117
118 let mut wtxn = env.write_txn()?;
119
120 let signed_packets_table: SignedPacketsTable =
121 env.create_database(&mut wtxn, Some(SIGNED_PACKET_TABLE))?;
122 let key_to_time_table: KeyToTimeTable =
123 env.create_database(&mut wtxn, Some(KEY_TO_TIME_TABLE))?;
124 let time_to_key_table: TimeToKeyTable =
125 env.create_database(&mut wtxn, Some(TIME_TO_KEY_TABLE))?;
126
127 wtxn.commit()?;
128
129 let instance = Self {
130 capacity,
131 env,
132 signed_packets_table,
133 key_to_time_table,
134 time_to_key_table,
135 batch: Arc::new(RwLock::new(vec![])),
136 };
137
138 Ok(instance)
139 }
140
141 pub fn open_unsafe(env_path: &Path, capacity: usize) -> Result<Self, Error> {
145 unsafe { Self::open(env_path, capacity) }
146 }
147
148 fn internal_len(&self) -> Result<usize, heed::Error> {
149 let rtxn = self.env.read_txn()?;
150 let len = self.signed_packets_table.len(&rtxn)? as usize;
151 rtxn.commit()?;
152
153 Ok(len)
154 }
155
156 fn internal_put(
157 &self,
158 key: &CacheKey,
159 signed_packet: &SignedPacket,
160 ) -> Result<(), heed::Error> {
161 if self.capacity == 0 {
162 return Ok(());
163 }
164
165 let mut wtxn = self.env.write_txn()?;
166
167 let packets = self.signed_packets_table;
168 let key_to_time = self.key_to_time_table;
169 let time_to_key = self.time_to_key_table;
170
171 let mut batch = self.batch.write().expect("LmdbCache::batch.write()");
172 update_lru(&mut wtxn, packets, key_to_time, time_to_key, &batch)?;
173
174 let len = packets.len(&wtxn)? as usize;
175
176 if len >= self.capacity {
177 debug!(?len, ?self.capacity, "Reached cache capacity, deleting extra item.");
178
179 let mut iter = time_to_key.iter(&wtxn)?;
180
181 if let Some((time, key)) = iter.next().transpose()? {
182 drop(iter);
183
184 time_to_key.delete(&mut wtxn, &time)?;
185 key_to_time.delete(&mut wtxn, &key)?;
186 packets.delete(&mut wtxn, &key)?;
187 };
188 }
189
190 batch.clear();
191
192 if let Some(old_time) = key_to_time.get(&wtxn, key)? {
193 time_to_key.delete(&mut wtxn, &old_time)?;
194 }
195
196 let new_time = Timestamp::now();
197
198 time_to_key.put(&mut wtxn, &new_time.as_u64(), key)?;
199 key_to_time.put(&mut wtxn, key, &new_time.as_u64())?;
200
201 packets.put(&mut wtxn, key, signed_packet)?;
202
203 wtxn.commit()?;
204
205 Ok(())
206 }
207
208 fn internal_get(&self, key: &CacheKey) -> Result<Option<SignedPacket>, heed::Error> {
209 self.batch
210 .write()
211 .expect("LmdbCache::batch.write()")
212 .push(*key);
213
214 self.internal_get_read_only(key)
215 }
216
217 fn internal_get_read_only(&self, key: &CacheKey) -> Result<Option<SignedPacket>, heed::Error> {
218 let rtxn = self.env.read_txn()?;
219
220 if let Some(signed_packet) = self.signed_packets_table.get(&rtxn, key)? {
221 return Ok(Some(signed_packet));
222 }
223
224 rtxn.commit()?;
225
226 Ok(None)
227 }
228}
229
230fn update_lru(
231 wtxn: &mut RwTxn,
232 packets: SignedPacketsTable,
233 key_to_time: KeyToTimeTable,
234 time_to_key: TimeToKeyTable,
235 to_update: &[CacheKey],
236) -> Result<(), heed::Error> {
237 for key in to_update {
238 if packets.get(wtxn, key)?.is_some() {
239 if let Some(time) = key_to_time.get(wtxn, key)? {
240 time_to_key.delete(wtxn, &time)?;
241 };
242
243 let new_time = Timestamp::now();
244
245 time_to_key.put(wtxn, &new_time.as_u64(), key)?;
246 key_to_time.put(wtxn, key, &new_time.as_u64())?;
247 }
248 }
249
250 Ok(())
251}
252
253impl Cache for LmdbCache {
254 fn capacity(&self) -> usize {
255 self.capacity
256 }
257
258 fn len(&self) -> usize {
259 match self.internal_len() {
260 Ok(result) => result,
261 Err(error) => {
262 debug!(?error, "Error in LmdbCache::len");
263 0
264 }
265 }
266 }
267
268 fn put(&self, key: &CacheKey, signed_packet: &SignedPacket) {
269 if let Err(error) = self.internal_put(key, signed_packet) {
270 debug!(?error, "Error in LmdbCache::put");
271 };
272 }
273
274 fn get(&self, key: &CacheKey) -> Option<SignedPacket> {
275 match self.internal_get(key) {
276 Ok(result) => result,
277 Err(error) => {
278 debug!(?error, "Error in LmdbCache::get");
279
280 None
281 }
282 }
283 }
284
285 fn get_read_only(&self, key: &CacheKey) -> Option<SignedPacket> {
286 match self.internal_get_read_only(key) {
287 Ok(result) => result,
288 Err(error) => {
289 debug!(?error, "Error in LmdbCache::get");
290
291 None
292 }
293 }
294 }
295}
296
297#[derive(thiserror::Error, Debug)]
298pub enum Error {
300 #[error(transparent)]
301 Lmdb(#[from] heed::Error),
303
304 #[error(transparent)]
305 IO(#[from] std::io::Error),
307}
308
309#[cfg(test)]
310mod tests {
311 use crate::Keypair;
312
313 use super::*;
314
315 #[test]
316 fn max_map_size() {
317 let env_path = std::env::temp_dir().join(Timestamp::now().to_string());
318
319 LmdbCache::open_unsafe(&env_path, usize::MAX).unwrap();
320 }
321
322 #[test]
323 fn lru_capacity() {
324 let env_path = std::env::temp_dir().join(Timestamp::now().to_string());
325
326 let cache = LmdbCache::open_unsafe(&env_path, 2).unwrap();
327
328 let mut keys = vec![];
329
330 for i in 0..2 {
331 let signed_packet = SignedPacket::builder()
332 .txt("foo".try_into().unwrap(), "bar".try_into().unwrap(), i)
333 .sign(&Keypair::random())
334 .unwrap();
335
336 let key = CacheKey::from(signed_packet.public_key());
337 cache.put(&key, &signed_packet);
338
339 keys.push((key, signed_packet));
340 }
341
342 assert_eq!(
343 cache.get_read_only(&keys.first().unwrap().0).unwrap(),
344 keys.first().unwrap().1,
345 "first key saved"
346 );
347 assert_eq!(
348 cache.get_read_only(&keys.last().unwrap().0).unwrap(),
349 keys.last().unwrap().1,
350 "second key saved"
351 );
352
353 assert_eq!(cache.len(), 2);
354
355 let signed_packet = SignedPacket::builder()
357 .txt("foo".try_into().unwrap(), "bar".try_into().unwrap(), 3)
358 .sign(&Keypair::random())
359 .unwrap();
360 let key = CacheKey::from(signed_packet.public_key());
361 cache.put(&key, &signed_packet);
362
363 assert_eq!(cache.len(), 2);
364
365 assert!(
366 cache.get_read_only(&keys.first().unwrap().0).is_none(),
367 "oldest key dropped"
368 );
369 assert_eq!(
370 cache.get_read_only(&keys.last().unwrap().0).unwrap(),
371 keys.last().unwrap().1,
372 "more recent key survived"
373 );
374 assert_eq!(
375 cache.get_read_only(&key).unwrap(),
376 signed_packet,
377 "most recent key survived"
378 )
379 }
380
381 #[test]
382 fn lru_capacity_refresh_oldest() {
383 let env_path = std::env::temp_dir().join(Timestamp::now().to_string());
384
385 let cache = LmdbCache::open_unsafe(&env_path, 2).unwrap();
386
387 let mut keys = vec![];
388
389 for i in 0..2 {
390 let signed_packet = SignedPacket::builder()
391 .txt("foo".try_into().unwrap(), "bar".try_into().unwrap(), i)
392 .sign(&Keypair::random())
393 .unwrap();
394
395 let key = CacheKey::from(signed_packet.public_key());
396 cache.put(&key, &signed_packet);
397
398 keys.push((key, signed_packet));
399 }
400
401 assert_eq!(
402 cache.get_read_only(&keys.first().unwrap().0).unwrap(),
403 keys.first().unwrap().1,
404 "first key saved"
405 );
406 assert_eq!(
407 cache.get_read_only(&keys.last().unwrap().0).unwrap(),
408 keys.last().unwrap().1,
409 "second key saved"
410 );
411
412 cache.get(&keys.first().unwrap().0).unwrap();
414
415 assert_eq!(cache.len(), 2);
416
417 let signed_packet = SignedPacket::builder()
419 .txt("foo".try_into().unwrap(), "bar".try_into().unwrap(), 3)
420 .sign(&Keypair::random())
421 .unwrap();
422 let key = CacheKey::from(signed_packet.public_key());
423 cache.put(&key, &signed_packet);
424
425 assert_eq!(cache.len(), 2);
426
427 assert!(
428 cache.get_read_only(&keys.last().unwrap().0).is_none(),
429 "oldest key dropped"
430 );
431 assert_eq!(
432 cache.get_read_only(&keys.first().unwrap().0).unwrap(),
433 keys.first().unwrap().1,
434 "refreshed key survived"
435 );
436 assert_eq!(
437 cache.get_read_only(&key).unwrap(),
438 signed_packet,
439 "most recent key survived"
440 )
441 }
442}