use std::collections::BTreeMap;
use std::sync::{Arc, Mutex};
use bytes::Bytes;
use tokio::sync::Mutex as AsyncMutex;
use crate::processor::record::RecordContext;
use crate::store::byte::ByteKeyValueStore;
use crate::store::cache::entry::LruCacheEntry;
use crate::store::cache::named::NamedCache;
use crate::store::session_schema::{
session_end_of, session_key, session_key_bytes_of, session_start_of,
};
pub(crate) struct CachingSessionStore {
cache: Arc<Mutex<NamedCache>>,
inner: AsyncMutex<Box<dyn ByteKeyValueStore>>,
name: String,
}
impl CachingSessionStore {
pub fn new(cache: Arc<Mutex<NamedCache>>, inner: Box<dyn ByteKeyValueStore>) -> Self {
Self {
cache,
inner: AsyncMutex::new(inner),
name: String::new(),
}
}
pub fn with_name(
cache: Arc<Mutex<NamedCache>>,
inner: Box<dyn ByteKeyValueStore>,
name: String,
) -> Self {
Self {
cache,
inner: AsyncMutex::new(inner),
name,
}
}
pub async fn get(&self, key: &[u8]) -> Option<Bytes> {
let key = Bytes::copy_from_slice(key);
let cached = {
let mut cache = self.cache.lock().unwrap();
cache.get_promote(&key).map(|e| e.value.clone())
};
match cached {
Some(value) => value,
None => self.inner.lock().await.get(&key).await,
}
}
pub async fn range(&self, lo: &[u8], hi: &[u8]) -> Vec<(Bytes, Bytes)> {
let mut merged: BTreeMap<Bytes, Bytes> = {
let inner = self.inner.lock().await;
inner.range(lo, hi).await.into_iter().collect()
};
let cached = {
let cache = self.cache.lock().unwrap();
cache.range(lo, hi)
};
for (k, e) in cached {
match e.value {
Some(v) => {
merged.insert(k, v);
}
None => {
merged.remove(&k);
}
}
}
merged.into_iter().collect()
}
pub async fn scan_all(&self) -> Vec<(Bytes, Bytes)> {
let mut merged: BTreeMap<Bytes, Bytes> = {
let inner = self.inner.lock().await;
inner.scan_all().await.into_iter().collect()
};
let cached = {
let cache = self.cache.lock().unwrap();
cache.all()
};
for (k, e) in cached {
match e.value {
Some(v) => {
merged.insert(k, v);
}
None => {
merged.remove(&k);
}
}
}
merged.into_iter().collect()
}
pub async fn put_inner(&self, key: Bytes, value: Bytes) {
self.inner.lock().await.put(key, value).await;
}
pub async fn delete_inner(&self, key: &[u8]) {
self.inner.lock().await.delete(key).await;
}
pub async fn clear(&self) {
{
let mut cache = self.cache.lock().unwrap();
*cache = NamedCache::new(self.name.clone());
}
self.inner.lock().await.clear().await;
}
#[allow(clippy::unused_async)]
pub async fn put(&self, session_key_bytes: Bytes, value: Bytes, ctx: RecordContext) {
let mut cache = self.cache.lock().unwrap();
cache.put(
session_key_bytes,
LruCacheEntry::new(Some(value), true, ctx),
);
}
#[allow(clippy::unused_async)]
pub async fn remove(&self, session_key_bytes: Bytes, ctx: RecordContext) {
let mut cache = self.cache.lock().unwrap();
cache.delete(session_key_bytes, ctx);
}
pub async fn find_sessions(
&self,
key: &[u8],
earliest_end: i64,
latest_start: i64,
) -> Vec<(i64, i64, Bytes)> {
let lo = session_key(key, 0, earliest_end.max(0));
let hi = session_key(key, i64::MAX, i64::MAX);
let mut merged: BTreeMap<Bytes, Bytes> = {
let inner = self.inner.lock().await;
inner
.range(&lo, &hi)
.await
.into_iter()
.filter(|(k, _)| session_key_bytes_of(k) == key)
.collect()
};
let cached = {
let cache = self.cache.lock().unwrap();
cache.range(&lo, &hi)
};
for (k, e) in cached {
if session_key_bytes_of(&k) != key {
continue;
}
match e.value {
Some(v) => {
merged.insert(k, v);
}
None => {
merged.remove(&k);
}
}
}
merged
.into_iter()
.filter_map(|(k, v)| {
let end = session_end_of(&k);
let start = session_start_of(&k);
(end >= earliest_end && start <= latest_start).then_some((start, end, v))
})
.collect()
}
pub async fn flush(&self) -> Vec<(Bytes, LruCacheEntry)> {
let mut collected: Vec<(Bytes, LruCacheEntry)> = Vec::new();
{
let mut cache = self.cache.lock().unwrap();
let mut listener =
|k: &Bytes, e: &LruCacheEntry| collected.push((k.clone(), e.clone()));
cache.flush(&mut listener);
} {
let mut inner = self.inner.lock().await;
for (k, e) in &collected {
match &e.value {
Some(v) => inner.put(k.clone(), v.clone()).await,
None => {
inner.delete(k).await;
}
}
}
}
collected
}
pub async fn flush_with_old(
&self,
) -> Vec<(Bytes, Option<Bytes>, Option<Bytes>, RecordContext)> {
let mut dirty: Vec<(Bytes, LruCacheEntry)> = Vec::new();
{
let mut cache = self.cache.lock().unwrap();
let mut listener = |k: &Bytes, e: &LruCacheEntry| dirty.push((k.clone(), e.clone()));
cache.flush(&mut listener);
}
let mut out = Vec::with_capacity(dirty.len());
{
let mut inner = self.inner.lock().await;
for (k, e) in dirty {
let old = inner.get(&k).await;
match &e.value {
Some(v) => inner.put(k.clone(), v.clone()).await,
None => {
inner.delete(&k).await;
}
}
out.push((k, old, e.value, e.context));
}
}
out
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::store::byte::InMemoryBytes;
fn ctx() -> RecordContext {
RecordContext {
topic: "t".to_string(),
partition: 0,
offset: 0,
timestamp: 0,
}
}
fn cache() -> Arc<Mutex<NamedCache>> {
Arc::new(Mutex::new(NamedCache::new("s".to_string())))
}
fn b(v: &'static [u8]) -> Bytes {
Bytes::from_static(v)
}
#[tokio::test]
async fn find_sessions_returns_cached() {
let store = CachingSessionStore::new(cache(), Box::new(InMemoryBytes::default()));
let sk = session_key(b"k", 0, 10);
store.put(sk, b(b"v"), ctx()).await;
let found = store.find_sessions(b"k", 0, 100).await;
assert_eq!(found, vec![(0, 10, b(b"v"))]);
}
#[tokio::test]
async fn flush_writes_through_and_returns_entries() {
let store = CachingSessionStore::new(cache(), Box::new(InMemoryBytes::default()));
let sk = session_key(b"k", 0, 10);
store.put(sk.clone(), b(b"v"), ctx()).await;
let flushed = store.flush().await;
assert_eq!(flushed.len(), 1);
assert_eq!(flushed[0].0, sk);
assert_eq!(flushed[0].1.value, Some(b(b"v")));
let found = store.find_sessions(b"k", 0, 100).await;
assert_eq!(found, vec![(0, 10, b(b"v"))]);
}
#[tokio::test]
async fn find_sessions_merges_cache_and_underlying() {
let mut inner = InMemoryBytes::default();
inner.put(session_key(b"k", 0, 10), b(b"i1")).await;
inner.put(session_key(b"k", 0, 30), b(b"i2")).await;
let store = CachingSessionStore::new(cache(), Box::new(inner));
store.put(session_key(b"k", 0, 10), b(b"c1"), ctx()).await;
store.put(session_key(b"k", 0, 20), b(b"c2"), ctx()).await;
let found = store.find_sessions(b"k", 0, 100).await;
assert_eq!(
found,
vec![
(0, 10, b(b"c1")), (0, 20, b(b"c2")), (0, 30, b(b"i2")), ]
);
}
#[tokio::test]
async fn tombstone_hides_underlying_session() {
let mut inner = InMemoryBytes::default();
inner.put(session_key(b"k", 0, 10), b(b"i1")).await;
let store = CachingSessionStore::new(cache(), Box::new(inner));
assert_eq!(
store.find_sessions(b"k", 0, 100).await,
vec![(0, 10, b(b"i1"))]
);
store.remove(session_key(b"k", 0, 10), ctx()).await;
assert!(store.find_sessions(b"k", 0, 100).await.is_empty());
}
#[tokio::test]
async fn other_key_prefix_is_not_returned() {
let store = CachingSessionStore::new(cache(), Box::new(InMemoryBytes::default()));
store.put(session_key(b"k", 0, 10), b(b"a"), ctx()).await;
store.put(session_key(b"kk", 0, 10), b(b"b"), ctx()).await;
let found = store.find_sessions(b"k", 0, 100).await;
assert_eq!(found, vec![(0, 10, b(b"a"))]);
}
#[tokio::test]
async fn get_is_cache_first_then_falls_through() {
let mut inner = InMemoryBytes::default();
let sk = session_key(b"k", 0, 10);
inner.put(sk.clone(), b(b"inner")).await;
let store = CachingSessionStore::new(cache(), Box::new(inner));
assert_eq!(store.get(&sk).await, Some(b(b"inner")));
store.put(sk.clone(), b(b"cached"), ctx()).await;
assert_eq!(store.get(&sk).await, Some(b(b"cached")));
assert_eq!(store.get(&session_key(b"k", 0, 99)).await, None);
}
#[tokio::test]
async fn range_merges_cache_over_inner_with_tombstone() {
let mut inner = InMemoryBytes::default();
let k0 = session_key(b"k", 0, 10);
let k1 = session_key(b"k", 0, 20);
let k2 = session_key(b"k", 0, 30);
inner.put(k0.clone(), b(b"i0")).await;
inner.put(k1.clone(), b(b"i1")).await;
inner.put(k2.clone(), b(b"i2")).await;
let store = CachingSessionStore::new(cache(), Box::new(inner));
store.put(k1.clone(), b(b"c1"), ctx()).await; store.remove(k2.clone(), ctx()).await;
let lo = session_key(b"k", 0, 0);
let hi = session_key(b"k", i64::MAX, i64::MAX);
let r = store.range(&lo, &hi).await;
assert_eq!(r, vec![(k0, b(b"i0")), (k1, b(b"c1"))]);
}
#[tokio::test]
async fn scan_all_merges_cache_and_underlying() {
let mut inner = InMemoryBytes::default();
let k0 = session_key(b"k", 0, 10);
let k1 = session_key(b"k", 0, 20);
let k3 = session_key(b"k", 0, 40);
inner.put(k0.clone(), b(b"i0")).await;
inner.put(k1.clone(), b(b"i1")).await;
inner.put(k3.clone(), b(b"i3")).await;
let store = CachingSessionStore::new(cache(), Box::new(inner));
store.put(k1.clone(), b(b"c1"), ctx()).await; let k2 = session_key(b"k", 0, 30);
store.put(k2.clone(), b(b"c2"), ctx()).await; store.remove(k3.clone(), ctx()).await;
let r = store.scan_all().await;
assert_eq!(r, vec![(k0, b(b"i0")), (k1, b(b"c1")), (k2, b(b"c2"))]);
}
#[tokio::test]
async fn put_and_delete_inner_bypass_the_cache() {
let store = CachingSessionStore::new(cache(), Box::new(InMemoryBytes::default()));
let sk = session_key(b"k", 0, 10);
store.put_inner(sk.clone(), b(b"v")).await;
assert_eq!(store.get(&sk).await, Some(b(b"v")));
assert!(store.flush().await.is_empty());
store.delete_inner(&sk).await;
assert_eq!(store.get(&sk).await, None);
assert!(store.flush().await.is_empty());
}
#[tokio::test]
async fn clear_empties_cache_and_inner() {
let mut inner = InMemoryBytes::default();
inner.put(session_key(b"k", 0, 10), b(b"i")).await;
let store = CachingSessionStore::new(cache(), Box::new(inner));
store.put(session_key(b"k", 0, 20), b(b"c"), ctx()).await;
store.clear().await;
assert!(store.find_sessions(b"k", 0, 1000).await.is_empty());
assert!(store.scan_all().await.is_empty());
assert!(store.flush().await.is_empty());
}
#[tokio::test]
async fn flush_deletes_tombstone_through() {
let mut inner = InMemoryBytes::default();
let sk = session_key(b"k", 0, 10);
inner.put(sk.clone(), b(b"old")).await;
let store = CachingSessionStore::new(cache(), Box::new(inner));
store.remove(sk.clone(), ctx()).await;
let flushed = store.flush().await;
assert_eq!(flushed.len(), 1);
assert_eq!(flushed[0].1.value, None);
assert_eq!(store.get(&sk).await, None);
}
#[tokio::test]
async fn flush_with_old_returns_inner_old_then_writes_through() {
let mut inner = InMemoryBytes::default();
let sk = session_key(b"k", 0, 10);
inner.put(sk.clone(), b(b"old")).await;
let store = CachingSessionStore::new(cache(), Box::new(inner));
store.put(sk.clone(), b(b"new"), ctx()).await;
let drained = store.flush_with_old().await;
assert_eq!(drained.len(), 1);
let (k, old, new, _ctx) = &drained[0];
assert_eq!(k, &sk);
assert_eq!(old.as_ref(), Some(&b(b"old"))); assert_eq!(new.as_ref(), Some(&b(b"new")));
assert_eq!(
store.find_sessions(b"k", 0, 100).await,
vec![(0, 10, b(b"new"))]
);
}
}