mod proto {
tonic::include_proto!("blip.cache");
}
use crate::{ExposedService, MeshService, MultiNodeCut, Subscription};
use bytes::Bytes;
use cache_2q::Cache as Cache2q;
use consistent_hash_ring::Ring;
use futures::future::FutureExt;
use once_cell::sync::OnceCell;
use proto::{cache_client::CacheClient, cache_server::CacheServer, Key, Value};
use rand::{thread_rng, Rng};
use std::{
cmp,
collections::{hash_map::Entry, HashMap},
net::SocketAddr,
sync::Arc,
};
use tokio::sync::{Mutex, RwLock, Semaphore};
use tonic::{transport::Channel, Request, Response, Status};
#[crate::async_trait]
pub trait Source: Sync + Send + 'static {
async fn get(&self, key: &[u8]) -> Result<Vec<u8>, Status>;
}
struct FnSource<F>(F);
#[crate::async_trait]
impl<F> Source for FnSource<F>
where F: 'static + Sync + Send + Fn(&[u8]) -> Vec<u8>
{
async fn get(&self, key: &[u8]) -> Result<Vec<u8>, Status> {
Ok(self.0(key))
}
}
pub struct Cache<S: ?Sized = dyn Source>(Arc<Inner<S>>);
struct Inner<S: ?Sized> {
inflight: Mutex<HashMap<Bytes, Arc<Lazy>>>,
remote: RwLock<Remote>,
local_keys: Mutex<Cache2q<Bytes, Bytes>>,
hot_keys: Mutex<Cache2q<Bytes, Bytes>>,
source: S,
}
struct Lazy {
sem: Semaphore,
val: OnceCell<Result<Bytes, Status>>,
}
enum Flight {
Leader(Arc<Lazy>),
Follower(Arc<Lazy>),
}
#[derive(Default)]
struct Remote {
config: Option<MultiNodeCut>,
shards: Ring<SocketAddr>,
}
impl<S: ?Sized> Clone for Cache<S> {
#[inline]
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
#[crate::async_trait]
impl MeshService for Cache {
async fn accept(self: Box<Self>, mut cuts: Subscription) {
while let Ok(cut) = cuts.recv().await {
let mut r = self.0.remote.write().await;
r.shards.clear();
r.shards
.extend(cut.with_meta(key!(Self)).map(|m| m.0.addr()));
r.config = Some(cut);
}
}
}
impl ExposedService for Cache {
#[inline]
fn add_metadata<K: Extend<(String, Vec<u8>)>>(&self, keys: &mut K) {
keys.extend(vec![(key!(Self).to_owned(), vec![])]);
}
type Service = CacheServer<Self>;
#[inline]
fn into_service(self) -> Self::Service {
CacheServer::new(self)
}
}
#[crate::async_trait]
impl proto::cache_server::Cache for Cache {
#[inline]
async fn get(&self, req: Request<Key>) -> Result<Response<Value>, Status> {
self.get(req.into_inner().key)
.await
.map(|b| Value { buf: b.to_vec() })
.map(Response::new)
}
}
impl Cache {
pub fn new<S: Source>(max_keys: usize, source: S) -> Self {
let max_hot = cmp::max(1, max_keys / 8);
let inner = Inner {
inflight: Mutex::default(),
remote: RwLock::default(),
local_keys: Cache2q::new(max_keys).into(),
hot_keys: Cache2q::new(max_hot).into(),
source,
};
Self(Arc::new(inner))
}
pub fn from_fn<F>(max_keys: usize, source: F) -> Self
where F: Sync + Send + 'static + Fn(&[u8]) -> Vec<u8> {
Self::new(max_keys, FnSource(source))
}
pub async fn get<K: Into<Bytes>>(&self, key: K) -> Result<Bytes, Status> {
let key = key.into();
match self.liftoff(key.clone()).await {
Flight::Leader(lazy) => {
let call = self.get_inner(key.clone()).await;
lazy.val.set(call).unwrap();
lazy.sem.add_permits(2 ^ 24);
self.0.inflight.lock().await.remove(&key);
(lazy.val.get().unwrap().as_ref())
.map(|v| v.clone())
.map_err(clone_status)
}
Flight::Follower(lazy) => {
drop(lazy.sem.acquire().await);
(lazy.val.get().unwrap().as_ref())
.map(|v| v.clone())
.map_err(clone_status)
}
}
}
#[inline]
async fn liftoff(&self, key: Bytes) -> Flight {
match self.0.inflight.lock().await.entry(key) {
Entry::Vacant(v) => Flight::Leader(Arc::clone(v.insert(Arc::new(Lazy {
sem: Semaphore::new(0),
val: OnceCell::new(),
})))),
Entry::Occupied(o) => Flight::Follower(Arc::clone(o.get())),
}
}
async fn get_inner(&self, key: Bytes) -> Result<Bytes, Status> {
if let Some(buf) = load(&self.0.hot_keys, &key).await {
return Ok(buf);
}
if let Some(shard) = self.lookup_shard(&key).await {
let mut c = CacheClient::new(shard);
let val = c.get(Key { key: key.to_vec() }).await?;
let buf = Bytes::from(val.into_inner().buf);
if thread_rng().gen_range(0..8) == 4 {
store(&self.0.hot_keys, key, buf.clone()).await;
}
return Ok(buf);
}
if let Some(buf) = load(&self.0.local_keys, &key).await {
return Ok(buf);
}
let buf = Bytes::from(self.0.source.get(&key).await?);
store(&self.0.local_keys, key, buf.clone()).await;
Ok(buf)
}
#[inline]
async fn lookup_shard(&self, key: &[u8]) -> Option<Channel> {
let conn = self.0.remote.read().map(|r| {
let cut = r.config.as_ref()?;
match *r.shards.try_get(key)? {
s if s == cut.local_addr() => None,
s => Some(cut[s].channel()),
}
});
conn.await
}
}
#[inline]
async fn load(cache: &Mutex<Cache2q<Bytes, Bytes>>, key: &[u8]) -> Option<Bytes> {
cache.lock().await.get(key).cloned()
}
#[inline]
async fn store(cache: &Mutex<Cache2q<Bytes, Bytes>>, key: Bytes, buf: Bytes) {
cache.lock().await.insert(key, buf);
}
fn clone_status(status: &Status) -> Status {
Status::with_details_and_metadata(
status.code(),
status.message(),
Bytes::copy_from_slice(status.details()),
status.metadata().clone(),
)
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
#[quickcheck_async::tokio]
async fn values_are_cached(keys: HashSet<Vec<u8>>) {
if keys.is_empty() {
return;
}
struct Src(Mutex<HashSet<Vec<u8>>>);
#[crate::async_trait]
impl Source for Src {
async fn get(&self, key: &[u8]) -> Result<Vec<u8>, Status> {
let mut seen = self.0.lock().await;
assert!(seen.insert(key.to_vec()), "key was fetched twice :(");
Ok(key.to_vec())
}
}
let cache = Cache::new(keys.len(), Src(Mutex::default()));
for key in keys.into_iter() {
assert_eq!(key, cache.get(key.clone()).await.unwrap());
}
}
}