#![feature(async_await)]
#![deny(missing_docs)]
use chrono::{DateTime, Utc};
use crossbeam::queue::SegQueue;
use futures::channel::oneshot;
use hashbrown::HashMap;
use hex::ToHex as _;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use serde_cbor as cbor;
use serde_hashkey as hashkey;
use serde_json as json;
use std::{
error, fmt,
future::Future,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
pub use chrono::Duration;
pub use sled;
#[derive(Debug)]
pub enum Error {
Cbor(cbor::error::Error),
HashKey(hashkey::Error),
Json(json::error::Error),
Sled(sled::Error),
Failed,
}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::Cbor(e) => write!(fmt, "CBOR error: {}", e),
Error::HashKey(e) => write!(fmt, "HashKey error: {}", e),
Error::Json(e) => write!(fmt, "JSON error: {}", e),
Error::Sled(e) => write!(fmt, "Database error: {}", e),
Error::Failed => write!(fmt, "Operation failed"),
}
}
}
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match self {
Error::Cbor(e) => Some(e),
Error::HashKey(e) => Some(e),
Error::Json(e) => Some(e),
Error::Sled(e) => Some(e),
_ => None,
}
}
}
impl From<json::error::Error> for Error {
fn from(error: json::error::Error) -> Self {
Error::Json(error)
}
}
impl From<cbor::error::Error> for Error {
fn from(error: cbor::error::Error) -> Self {
Error::Cbor(error)
}
}
impl From<hashkey::Error> for Error {
fn from(error: hashkey::Error) -> Self {
Error::HashKey(error)
}
}
impl From<sled::Error> for Error {
fn from(error: sled::Error) -> Self {
Error::Sled(error)
}
}
pub enum State<T> {
Fresh(StoredEntry<T>),
Expired(StoredEntry<T>),
Missing,
}
impl<T> State<T> {
pub fn get(self) -> Option<T> {
match self {
State::Fresh(e) | State::Expired(e) => Some(e.value),
State::Missing => None,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct JsonEntry {
pub key: serde_json::Value,
#[serde(flatten)]
pub stored: StoredEntry<serde_json::Value>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct StoredEntry<T> {
expires_at: DateTime<Utc>,
value: T,
}
#[derive(Debug, Serialize)]
pub struct StoredEntryRef<'a, T> {
expires_at: DateTime<Utc>,
value: &'a T,
}
impl<T> StoredEntry<T> {
fn is_expired(&self, now: DateTime<Utc>) -> bool {
self.expires_at < now
}
}
#[derive(Debug, Serialize, Deserialize)]
struct PartialStoredEntry {
expires_at: DateTime<Utc>,
}
impl PartialStoredEntry {
fn is_expired(&self, now: DateTime<Utc>) -> bool {
self.expires_at < now
}
fn into_stored_entry(self) -> StoredEntry<()> {
StoredEntry {
expires_at: self.expires_at,
value: (),
}
}
}
#[derive(Default)]
struct Waker {
pending: AtomicUsize,
channels: SegQueue<oneshot::Sender<bool>>,
}
impl Waker {
fn cleanup(&self, error: bool) {
let mut previous = self.pending.load(Ordering::Acquire);
loop {
while previous > 1 {
let mut received = 0usize;
while let Ok(waker) = self.channels.pop() {
received += 1;
let _ = waker.send(error);
}
previous = self.pending.fetch_sub(received, Ordering::AcqRel);
}
previous = self.pending.compare_and_swap(1, 0, Ordering::AcqRel);
if previous == 1 {
break;
}
}
}
}
struct Inner {
ns: Option<hashkey::Key>,
db: Arc<sled::Tree>,
wakers: RwLock<HashMap<Vec<u8>, Arc<Waker>>>,
}
#[derive(Clone)]
pub struct Cache {
inner: Arc<Inner>,
}
impl Cache {
pub fn load(db: Arc<sled::Tree>) -> Result<Cache, Error> {
let cache = Cache {
inner: Arc::new(Inner {
ns: None,
db,
wakers: Default::default(),
}),
};
cache.cleanup()?;
Ok(cache)
}
pub fn delete_with_ns<N, K>(&self, ns: Option<&N>, key: &K) -> Result<(), Error>
where
N: Serialize,
K: Serialize,
{
let ns = match ns {
Some(ns) => Some(hashkey::to_key(ns)?),
None => None,
};
let key = self.key_with_ns(ns.as_ref(), key)?;
self.inner.db.del(&key)?;
Ok(())
}
pub fn list_json(&self) -> Result<Vec<JsonEntry>, Error> {
let mut out = Vec::new();
for result in self.inner.db.range::<&[u8], _>(..) {
let (key, value) = result?;
let key: json::Value = match cbor::from_slice(&*key) {
Ok(key) => key,
Err(_) => continue,
};
let stored = match cbor::from_slice(&*value) {
Ok(storage) => storage,
Err(_) => continue,
};
out.push(JsonEntry { key, stored });
}
Ok(out)
}
fn cleanup(&self) -> Result<(), Error> {
let now = Utc::now();
for result in self.inner.db.range::<&[u8], _>(..) {
let (key, value) = result?;
let entry: PartialStoredEntry = match cbor::from_slice(&*value) {
Ok(entry) => entry,
Err(e) => {
if log::log_enabled!(log::Level::Trace) {
log::warn!(
"{}: failed to load: {}: {}",
KeyFormat(&*key),
e,
KeyFormat(&*value)
);
} else {
log::warn!("{}: failed to load: {}", KeyFormat(&*key), e);
}
self.inner.db.del(key)?;
continue;
}
};
if entry.is_expired(now) {
self.inner.db.del(key)?;
}
}
Ok(())
}
pub fn namespaced<N>(&self, ns: &N) -> Result<Self, Error>
where
N: Serialize,
{
Ok(Self {
inner: Arc::new(Inner {
ns: Some(hashkey::to_key(ns)?),
db: self.inner.db.clone(),
wakers: Default::default(),
}),
})
}
pub fn insert<K, T>(&self, key: K, age: Duration, value: &T) -> Result<(), Error>
where
K: Serialize,
T: Serialize,
{
let key = self.key(&key)?;
self.inner_insert(&key, age, value)
}
#[inline(always)]
fn inner_insert<T>(&self, key: &Vec<u8>, age: Duration, value: &T) -> Result<(), Error>
where
T: Serialize,
{
let expires_at = Utc::now() + age;
let value = match cbor::to_vec(&StoredEntryRef { expires_at, value }) {
Ok(value) => value,
Err(e) => {
log::trace!("store:{} *errored*", KeyFormat(key));
return Err(e.into());
}
};
log::trace!("store:{}", KeyFormat(key));
self.inner.db.set(key, value)?;
Ok(())
}
pub fn test<K>(&self, key: K) -> Result<State<()>, Error>
where
K: Serialize,
{
let key = self.key(&key)?;
self.inner_test(&key)
}
#[inline(always)]
fn inner_test(&self, key: &[u8]) -> Result<State<()>, Error> {
let value = match self.inner.db.get(&key)? {
Some(value) => value,
None => {
log::trace!("test:{} -> null (missing)", KeyFormat(key));
return Ok(State::Missing);
}
};
let stored: PartialStoredEntry = match cbor::from_slice(&value) {
Ok(value) => value,
Err(e) => {
if log::log_enabled!(log::Level::Trace) {
log::warn!(
"{}: failed to deserialize: {}: {}",
KeyFormat(key),
e,
KeyFormat(&value)
);
} else {
log::warn!("{}: failed to deserialize: {}", KeyFormat(key), e);
}
log::trace!("test:{} -> null (deserialize error)", KeyFormat(key));
return Ok(State::Missing);
}
};
if stored.is_expired(Utc::now()) {
log::trace!("test:{} -> null (expired)", KeyFormat(key));
return Ok(State::Expired(stored.into_stored_entry()));
}
log::trace!("test:{} -> *value*", KeyFormat(key));
Ok(State::Fresh(stored.into_stored_entry()))
}
pub fn get<K, T>(&self, key: K) -> Result<State<T>, Error>
where
K: Serialize,
T: serde::de::DeserializeOwned,
{
let key = self.key(&key)?;
self.inner_get(&key)
}
#[inline(always)]
fn inner_get<T>(&self, key: &[u8]) -> Result<State<T>, Error>
where
T: serde::de::DeserializeOwned,
{
let value = match self.inner.db.get(key)? {
Some(value) => value,
None => {
log::trace!("load:{} -> null (missing)", KeyFormat(key));
return Ok(State::Missing);
}
};
let stored: StoredEntry<T> = match cbor::from_slice(&value) {
Ok(value) => value,
Err(e) => {
if log::log_enabled!(log::Level::Trace) {
log::warn!(
"{}: failed to deserialize: {}: {}",
KeyFormat(key),
e,
KeyFormat(&value)
);
} else {
log::warn!("{}: failed to deserialize: {}", KeyFormat(key), e);
}
log::trace!("load:{} -> null (deserialize error)", KeyFormat(key));
return Ok(State::Missing);
}
};
if stored.is_expired(Utc::now()) {
log::trace!("load:{} -> null (expired)", KeyFormat(key));
return Ok(State::Expired(stored));
}
log::trace!("load:{} -> *value*", KeyFormat(key));
Ok(State::Fresh(stored))
}
fn waker(&self, key: &[u8]) -> Arc<Waker> {
let wakers = self.inner.wakers.read();
match wakers.get(key) {
Some(waker) => return waker.clone(),
None => drop(wakers),
}
self.inner
.wakers
.write()
.entry(key.to_vec())
.or_default()
.clone()
}
pub async fn wrap<'a, K, F, T, E>(&'a self, key: K, age: Duration, future: F) -> Result<T, E>
where
K: Serialize,
F: Future<Output = Result<T, E>>,
T: Serialize + serde::de::DeserializeOwned,
E: From<Error>,
{
let key = self.key(&key)?;
loop {
if let State::Fresh(e) = self.inner_get(&key)? {
return Ok(e.value);
}
let waker = self.waker(&key);
if waker.pending.fetch_add(1, Ordering::AcqRel) > 0 {
let (tx, rx) = oneshot::channel();
waker.channels.push(tx);
let result = rx.await;
match result {
Ok(true) => return Err(E::from(Error::Failed)),
Err(oneshot::Canceled) | Ok(false) => continue,
}
}
if let State::Fresh(e) = self.inner_get(&key)? {
waker.cleanup(false);
return Ok(e.value);
}
let result = Guard::new(|| waker.cleanup(false)).wrap(future).await;
match result {
Ok(output) => {
self.inner_insert(&key, age, &output)?;
waker.cleanup(false);
return Ok(output);
}
Err(e) => {
waker.cleanup(true);
return Err(e);
}
}
}
struct Guard<F>
where
F: FnMut(),
{
f: F,
}
impl<F> Guard<F>
where
F: FnMut(),
{
pub fn new(f: F) -> Self {
Self { f }
}
pub async fn wrap<O>(self, future: O) -> O::Output
where
O: Future,
{
let result = future.await;
std::mem::forget(self);
result
}
}
impl<F> Drop for Guard<F>
where
F: FnMut(),
{
fn drop(&mut self) {
(self.f)();
}
}
}
fn key<T>(&self, key: &T) -> Result<Vec<u8>, Error>
where
T: Serialize,
{
self.key_with_ns(self.inner.ns.as_ref(), key)
}
fn key_with_ns<T>(&self, ns: Option<&hashkey::Key>, key: &T) -> Result<Vec<u8>, Error>
where
T: Serialize,
{
let key = Key(ns, hashkey::to_key(key)?);
return Ok(cbor::to_vec(&key)?);
#[derive(Serialize)]
struct Key<'a>(Option<&'a hashkey::Key>, hashkey::Key);
}
}
struct KeyFormat<'a>(&'a [u8]);
impl fmt::Display for KeyFormat<'_> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
let value = match cbor::from_slice::<cbor::Value>(self.0) {
Ok(value) => value,
Err(_) => return self.0.write_hex(fmt),
};
let value = match json::to_string(&value) {
Ok(value) => value,
Err(_) => return self.0.write_hex(fmt),
};
value.fmt(fmt)
}
}
#[cfg(test)]
mod tests {
use super::{Cache, Duration, Error};
use std::{error, fs, sync::Arc, thread};
use tempdir::TempDir;
fn db(name: &str) -> Result<Arc<sled::Tree>, Box<dyn error::Error>> {
let path = TempDir::new(name)?;
let path = path.path();
if !path.is_dir() {
fs::create_dir_all(path)?;
}
let db = sled::Db::start_default(path)?;
Ok(db.open_tree("test")?)
}
#[test]
fn test_cached() -> Result<(), Box<dyn error::Error>> {
use std::sync::atomic::{AtomicUsize, Ordering};
let db = db("test_cached")?;
let cache = Cache::load(db)?;
let count = Arc::new(AtomicUsize::default());
let c = count.clone();
let op1 = cache.wrap("a", Duration::hours(12), async move {
let _ = c.fetch_add(1, Ordering::SeqCst);
Ok::<_, Error>(String::from("foo"))
});
let c = count.clone();
let op2 = cache.wrap("a", Duration::hours(12), async move {
let _ = c.fetch_add(1, Ordering::SeqCst);
Ok::<_, Error>(String::from("foo"))
});
::futures::executor::block_on(async move {
let (a, b) = ::futures::future::join(op1, op2).await;
assert_eq!("foo", a.expect("ok result"));
assert_eq!("foo", b.expect("ok result"));
assert_eq!(1, count.load(Ordering::SeqCst));
});
Ok(())
}
#[test]
fn test_contended() -> Result<(), Box<dyn error::Error>> {
use crossbeam::queue::SegQueue;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
const THREAD_COUNT: usize = 1_000;
let db = db("test_contended")?;
let cache = Cache::load(db)?;
let started = Arc::new(AtomicBool::new(false));
let count = Arc::new(AtomicUsize::default());
let results = Arc::new(SegQueue::new());
let mut threads = Vec::with_capacity(THREAD_COUNT);
for _ in 0..THREAD_COUNT {
let started = started.clone();
let cache = cache.clone();
let results = results.clone();
let count = count.clone();
let t = thread::spawn(move || {
let op = cache.wrap("a", Duration::hours(12), async move {
let _ = count.fetch_add(1, Ordering::SeqCst);
Ok::<_, Error>(String::from("foo"))
});
while !started.load(Ordering::Acquire) {}
::futures::executor::block_on(async move {
results.push(op.await);
});
});
threads.push(t);
}
started.store(true, Ordering::Release);
for t in threads {
t.join().expect("thread to join");
}
assert_eq!(1, count.load(Ordering::SeqCst));
Ok(())
}
#[test]
fn test_guards() -> Result<(), Box<dyn error::Error>> {
use self::futures::PollOnce;
use ::futures::channel::oneshot;
use std::sync::atomic::Ordering;
let db = db("test_guards")?;
let cache = Cache::load(db)?;
::futures::executor::block_on(async move {
let (op1_tx, op1_rx) = oneshot::channel::<()>();
let op1 = cache.wrap("a", Duration::hours(12), async move {
let _ = op1_rx.await;
Ok::<_, Error>(String::from("foo"))
});
pin_utils::pin_mut!(op1);
let (op2_tx, op2_rx) = oneshot::channel::<()>();
let op2 = cache.wrap("a", Duration::hours(12), async move {
let _ = op2_rx.await;
Ok::<_, Error>(String::from("foo"))
});
pin_utils::pin_mut!(op2);
assert!(PollOnce::new(&mut op1).await.is_none());
let k = cache.key(&"a")?;
let waker = cache.inner.wakers.read().get(&k).cloned();
assert!(waker.is_some());
let waker = waker.expect("waker to be registered");
assert_eq!(1, waker.pending.load(Ordering::SeqCst));
assert!(PollOnce::new(&mut op2).await.is_none());
assert_eq!(2, waker.pending.load(Ordering::SeqCst));
op1_tx.send(()).expect("send to op1");
op2_tx.send(()).expect("send to op2");
assert!(PollOnce::new(&mut op1).await.is_some());
assert_eq!(0, waker.pending.load(Ordering::SeqCst));
assert!(PollOnce::new(&mut op2).await.is_some());
Ok(())
})
}
mod futures {
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
pub struct PollOnce<F> {
future: F,
}
impl<F> PollOnce<F> {
pub fn new(future: F) -> Self {
Self { future }
}
}
impl<F> PollOnce<F> {
pin_utils::unsafe_pinned!(future: F);
}
impl<F> Future for PollOnce<F>
where
F: Future,
{
type Output = Option<F::Output>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.future().poll(cx) {
Poll::Ready(output) => Poll::Ready(Some(output)),
Poll::Pending => Poll::Ready(None),
}
}
}
}
}