1#![deny(missing_docs)]
111
112use chrono::{DateTime, Utc};
113use crossbeam::queue::SegQueue;
114use futures_channel::oneshot;
115use hashbrown::HashMap;
116use hex::ToHex as _;
117use parking_lot::RwLock;
118use serde::{Deserialize, Serialize};
119use serde_cbor as cbor;
120use serde_hashkey as hashkey;
121use serde_json as json;
122use std::error;
123use std::fmt;
124use std::future::Future;
125use std::sync::atomic::{AtomicUsize, Ordering};
126use std::sync::Arc;
127
128pub use chrono::Duration;
129pub use sled;
130
131#[derive(Debug)]
133pub enum Error {
134 Cbor(cbor::error::Error),
136 HashKey(hashkey::Error),
138 Json(json::error::Error),
140 Sled(sled::Error),
142 Failed,
144}
145
146impl fmt::Display for Error {
147 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
148 match self {
149 Error::Cbor(e) => write!(fmt, "CBOR error: {}", e),
150 Error::HashKey(e) => write!(fmt, "HashKey error: {}", e),
151 Error::Json(e) => write!(fmt, "JSON error: {}", e),
152 Error::Sled(e) => write!(fmt, "Database error: {}", e),
153 Error::Failed => write!(fmt, "Operation failed"),
154 }
155 }
156}
157
158impl error::Error for Error {
159 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
160 match self {
161 Error::Cbor(e) => Some(e),
162 Error::HashKey(e) => Some(e),
163 Error::Json(e) => Some(e),
164 Error::Sled(e) => Some(e),
165 _ => None,
166 }
167 }
168}
169
170impl From<json::error::Error> for Error {
171 fn from(error: json::error::Error) -> Self {
172 Error::Json(error)
173 }
174}
175
176impl From<cbor::error::Error> for Error {
177 fn from(error: cbor::error::Error) -> Self {
178 Error::Cbor(error)
179 }
180}
181
182impl From<hashkey::Error> for Error {
183 fn from(error: hashkey::Error) -> Self {
184 Error::HashKey(error)
185 }
186}
187
188impl From<sled::Error> for Error {
189 fn from(error: sled::Error) -> Self {
190 Error::Sled(error)
191 }
192}
193
194pub enum State<T> {
196 Fresh(StoredEntry<T>),
198 Expired(StoredEntry<T>),
201 Missing,
203}
204
205impl<T> State<T> {
206 pub fn get(self) -> Option<T> {
208 match self {
209 State::Fresh(e) | State::Expired(e) => Some(e.value),
210 State::Missing => None,
211 }
212 }
213}
214
215#[derive(Debug, Serialize, Deserialize)]
219pub struct JsonEntry {
220 pub key: serde_json::Value,
222 #[serde(flatten)]
224 pub stored: StoredEntry<serde_json::Value>,
225}
226
227#[derive(Debug, Serialize, Deserialize)]
229pub struct StoredEntry<T> {
230 expires_at: DateTime<Utc>,
231 value: T,
232}
233
234#[derive(Debug, Serialize)]
238pub struct StoredEntryRef<'a, T> {
239 expires_at: DateTime<Utc>,
240 value: &'a T,
241}
242
243impl<T> StoredEntry<T> {
244 fn is_expired(&self, now: DateTime<Utc>) -> bool {
246 self.expires_at < now
247 }
248}
249
250#[derive(Debug, Serialize, Deserialize)]
252struct PartialStoredEntry {
253 expires_at: DateTime<Utc>,
254}
255
256impl PartialStoredEntry {
257 fn is_expired(&self, now: DateTime<Utc>) -> bool {
259 self.expires_at < now
260 }
261
262 fn into_stored_entry(self) -> StoredEntry<()> {
264 StoredEntry {
265 expires_at: self.expires_at,
266 value: (),
267 }
268 }
269}
270
271#[derive(Default)]
272struct Waker {
273 pending: AtomicUsize,
275 channels: SegQueue<oneshot::Sender<bool>>,
277}
278
279impl Waker {
280 fn cleanup(&self, error: bool) {
283 let mut previous = self.pending.load(Ordering::Acquire);
284
285 loop {
286 while previous > 1 {
287 let mut received = 0usize;
288
289 while let Some(waker) = self.channels.pop() {
290 received += 1;
291 let _ = waker.send(error);
292 }
293
294 previous = self.pending.fetch_sub(received, Ordering::AcqRel);
298 }
299
300 previous =
301 match self
302 .pending
303 .compare_exchange(1, 0, Ordering::AcqRel, Ordering::Acquire)
304 {
305 Ok(n) => n,
306 Err(n) => n,
307 };
308
309 if previous == 1 {
310 break;
311 }
312 }
313 }
314}
315
316struct Inner {
317 ns: Option<hashkey::Key>,
319 db: sled::Tree,
321 wakers: RwLock<HashMap<Vec<u8>, Arc<Waker>>>,
324}
325
326#[derive(Clone)]
330pub struct Cache {
331 inner: Arc<Inner>,
332}
333
334impl Cache {
335 pub fn load(db: sled::Tree) -> Result<Cache, Error> {
337 let cache = Cache {
338 inner: Arc::new(Inner {
339 ns: None,
340 db,
341 wakers: Default::default(),
342 }),
343 };
344 cache.cleanup()?;
345 Ok(cache)
346 }
347
348 pub fn delete_with_ns<N, K>(&self, ns: Option<&N>, key: &K) -> Result<(), Error>
350 where
351 N: Serialize,
352 K: Serialize,
353 {
354 let ns = match ns {
355 Some(ns) => Some(hashkey::to_key(ns)?.normalize()),
356 None => None,
357 };
358
359 let key = self.key_with_ns(ns.as_ref(), key)?;
360 self.inner.db.remove(key)?;
361 Ok(())
362 }
363
364 pub fn list_json(&self) -> Result<Vec<JsonEntry>, Error> {
366 let mut out = Vec::new();
367
368 for result in self.inner.db.range::<&[u8], _>(..) {
369 let (key, value) = result?;
370
371 let key: json::Value = match cbor::from_slice(&key) {
372 Ok(key) => key,
373 Err(_) => continue,
375 };
376
377 let stored = match cbor::from_slice(&value) {
378 Ok(storage) => storage,
379 Err(_) => continue,
381 };
382
383 out.push(JsonEntry { key, stored });
384 }
385
386 Ok(out)
387 }
388
389 fn cleanup(&self) -> Result<(), Error> {
393 let now = Utc::now();
394
395 for result in self.inner.db.range::<&[u8], _>(..) {
396 let (key, value) = result?;
397
398 let entry: PartialStoredEntry = match cbor::from_slice(&value) {
399 Ok(entry) => entry,
400 Err(e) => {
401 if log::log_enabled!(log::Level::Trace) {
402 log::warn!(
403 "{}: failed to load: {}: {}",
404 KeyFormat(&key),
405 e,
406 KeyFormat(&value)
407 );
408 } else {
409 log::warn!("{}: failed to load: {}", KeyFormat(&key), e);
410 }
411
412 self.inner.db.remove(key)?;
414 continue;
415 }
416 };
417
418 if entry.is_expired(now) {
419 self.inner.db.remove(key)?;
420 }
421 }
422
423 Ok(())
424 }
425
426 pub fn namespaced<N>(&self, ns: &N) -> Result<Self, Error>
432 where
433 N: Serialize,
434 {
435 Ok(Self {
436 inner: Arc::new(Inner {
437 ns: Some(hashkey::to_key(ns)?.normalize()),
438 db: self.inner.db.clone(),
439 wakers: Default::default(),
440 }),
441 })
442 }
443
444 pub fn insert<K, T>(&self, key: K, age: Duration, value: &T) -> Result<(), Error>
446 where
447 K: Serialize,
448 T: Serialize,
449 {
450 let key = self.key(&key)?;
451 self.inner_insert(&key, age, value)
452 }
453
454 #[inline(always)]
456 fn inner_insert<T>(&self, key: &Vec<u8>, age: Duration, value: &T) -> Result<(), Error>
457 where
458 T: Serialize,
459 {
460 let expires_at = Utc::now() + age;
461
462 let value = match cbor::to_vec(&StoredEntryRef { expires_at, value }) {
463 Ok(value) => value,
464 Err(e) => {
465 log::trace!("store:{} *errored*", KeyFormat(key));
466 return Err(e.into());
467 }
468 };
469
470 log::trace!("store:{}", KeyFormat(key));
471 self.inner.db.insert(key, value)?;
472 Ok(())
473 }
474
475 pub fn test<K>(&self, key: K) -> Result<State<()>, Error>
477 where
478 K: Serialize,
479 {
480 let key = self.key(&key)?;
481 self.inner_test(&key)
482 }
483
484 #[inline(always)]
486 fn inner_test(&self, key: &[u8]) -> Result<State<()>, Error> {
487 let value = match self.inner.db.get(key)? {
488 Some(value) => value,
489 None => {
490 log::trace!("test:{} -> null (missing)", KeyFormat(key));
491 return Ok(State::Missing);
492 }
493 };
494
495 let stored: PartialStoredEntry = match cbor::from_slice(&value) {
496 Ok(value) => value,
497 Err(e) => {
498 if log::log_enabled!(log::Level::Trace) {
499 log::warn!(
500 "{}: failed to deserialize: {}: {}",
501 KeyFormat(key),
502 e,
503 KeyFormat(&value)
504 );
505 } else {
506 log::warn!("{}: failed to deserialize: {}", KeyFormat(key), e);
507 }
508
509 log::trace!("test:{} -> null (deserialize error)", KeyFormat(key));
510 return Ok(State::Missing);
511 }
512 };
513
514 if stored.is_expired(Utc::now()) {
515 log::trace!("test:{} -> null (expired)", KeyFormat(key));
516 return Ok(State::Expired(stored.into_stored_entry()));
517 }
518
519 log::trace!("test:{} -> *value*", KeyFormat(key));
520 Ok(State::Fresh(stored.into_stored_entry()))
521 }
522
523 pub fn get<K, T>(&self, key: K) -> Result<State<T>, Error>
525 where
526 K: Serialize,
527 T: serde::de::DeserializeOwned,
528 {
529 let key = self.key(&key)?;
530 self.inner_get(&key)
531 }
532
533 #[inline(always)]
535 fn inner_get<T>(&self, key: &[u8]) -> Result<State<T>, Error>
536 where
537 T: serde::de::DeserializeOwned,
538 {
539 let value = match self.inner.db.get(key)? {
540 Some(value) => value,
541 None => {
542 log::trace!("load:{} -> null (missing)", KeyFormat(key));
543 return Ok(State::Missing);
544 }
545 };
546
547 let stored: StoredEntry<T> = match cbor::from_slice(&value) {
548 Ok(value) => value,
549 Err(e) => {
550 if log::log_enabled!(log::Level::Trace) {
551 log::warn!(
552 "{}: failed to deserialize: {}: {}",
553 KeyFormat(key),
554 e,
555 KeyFormat(&value)
556 );
557 } else {
558 log::warn!("{}: failed to deserialize: {}", KeyFormat(key), e);
559 }
560
561 log::trace!("load:{} -> null (deserialize error)", KeyFormat(key));
562 return Ok(State::Missing);
563 }
564 };
565
566 if stored.is_expired(Utc::now()) {
567 log::trace!("load:{} -> null (expired)", KeyFormat(key));
568 return Ok(State::Expired(stored));
569 }
570
571 log::trace!("load:{} -> *value*", KeyFormat(key));
572 Ok(State::Fresh(stored))
573 }
574
575 fn waker(&self, key: &[u8]) -> Arc<Waker> {
577 let wakers = self.inner.wakers.read();
578
579 match wakers.get(key) {
580 Some(waker) => return waker.clone(),
581 None => drop(wakers),
582 }
583
584 self.inner
585 .wakers
586 .write()
587 .entry(key.to_vec())
588 .or_default()
589 .clone()
590 }
591
592 pub async fn wrap<K, F, T, E>(&self, key: K, age: Duration, future: F) -> Result<T, E>
594 where
595 K: Serialize,
596 F: Future<Output = Result<T, E>>,
597 T: Serialize + serde::de::DeserializeOwned,
598 E: From<Error>,
599 {
600 let key = self.key(&key)?;
601
602 loop {
603 if let State::Fresh(e) = self.inner_get(&key)? {
608 return Ok(e.value);
609 }
610
611 let waker = self.waker(&key);
612
613 if waker.pending.fetch_add(1, Ordering::AcqRel) > 0 {
615 let (tx, rx) = oneshot::channel();
616 waker.channels.push(tx);
617
618 let result = rx.await;
619
620 match result {
622 Ok(true) => return Err(E::from(Error::Failed)),
623 Err(oneshot::Canceled) | Ok(false) => continue,
624 }
625 }
626
627 if let State::Fresh(e) = self.inner_get(&key)? {
634 waker.cleanup(false);
635 return Ok(e.value);
636 }
637
638 let result = Guard::new(|| waker.cleanup(false)).wrap(future).await;
640
641 match result {
644 Ok(output) => {
645 self.inner_insert(&key, age, &output)?;
646 waker.cleanup(false);
647 return Ok(output);
648 }
649 Err(e) => {
650 waker.cleanup(true);
651 return Err(e);
652 }
653 }
654 }
655
656 struct Guard<F>
658 where
659 F: FnMut(),
660 {
661 f: F,
662 }
663
664 impl<F> Guard<F>
665 where
666 F: FnMut(),
667 {
668 pub fn new(f: F) -> Self {
670 Self { f }
671 }
672
673 pub async fn wrap<O>(self, future: O) -> O::Output
675 where
676 O: Future,
677 {
678 let result = future.await;
679 std::mem::forget(self);
680 result
681 }
682 }
683
684 impl<F> Drop for Guard<F>
685 where
686 F: FnMut(),
687 {
688 fn drop(&mut self) {
689 (self.f)();
690 }
691 }
692 }
693
694 fn key<T>(&self, key: &T) -> Result<Vec<u8>, Error>
696 where
697 T: Serialize,
698 {
699 self.key_with_ns(self.inner.ns.as_ref(), key)
700 }
701
702 fn key_with_ns<T>(&self, ns: Option<&hashkey::Key>, key: &T) -> Result<Vec<u8>, Error>
704 where
705 T: Serialize,
706 {
707 let key = hashkey::to_key(key)?.normalize();
708 let key = Key(ns, key);
709 return Ok(cbor::to_vec(&key)?);
710
711 #[derive(Serialize)]
712 struct Key<'a>(Option<&'a hashkey::Key>, hashkey::Key);
713 }
714}
715
716struct KeyFormat<'a>(&'a [u8]);
718
719impl fmt::Display for KeyFormat<'_> {
720 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
721 let value = match cbor::from_slice::<cbor::Value>(self.0) {
722 Ok(value) => value,
723 Err(_) => return self.0.encode_hex::<String>().fmt(fmt),
724 };
725
726 let value = match json::to_string(&value) {
727 Ok(value) => value,
728 Err(_) => return self.0.encode_hex::<String>().fmt(fmt),
729 };
730
731 value.fmt(fmt)
732 }
733}
734
735#[cfg(test)]
736mod tests {
737 use super::{Cache, Duration, Error};
738 use std::{error, fs, sync::Arc, thread};
739 use tempdir::TempDir;
740
741 fn db(name: &str) -> Result<sled::Tree, Box<dyn error::Error>> {
742 let path = TempDir::new(name)?;
743 let path = path.path();
744
745 if !path.is_dir() {
746 fs::create_dir_all(path)?;
747 }
748
749 let db = sled::open(path)?;
750 Ok(db.open_tree("test")?)
751 }
752
753 #[test]
754 fn test_cached() -> Result<(), Box<dyn error::Error>> {
755 use std::sync::atomic::{AtomicUsize, Ordering};
756
757 let db = db("test_cached")?;
758 let cache = Cache::load(db)?;
759
760 let count = Arc::new(AtomicUsize::default());
761
762 let c = count.clone();
763
764 let op1 = cache.wrap("a", Duration::hours(12), async move {
765 let _ = c.fetch_add(1, Ordering::SeqCst);
766 Ok::<_, Error>(String::from("foo"))
767 });
768
769 let c = count.clone();
770
771 let op2 = cache.wrap("a", Duration::hours(12), async move {
772 let _ = c.fetch_add(1, Ordering::SeqCst);
773 Ok::<_, Error>(String::from("foo"))
774 });
775
776 ::futures::executor::block_on(async move {
777 let (a, b) = ::futures::future::join(op1, op2).await;
778 assert_eq!("foo", a.expect("ok result"));
779 assert_eq!("foo", b.expect("ok result"));
780 assert_eq!(1, count.load(Ordering::SeqCst));
781 });
782
783 Ok(())
784 }
785
786 #[test]
787 fn test_contended() -> Result<(), Box<dyn error::Error>> {
788 use crossbeam::queue::SegQueue;
789 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
790
791 const THREAD_COUNT: usize = 1_000;
792
793 let db = db("test_contended")?;
794 let cache = Cache::load(db)?;
795
796 let started = Arc::new(AtomicBool::new(false));
797 let count = Arc::new(AtomicUsize::default());
798 let results = Arc::new(SegQueue::new());
799 let mut threads = Vec::with_capacity(THREAD_COUNT);
800
801 for _ in 0..THREAD_COUNT {
802 let started = started.clone();
803 let cache = cache.clone();
804 let results = results.clone();
805 let count = count.clone();
806
807 let t = thread::spawn(move || {
808 let op = cache.wrap("a", Duration::hours(12), async move {
809 let _ = count.fetch_add(1, Ordering::SeqCst);
810 Ok::<_, Error>(String::from("foo"))
811 });
812
813 while !started.load(Ordering::Acquire) {}
814
815 ::futures::executor::block_on(async move {
816 results.push(op.await);
817 });
818 });
819
820 threads.push(t);
821 }
822
823 started.store(true, Ordering::Release);
824
825 for t in threads {
826 t.join().expect("thread to join");
827 }
828
829 assert_eq!(1, count.load(Ordering::SeqCst));
830 Ok(())
831 }
832
833 #[test]
834 fn test_guards() -> Result<(), Box<dyn error::Error>> {
835 use self::futures::PollOnce;
836 use ::futures::channel::oneshot;
837 use std::sync::atomic::Ordering;
838
839 let db = db("test_guards")?;
840 let cache = Cache::load(db)?;
841
842 ::futures::executor::block_on(async move {
843 let (op1_tx, op1_rx) = oneshot::channel::<()>();
844
845 let op1 = cache.wrap("a", Duration::hours(12), async move {
846 let _ = op1_rx.await;
847 Ok::<_, Error>(String::from("foo"))
848 });
849
850 pin_utils::pin_mut!(op1);
851
852 let (op2_tx, op2_rx) = oneshot::channel::<()>();
853
854 let op2 = cache.wrap("a", Duration::hours(12), async move {
855 let _ = op2_rx.await;
856 Ok::<_, Error>(String::from("foo"))
857 });
858
859 pin_utils::pin_mut!(op2);
860
861 assert!(PollOnce::new(&mut op1).await.is_none());
862
863 let k = cache.key(&"a")?;
864 let waker = cache.inner.wakers.read().get(&k).cloned();
865 assert!(waker.is_some());
866 let waker = waker.expect("waker to be registered");
867
868 assert_eq!(1, waker.pending.load(Ordering::SeqCst));
869 assert!(PollOnce::new(&mut op2).await.is_none());
870 assert_eq!(2, waker.pending.load(Ordering::SeqCst));
871
872 op1_tx.send(()).expect("send to op1");
873 op2_tx.send(()).expect("send to op2");
874
875 assert!(PollOnce::new(&mut op1).await.is_some());
876 assert_eq!(0, waker.pending.load(Ordering::SeqCst));
877 assert!(PollOnce::new(&mut op2).await.is_some());
878
879 Ok(())
880 })
881 }
882
883 mod futures {
884 use std::{
885 future::Future,
886 pin::Pin,
887 task::{Context, Poll},
888 };
889
890 pub struct PollOnce<F> {
891 future: F,
892 }
893
894 impl<F> PollOnce<F> {
895 pub fn new(future: F) -> Self {
897 Self { future }
898 }
899 }
900
901 impl<F> PollOnce<F> {
902 pin_utils::unsafe_pinned!(future: F);
903 }
904
905 impl<F> Future for PollOnce<F>
906 where
907 F: Future,
908 {
909 type Output = Option<F::Output>;
910
911 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
912 match self.future().poll(cx) {
913 Poll::Ready(output) => Poll::Ready(Some(output)),
914 Poll::Pending => Poll::Ready(None),
915 }
916 }
917 }
918 }
919}