1#![crate_name = "pond_cache"]
2
3use std::hash::{DefaultHasher, Hash, Hasher};
4use std::path::PathBuf;
5
6use chrono::{DateTime, Duration, Utc};
7use rusqlite::Connection;
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10
11pub use rusqlite::Error;
12
13pub struct Cache<K, T> {
15 path: PathBuf,
16 ttl: Duration,
17 key: std::marker::PhantomData<K>,
18 data: std::marker::PhantomData<T>,
19}
20
21#[derive(Debug)]
22struct CacheEntry<T>
23where
24 T: Serialize + DeserializeOwned + Clone,
25{
26 key: u32,
27 value: T,
28 expiration: DateTime<Utc>,
29}
30
31impl<K: Hash, T: Serialize + DeserializeOwned + Clone> Cache<K, T> {
32 pub fn new(path: PathBuf) -> Result<Self, Error> {
51 Self::with_time_to_live(path, Duration::minutes(10))
52 }
53
54 pub fn with_time_to_live(path: PathBuf, ttl: Duration) -> Result<Self, Error> {
75 let db = Connection::open(path.as_path())?;
76
77 db.execute(
78 "CREATE TABLE IF NOT EXISTS items (
79 id TEXT PRIMARY KEY,
80 expires TEXT NOT NULL,
81 data BLOB NOT NULL
82 )",
83 (),
84 )?;
85
86 db.close().expect("Failed to close database connection");
87
88 Ok(Self {
89 path,
90 ttl,
91 key: std::marker::PhantomData,
92 data: std::marker::PhantomData,
93 })
94 }
95
96 pub fn get(&self, key: K) -> Result<Option<T>, Error> {
118 let db = Connection::open(self.path.as_path())?;
119
120 let mut stmt = db.prepare(
121 "SELECT id, expires, data
122 FROM items
123 WHERE id = ?1",
124 )?;
125
126 let mut hasher = DefaultHasher::new();
127 let hash = {
128 key.hash(&mut hasher);
129 hasher.finish() as u32
130 };
131 let mut rows = stmt.query([hash]).unwrap();
132
133 let Some(row) = rows.next().unwrap() else {
134 return Ok(None);
135 };
136
137 let expires: DateTime<Utc> = row
138 .get::<usize, String>(1)
139 .map(|expires_string| {
140 DateTime::parse_from_rfc3339(&expires_string)
141 .unwrap()
142 .with_timezone(&Utc)
143 })
144 .unwrap();
145 let data: Vec<u8> = row.get(2).unwrap();
146
147 drop(rows);
148 drop(stmt);
149 db.close().expect("Failed to close database connection");
150
151 let data: T = bitcode::deserialize(&data).unwrap();
152
153 if expires < Utc::now() {
154 Ok(None)
155 } else {
156 Ok(Some(data))
157 }
158 }
159
160 pub fn store(&self, key: K, value: T) -> Result<(), Error> {
186 self.store_with_expiration(key, value, Utc::now() + self.ttl)
187 }
188
189 pub fn store_with_expiration(
218 &self,
219 key: K,
220 value: T,
221 expiration: DateTime<Utc>,
222 ) -> Result<(), Error> {
223 let mut hasher = DefaultHasher::new();
224 let hash = {
225 key.hash(&mut hasher);
226 hasher.finish() as u32
227 };
228
229 let value = CacheEntry {
230 key: hash,
231 value,
232 expiration,
233 };
234
235 let db = Connection::open(self.path.as_path())?;
236
237 db.execute(
238 "INSERT OR REPLACE INTO items (id, expires, data) VALUES (?1, ?2, ?3);",
239 (
240 &value.key.to_string(),
241 &value.expiration.to_rfc3339(),
242 &bitcode::serialize(&value.value).unwrap(),
243 ),
244 )?;
245
246 db.close().expect("Failed to close database connection");
247
248 Ok(())
249 }
250
251 pub fn clean(&self) -> Result<(), Error> {
270 let db = Connection::open(self.path.as_path())?;
271
272 db.execute(
273 "DELETE FROM items WHERE expires < ?1;",
274 (&Utc::now().to_rfc3339(),),
275 )?;
276
277 db.close().expect("Failed to close database connection");
278
279 Ok(())
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use serde::Deserialize;
286 use serde::Serialize;
287 use uuid::Uuid;
288
289 use super::*;
290
291 #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
292 struct User {
293 id: Uuid,
294 name: String,
295 }
296
297 fn store_manual(
298 path: PathBuf,
299 key: String,
300 value: Vec<u8>,
301 expires: DateTime<Utc>,
302 ) -> Result<(), Error> {
303 let mut hasher = DefaultHasher::new();
304 let hash = {
305 key.hash(&mut hasher);
306 hasher.finish() as u32
307 };
308
309 let db = Connection::open(path.as_path()).unwrap();
310
311 db.execute(
312 "INSERT OR REPLACE INTO items (id, expires, data) VALUES (?1, ?2, ?3);",
313 (hash, &expires.to_rfc3339(), &value),
314 )
315 .unwrap();
316
317 db.close().unwrap();
318 Ok(())
319 }
320
321 fn get_manual<T: Serialize + DeserializeOwned + Clone>(
322 path: PathBuf,
323 key: String,
324 ) -> Result<Option<CacheEntry<T>>, Error> {
325 let db = Connection::open(path.as_path())?;
326
327 let mut stmt = db.prepare(
328 "SELECT id, expires, data
329 FROM items
330 WHERE id = ?1",
331 )?;
332
333 let mut hasher = DefaultHasher::new();
334 let hash = {
335 key.hash(&mut hasher);
336 hasher.finish() as u32
337 };
338
339 let mut rows = stmt.query([hash]).unwrap();
340
341 let Some(row) = rows.next().unwrap() else {
342 return Ok(None);
343 };
344
345 let expires: DateTime<Utc> = row
346 .get::<usize, String>(1)
347 .map(|expires_string| {
348 DateTime::parse_from_rfc3339(&expires_string)
349 .unwrap()
350 .with_timezone(&Utc)
351 })
352 .unwrap();
353 let data: Vec<u8> = row.get(2).unwrap();
354
355 drop(rows);
356 drop(stmt);
357 db.close().expect("Failed to close database connection");
358
359 let data: T = bitcode::deserialize(&data).unwrap();
360
361 Ok(Some(CacheEntry {
362 key: hash,
363 value: data,
364 expiration: expires,
365 }))
366 }
367
368 #[test]
369 fn test_new() {
370 let filename = std::env::temp_dir().join(format!(
371 "pond-test-{}-{}.sqlite",
372 Uuid::new_v4(),
373 rand::random::<u8>()
374 ));
375 let cache: Cache<Uuid, String> = Cache::new(filename.clone()).unwrap();
376 assert_eq!(cache.path, filename);
377 assert_eq!(cache.ttl, Duration::minutes(10));
378 }
379
380 #[test]
381 fn test_load_existing() {
382 let filename = std::env::temp_dir().join(format!(
383 "pond-test-{}-{}.sqlite",
384 Uuid::new_v4(),
385 rand::random::<u8>()
386 ));
387 let _: Cache<Uuid, String> = Cache::new(filename.clone()).unwrap();
388 let _: Cache<Uuid, String> = Cache::new(filename).unwrap();
389 }
390
391 #[test]
392 fn test_time_to_live() {
393 let filename = std::env::temp_dir().join(format!(
394 "pond-test-{}-{}.sqlite",
395 Uuid::new_v4(),
396 rand::random::<u8>()
397 ));
398 let cache: Cache<Uuid, String> =
399 Cache::with_time_to_live(filename.clone(), Duration::minutes(5)).unwrap();
400 assert_eq!(cache.path, filename);
401 assert_eq!(cache.ttl, Duration::minutes(5));
402 }
403
404 #[test]
405 fn test_store_get() {
406 let filename = std::env::temp_dir().join(format!(
407 "pond-test-{}-{}.sqlite",
408 Uuid::new_v4(),
409 rand::random::<u8>()
410 ));
411
412 let cache = Cache::new(filename).unwrap();
413
414 let key = Uuid::new_v4();
415 let value = String::from("Hello, world!");
416
417 cache.store(key, value.clone()).unwrap();
418 let result: Option<_> = cache.get(key).unwrap();
419
420 assert_eq!(result, Some(value));
421 }
422
423 #[test]
424 fn test_store_get_struct() {
425 let filename = std::env::temp_dir().join(format!(
426 "pond-test-{}-{}.sqlite",
427 Uuid::new_v4(),
428 rand::random::<u8>()
429 ));
430
431 let cache = Cache::new(filename).unwrap();
432
433 let key = Uuid::new_v4();
434 let value = User {
435 id: Uuid::new_v4(),
436 name: String::from("Alice"),
437 };
438
439 cache.store(key, value.clone()).unwrap();
440 let result: Option<_> = cache.get(key).unwrap();
441
442 assert_eq!(result, Some(value));
443 }
444
445 #[test]
446 fn test_store_existing() {
447 let filename = std::env::temp_dir().join(format!(
448 "pond-test-{}-{}.sqlite",
449 Uuid::new_v4(),
450 rand::random::<u8>()
451 ));
452
453 let cache = Cache::new(filename).unwrap();
454
455 let key = Uuid::new_v4();
456 let value = String::from("Hello, world!");
457
458 cache.store(key, value).unwrap();
459
460 let value = String::from("Hello, world! 2");
461 cache.store(key, value.clone()).unwrap();
462 let result: Option<_> = cache.get(key).unwrap();
463
464 assert_eq!(result, Some(value));
465 }
466
467 #[test]
468 fn test_get_expired() {
469 let filename = std::env::temp_dir().join(format!(
470 "pond-test-{}-{}.sqlite",
471 Uuid::new_v4(),
472 rand::random::<u8>()
473 ));
474
475 let cache = Cache::new(filename.clone()).unwrap();
476
477 let key = Uuid::new_v4();
478 let value = String::from("Hello, world!");
479
480 store_manual(
481 filename,
482 key.to_string(),
483 bitcode::serialize(&value).unwrap(),
484 Utc::now() - Duration::minutes(5),
485 )
486 .unwrap();
487 let result: Option<String> = cache.get(key).unwrap();
488
489 assert_eq!(result, None);
490 }
491
492 #[test]
493 fn test_get_nonexistent() {
494 let filename = std::env::temp_dir().join(format!(
495 "pond-test-{}-{}.sqlite",
496 Uuid::new_v4(),
497 rand::random::<u8>()
498 ));
499
500 let cache = Cache::new(filename).unwrap();
501
502 let key = Uuid::new_v4();
503
504 let result: Option<String> = cache.get(key).unwrap();
505
506 assert_eq!(result, None);
507 }
508
509 #[test]
510 fn test_invalid_path() {
511 let cache: Result<Cache<Uuid, String>, Error> =
512 Cache::new(PathBuf::from("invalid/path/db.sqlite"));
513
514 assert!(cache.is_err());
515 }
516
517 #[test]
518 fn test_clean() {
519 let filename = std::env::temp_dir().join(format!(
520 "pond-test-{}-{}.sqlite",
521 Uuid::new_v4(),
522 rand::random::<u8>()
523 ));
524
525 let cache: Cache<Uuid, String> =
526 Cache::with_time_to_live(filename.clone(), Duration::minutes(5)).unwrap();
527
528 let key = Uuid::new_v4().to_string();
529 let value = String::from("Hello, world!");
530
531 store_manual(
532 filename.clone(),
533 key.clone(),
534 bitcode::serialize(&value).unwrap(),
535 Utc::now() - Duration::minutes(5),
536 )
537 .unwrap();
538
539 let result: Option<CacheEntry<String>> = get_manual(filename.clone(), key.clone()).unwrap();
540 if let Some(result) = result {
541 assert_eq!(result.value, value);
542 } else {
543 panic!("Expected result to be Some");
544 }
545
546 cache.clean().unwrap();
547 let result: Option<CacheEntry<String>> = get_manual(filename, key).unwrap();
548 assert!(result.is_none());
549 }
550
551 #[test]
552 fn test_clean_leaves_unexpired() {
553 let filename = std::env::temp_dir().join(format!(
554 "pond-test-{}-{}.sqlite",
555 Uuid::new_v4(),
556 rand::random::<u8>()
557 ));
558
559 let cache: Cache<Uuid, String> =
560 Cache::with_time_to_live(filename.clone(), Duration::minutes(5)).unwrap();
561
562 let key = Uuid::new_v4().to_string();
563 let value = String::from("Hello, world!");
564
565 store_manual(
566 filename.clone(),
567 key.clone(),
568 bitcode::serialize(&value).unwrap(),
569 Utc::now() + Duration::minutes(15),
570 )
571 .unwrap();
572
573 let result: Option<CacheEntry<String>> = get_manual(filename.clone(), key.clone()).unwrap();
574 if let Some(result) = result {
575 assert_eq!(result.value, value);
576 } else {
577 panic!("Expected result to be Some");
578 }
579
580 cache.clean().unwrap();
581
582 let result: Option<CacheEntry<String>> = get_manual(filename, key).unwrap();
583 if let Some(result) = result {
584 assert_eq!(result.value, value);
585 } else {
586 panic!("Expected result to be Some");
587 }
588 }
589}