use std::sync::Arc;
use async_trait::async_trait;
use crate::db_cache::stats::DbCacheStats;
use crate::stats::StatRegistry;
use crate::{
block::Block, db_state::SsTableId, filter::BloomFilter, flatbuffer_types::SsTableIndexOwned,
};
#[cfg(feature = "foyer")]
pub mod foyer;
#[cfg(feature = "moka")]
pub mod moka;
pub const DEFAULT_MAX_CAPACITY: u64 = 64 * 1024 * 1024;
#[async_trait]
pub trait DbCache: Send + Sync {
async fn get_block(&self, key: CachedKey) -> Option<CachedEntry>;
async fn get_index(&self, key: CachedKey) -> Option<CachedEntry>;
async fn get_filter(&self, key: CachedKey) -> Option<CachedEntry>;
async fn insert(&self, key: CachedKey, value: CachedEntry);
#[allow(dead_code)]
async fn remove(&self, key: CachedKey);
#[allow(dead_code)]
fn entry_count(&self) -> u64;
}
#[non_exhaustive]
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct CachedKey(SsTableId, u64);
impl From<(SsTableId, u64)> for CachedKey {
fn from((sst_id, block_id): (SsTableId, u64)) -> Self {
Self(sst_id, block_id)
}
}
#[non_exhaustive]
#[derive(Clone)]
enum CachedItem {
Block(Arc<Block>),
SsTableIndex(Arc<SsTableIndexOwned>),
BloomFilter(Arc<BloomFilter>),
}
#[derive(Clone)]
pub struct CachedEntry {
item: CachedItem,
}
impl CachedEntry {
pub(crate) fn with_block(block: Arc<Block>) -> Self {
Self {
item: CachedItem::Block(block),
}
}
pub(crate) fn with_sst_index(sst_index: Arc<SsTableIndexOwned>) -> Self {
Self {
item: CachedItem::SsTableIndex(sst_index),
}
}
pub(crate) fn with_bloom_filter(bloom_filter: Arc<BloomFilter>) -> Self {
Self {
item: CachedItem::BloomFilter(bloom_filter),
}
}
pub(crate) fn block(&self) -> Option<Arc<Block>> {
match &self.item {
CachedItem::Block(block) => Some(block.clone()),
_ => None,
}
}
pub(crate) fn sst_index(&self) -> Option<Arc<SsTableIndexOwned>> {
match &self.item {
CachedItem::SsTableIndex(sst_index) => Some(sst_index.clone()),
_ => None,
}
}
pub(crate) fn bloom_filter(&self) -> Option<Arc<BloomFilter>> {
match &self.item {
CachedItem::BloomFilter(bloom_filter) => Some(bloom_filter.clone()),
_ => None,
}
}
pub fn size(&self) -> usize {
match &self.item {
CachedItem::Block(block) => block.size(),
CachedItem::SsTableIndex(sst_index) => sst_index.size(),
CachedItem::BloomFilter(bloom_filter) => bloom_filter.size(),
}
}
pub fn clamp_allocated_size(&self) -> Self {
match &self.item {
CachedItem::Block(block) => Self::with_block(Arc::new(block.clamp_allocated_size())),
CachedItem::SsTableIndex(sst_index) => {
Self::with_sst_index(Arc::new(sst_index.clamp_allocated_size()))
}
CachedItem::BloomFilter(bloom_filter) => {
Self::with_bloom_filter(Arc::new(bloom_filter.clamp_allocated_size()))
}
}
}
}
pub struct DbCacheWrapper {
stats: DbCacheStats,
cache: Arc<dyn DbCache>,
}
impl DbCacheWrapper {
pub fn new(cache: Arc<dyn DbCache>, stats_registry: &StatRegistry) -> Self {
Self {
stats: DbCacheStats::new(stats_registry),
cache,
}
}
}
#[async_trait]
impl DbCache for DbCacheWrapper {
async fn get_block(&self, key: CachedKey) -> Option<CachedEntry> {
let entry = self.cache.get_block(key).await;
if entry.is_some() {
self.stats.data_block_hit.inc();
} else {
self.stats.data_block_miss.inc();
}
entry
}
async fn get_index(&self, key: CachedKey) -> Option<CachedEntry> {
let entry = self.cache.get_index(key).await;
if entry.is_some() {
self.stats.index_hit.inc();
} else {
self.stats.index_miss.inc();
}
entry
}
async fn get_filter(&self, key: CachedKey) -> Option<CachedEntry> {
let entry = self.cache.get_filter(key).await;
if entry.is_some() {
self.stats.filter_hit.inc();
} else {
self.stats.filter_miss.inc();
}
entry
}
async fn insert(&self, key: CachedKey, value: CachedEntry) {
self.cache.insert(key, value.clamp_allocated_size()).await
}
#[allow(dead_code)]
async fn remove(&self, key: CachedKey) {
self.cache.remove(key).await
}
fn entry_count(&self) -> u64 {
self.cache.entry_count()
}
}
pub mod stats {
use crate::stats::{Counter, StatRegistry};
use std::sync::Arc;
macro_rules! dbcache_stat_name {
($suffix:expr) => {
crate::stat_name!("dbcache", $suffix)
};
}
pub const DB_CACHE_FILTER_HIT: &str = dbcache_stat_name!("filter_hit");
pub const DB_CACHE_FILTER_MISS: &str = dbcache_stat_name!("filter_miss");
pub const DB_CACHE_INDEX_HIT: &str = dbcache_stat_name!("index_hit");
pub const DB_CACHE_INDEX_MISS: &str = dbcache_stat_name!("index_miss");
pub const DB_CACHE_DATA_BLOCK_HIT: &str = dbcache_stat_name!("data_block_hit");
pub const DB_CACHE_DATA_BLOCK_MISS: &str = dbcache_stat_name!("data_block_miss");
pub(super) struct DbCacheStats {
pub(super) filter_hit: Arc<Counter>,
pub(super) filter_miss: Arc<Counter>,
pub(super) index_hit: Arc<Counter>,
pub(super) index_miss: Arc<Counter>,
pub(super) data_block_hit: Arc<Counter>,
pub(super) data_block_miss: Arc<Counter>,
}
impl DbCacheStats {
pub(super) fn new(registry: &StatRegistry) -> Self {
let stats = Self {
filter_hit: Arc::new(Counter::default()),
filter_miss: Arc::new(Counter::default()),
index_hit: Arc::new(Counter::default()),
index_miss: Arc::new(Counter::default()),
data_block_hit: Arc::new(Counter::default()),
data_block_miss: Arc::new(Counter::default()),
};
registry.register(DB_CACHE_FILTER_HIT, stats.filter_hit.clone());
registry.register(DB_CACHE_FILTER_MISS, stats.filter_miss.clone());
registry.register(DB_CACHE_INDEX_HIT, stats.index_hit.clone());
registry.register(DB_CACHE_INDEX_MISS, stats.index_miss.clone());
registry.register(DB_CACHE_DATA_BLOCK_HIT, stats.data_block_hit.clone());
registry.register(DB_CACHE_DATA_BLOCK_MISS, stats.data_block_miss.clone());
stats
}
}
}
#[cfg(test)]
mod tests {
use crate::db_cache::{CachedEntry, CachedKey, DbCache, DbCacheWrapper};
use crate::db_state::SsTableId;
use crate::flatbuffer_types::test_utils::assert_index_clamped;
use crate::sst::SsTableFormat;
use crate::stats::{ReadableStat, StatRegistry};
use crate::test_utils::{build_test_sst, SstData};
use async_trait::async_trait;
use rstest::{fixture, rstest};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use ulid::Ulid;
const SST_ID: SsTableId = SsTableId::Compacted(Ulid::from_parts(0u64, 0u128));
#[rstest]
#[tokio::test]
async fn test_should_count_filter_hits(
cache: DbCacheWrapper,
sst_format: SsTableFormat,
sst: SstData,
) {
let filter = sst_format
.read_filter_raw(&sst.info, &sst.data)
.unwrap()
.unwrap();
let key = CachedKey::from((SST_ID, 12345u64));
cache
.insert(key.clone(), CachedEntry::with_bloom_filter(filter))
.await;
for i in 1..4 {
let _ = cache.get_filter(key.clone()).await;
assert_eq!(0, cache.stats.filter_miss.get());
assert_eq!(i, cache.stats.filter_hit.get());
}
}
#[rstest]
#[tokio::test]
async fn test_should_count_filter_misses(cache: DbCacheWrapper) {
let key = CachedKey::from((SST_ID, 12345u64));
for i in 1..4 {
let _ = cache.get_filter(key.clone()).await;
assert_eq!(i, cache.stats.filter_miss.get());
assert_eq!(0, cache.stats.filter_hit.get());
}
}
#[rstest]
#[tokio::test]
async fn test_should_count_index_hits(
cache: DbCacheWrapper,
sst_format: SsTableFormat,
sst: SstData,
) {
let index = sst_format.read_index_raw(&sst.info, &sst.data).unwrap();
let key = CachedKey::from((SST_ID, 12345u64));
cache
.insert(key.clone(), CachedEntry::with_sst_index(Arc::new(index)))
.await;
for i in 1..4 {
let _ = cache.get_index(key.clone()).await;
assert_eq!(0, cache.stats.index_miss.get());
assert_eq!(i, cache.stats.index_hit.get());
}
}
#[rstest]
#[tokio::test]
async fn test_should_clamp_entries_to_cache(
cache: DbCacheWrapper,
sst_format: SsTableFormat,
sst: SstData,
) {
let index = Arc::new(sst_format.read_index_raw(&sst.info, &sst.data).unwrap());
let key = CachedKey::from((SST_ID, 12345u64));
cache
.insert(key.clone(), CachedEntry::with_sst_index(index.clone()))
.await;
let cached = cache.get_index(key).await.unwrap();
assert_index_clamped(index.as_ref(), cached.sst_index().unwrap().as_ref());
}
#[rstest]
#[tokio::test]
async fn test_should_count_index_misses(cache: DbCacheWrapper) {
let key = CachedKey::from((SST_ID, 12345u64));
for i in 1..4 {
let _ = cache.get_index(key.clone()).await;
assert_eq!(i, cache.stats.index_miss.get());
assert_eq!(0, cache.stats.index_hit.get());
}
}
#[rstest]
#[tokio::test]
async fn test_should_count_data_block_hits(
cache: DbCacheWrapper,
sst_format: SsTableFormat,
sst: SstData,
) {
let index = sst_format.read_index_raw(&sst.info, &sst.data).unwrap();
let block = sst_format
.read_block_raw(&sst.info, &index, 0, &sst.data)
.unwrap();
let key = CachedKey::from((SST_ID, 12345u64));
cache
.insert(key.clone(), CachedEntry::with_block(Arc::new(block)))
.await;
for i in 1..4 {
let _ = cache.get_block(key.clone()).await;
assert_eq!(0, cache.stats.data_block_miss.get());
assert_eq!(i, cache.stats.data_block_hit.get());
}
}
#[rstest]
#[tokio::test]
async fn test_should_count_data_block_misses(cache: DbCacheWrapper) {
let key = CachedKey::from((SST_ID, 12345u64));
for i in 1..4 {
let _ = cache.get_block(key.clone()).await;
assert_eq!(i, cache.stats.data_block_miss.get());
assert_eq!(0, cache.stats.data_block_hit.get());
}
}
#[fixture]
fn cache() -> DbCacheWrapper {
let registry = StatRegistry::new();
DbCacheWrapper::new(Arc::new(TestCache::new()), ®istry)
}
#[fixture]
fn sst_format() -> SsTableFormat {
SsTableFormat {
block_size: 128,
..SsTableFormat::default()
}
}
#[fixture]
fn sst(sst_format: SsTableFormat) -> SstData {
build_test_sst(&sst_format, 1)
}
struct TestCache {
items: Mutex<HashMap<CachedKey, CachedEntry>>,
}
impl TestCache {
fn new() -> Self {
Self {
items: Mutex::new(HashMap::new()),
}
}
}
#[async_trait]
impl DbCache for TestCache {
async fn get_block(&self, key: CachedKey) -> Option<CachedEntry> {
let guard = self.items.lock().unwrap();
guard.get(&key).cloned()
}
async fn get_index(&self, key: CachedKey) -> Option<CachedEntry> {
let guard = self.items.lock().unwrap();
guard.get(&key).cloned()
}
async fn get_filter(&self, key: CachedKey) -> Option<CachedEntry> {
let guard = self.items.lock().unwrap();
guard.get(&key).cloned()
}
async fn insert(&self, key: CachedKey, value: CachedEntry) {
let mut guard = self.items.lock().unwrap();
guard.insert(key, value);
}
async fn remove(&self, key: CachedKey) {
let mut guard = self.items.lock().unwrap();
guard.remove(&key);
}
fn entry_count(&self) -> u64 {
let guard = self.items.lock().unwrap();
guard.iter().count() as u64
}
}
}