use std::{
path::PathBuf,
sync::atomic::{AtomicUsize, Ordering},
};
use bytes::{Bytes, BytesMut};
use bytesize::ByteSize;
use compact_str::CompactString;
use foyer::{
BlockEngineConfig, Code, DeviceBuilder, FileDeviceBuilder, FsDeviceBuilder, HybridCache,
HybridCacheBuilder, HybridCachePolicy, IoEngineConfig,
};
use mixtrics::registry::prometheus_0_14::PrometheusMetricsRegistry;
use crate::types::{BucketName, ObjectKey, ObjectKind, PageId};
#[derive(Debug)]
pub struct CacheConfig {
pub memory_size: ByteSize,
pub disk_cache: Option<DiskCacheConfig>,
pub metrics_registry: Option<prometheus::Registry>,
}
#[derive(Clone, Debug, clap::ValueEnum)]
pub enum DiskCacheKind {
#[value(name = "block")]
BlockFile,
#[value(name = "fs")]
FileSystem,
}
#[derive(Debug)]
pub struct DiskCacheConfig {
pub path: PathBuf,
pub kind: DiskCacheKind,
pub capacity: Option<ByteSize>,
pub iouring: bool,
}
pub async fn build_cache(config: CacheConfig) -> foyer::Result<HybridCache<CacheKey, CacheValue>> {
let mut builder = HybridCacheBuilder::new().with_policy(HybridCachePolicy::WriteOnEviction);
if let Some(registry) = config.metrics_registry {
builder = builder.with_metrics_registry(Box::new(PrometheusMetricsRegistry::new(registry)));
}
let mut builder = builder
.memory(config.memory_size.as_u64() as usize)
.with_weighter(|key: &CacheKey, value: &CacheValue| {
key.estimated_size() + value.estimated_size()
})
.storage()
.with_spawner(
tokio::runtime::Builder::new_multi_thread()
.thread_name_fn(|| {
static TID: AtomicUsize = AtomicUsize::new(0);
let id = TID.fetch_add(1, Ordering::Relaxed);
format!("foyer-{id}")
})
.enable_all()
.build()?
.into(),
);
if let Some(disk_config) = config.disk_cache {
let device = match disk_config.kind {
DiskCacheKind::BlockFile => {
let mut file_device = FileDeviceBuilder::new(disk_config.path);
#[cfg(target_os = "linux")]
{
file_device = file_device.with_direct(true);
}
if let Some(cap) = disk_config.capacity {
file_device = file_device.with_capacity(cap.as_u64() as usize);
}
file_device.build()?
}
DiskCacheKind::FileSystem => {
let mut fs_device = FsDeviceBuilder::new(disk_config.path);
#[cfg(target_os = "linux")]
{
fs_device = fs_device.with_direct(true);
}
if let Some(cap) = disk_config.capacity {
fs_device = fs_device.with_capacity(cap.as_u64() as usize);
}
fs_device.build()?
}
};
let engine = BlockEngineConfig::new(device).with_block_size(64 * 1024 * 1024);
builder = builder
.with_engine_config(engine)
.with_io_engine_config(io_engine_config(disk_config.iouring));
}
builder.build().await
}
fn io_engine_config(iouring: bool) -> Box<dyn IoEngineConfig> {
#[cfg(target_os = "linux")]
if iouring {
return foyer::UringIoEngineConfig::new().boxed();
}
#[cfg(not(target_os = "linux"))]
let _ = iouring; foyer::PsyncIoEngineConfig::new().boxed()
}
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
pub struct CacheKey {
pub kind: ObjectKind,
pub object: ObjectKey,
pub page_id: PageId,
}
impl CacheKey {
const VERSION: u8 = 3;
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
struct CacheKeyHeader(
[u8; 5],
);
impl CacheKeyHeader {
fn new(
version: u8,
kind_len: usize,
key_len: usize,
page_id: PageId,
) -> Result<Self, &'static str> {
if kind_len == 0 {
return Err("Kind length cannot be zero");
}
if kind_len > (1 << 6) {
return Err("Kind length exceeds 6 bits");
}
if key_len == 0 {
return Err("Key length cannot be zero");
}
if key_len > (1 << 10) {
return Err("Key length exceeds 10 bits");
}
let mut bytes = [0u8; 5];
let key_len_minus_one = key_len - 1;
bytes[0] = version;
bytes[1] = (((kind_len - 1) as u8) << 2) | ((key_len_minus_one >> 8) as u8 & 0b11);
bytes[2] = (key_len_minus_one & 0xFF) as u8;
bytes[3] = (page_id >> 8) as u8;
bytes[4] = (page_id & 0xFF) as u8;
Ok(Self(bytes))
}
fn version(self) -> u8 {
self.0[0]
}
fn kind_len(self) -> usize {
((self.0[1] >> 2) as usize) + 1
}
fn key_len(self) -> usize {
let high_bits = ((self.0[1] & 0b11) as usize) << 8;
let low_bits = self.0[2] as usize;
(high_bits | low_bits) + 1
}
fn page_id(self) -> PageId {
u16::from_be_bytes([self.0[3], self.0[4]])
}
fn from_bytes(bytes: [u8; 5]) -> Result<Self, &'static str> {
Ok(Self(bytes))
}
fn to_bytes(self) -> [u8; 5] {
self.0
}
}
impl foyer::Code for CacheKey {
fn encode(&self, writer: &mut impl std::io::Write) -> foyer::Result<()> {
let flag = CacheKeyHeader::new(
Self::VERSION,
self.kind.len(),
self.object.len(),
self.page_id,
)
.map_err(|msg| std::io::Error::new(std::io::ErrorKind::InvalidData, msg))?;
writer.write_all(&flag.to_bytes())?;
writer.write_all(self.kind.as_bytes())?;
writer.write_all(self.object.as_bytes())?;
Ok(())
}
fn decode(reader: &mut impl std::io::Read) -> foyer::Result<Self>
where
Self: Sized,
{
let header = {
let mut buf = [0u8; 5];
reader.read_exact(&mut buf)?;
CacheKeyHeader::from_bytes(buf)
.map_err(|msg| std::io::Error::new(std::io::ErrorKind::InvalidData, msg))?
};
if header.version() != Self::VERSION {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Unsupported version {}", header.version()),
)
.into());
}
let kind = {
let mut buf = BytesMut::zeroed(header.kind_len());
reader.read_exact(&mut buf)?;
let str = CompactString::from_utf8(buf).map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Invalid UTF-8 in object kind",
)
})?;
ObjectKind::new(str)
.map_err(|msg| std::io::Error::new(std::io::ErrorKind::InvalidData, msg))?
};
let object = {
let mut buf = BytesMut::zeroed(header.key_len());
reader.read_exact(&mut buf)?;
let str = CompactString::from_utf8(buf).map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Invalid UTF-8 in object key",
)
})?;
ObjectKey::new(str)
.map_err(|msg| std::io::Error::new(std::io::ErrorKind::InvalidData, msg))?
};
let page_id = header.page_id();
Ok(Self {
kind,
object,
page_id,
})
}
fn estimated_size(&self) -> usize {
size_of::<CacheKeyHeader>() + self.kind.len() + self.object.len()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CacheValue {
pub bucket: BucketName,
pub mtime: u32,
pub data: Bytes,
pub object_size: u64,
pub cached_at: u32,
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
struct CacheValueHeader(
)
[u8; 17],
);
impl CacheValueHeader {
fn new(
bucket_name_len: usize,
object_size: u64,
mtime: u32,
data_len: usize,
cached_at: u32,
) -> Result<Self, &'static str> {
if bucket_name_len == 0 {
return Err("Bucket name length cannot be zero");
}
if bucket_name_len > (1 << 6) {
return Err("Bucket name length exceeds limit");
}
if object_size >= (1 << 40) {
return Err("Object size exceeds limit");
}
if data_len > (1 << 24) {
return Err("Data length exceeds limit");
}
let data_len_minus_one = (data_len as u32).saturating_sub(1);
let bytes = [
u8::from(data_len == 0) << 6 | ((bucket_name_len - 1) as u8 & 0b0011_1111),
(object_size >> 32) as u8,
((object_size >> 24) & 0xff) as u8,
((object_size >> 16) & 0xff) as u8,
((object_size >> 8) & 0xff) as u8,
(object_size & 0xff) as u8,
((data_len_minus_one >> 16) & 0xff) as u8,
((data_len_minus_one >> 8) & 0xff) as u8,
(data_len_minus_one & 0xff) as u8,
(mtime >> 24) as u8,
((mtime >> 16) & 0xff) as u8,
((mtime >> 8) & 0xff) as u8,
(mtime & 0xff) as u8,
(cached_at >> 24) as u8,
((cached_at >> 16) & 0xff) as u8,
((cached_at >> 8) & 0xff) as u8,
(cached_at & 0xff) as u8,
];
Ok(Self(bytes))
}
fn bucket_name_len(self) -> usize {
((self.0[0] & 0b0011_1111) as usize) + 1
}
fn object_size(self) -> u64 {
u64::from_be_bytes([
0, 0, 0, self.0[1], self.0[2], self.0[3], self.0[4], self.0[5],
])
}
fn data_len(self) -> usize {
if self.0[0] & 0b0100_0000 != 0 {
return 0;
}
let data_len_minus_one = u32::from_be_bytes([0, self.0[6], self.0[7], self.0[8]]);
(data_len_minus_one + 1) as usize
}
fn mtime(self) -> u32 {
u32::from_be_bytes([self.0[9], self.0[10], self.0[11], self.0[12]])
}
fn cached_at(self) -> u32 {
u32::from_be_bytes([self.0[13], self.0[14], self.0[15], self.0[16]])
}
fn from_bytes(bytes: [u8; 17]) -> Result<Self, &'static str> {
if bytes[0] & 0b1000_0000 != 0 {
return Err("Invalid header");
}
let header = Self(bytes);
let empty = header.0[0] & 0b0100_0000 != 0;
let data_len_minus_one = u32::from_be_bytes([0, header.0[6], header.0[7], header.0[8]]);
if empty {
if data_len_minus_one != 0 {
return Err("Invalid header");
}
} else if (data_len_minus_one + 1) > (1 << 24) {
return Err("Invalid header");
}
Ok(header)
}
fn to_bytes(self) -> [u8; 17] {
self.0
}
}
impl foyer::Code for CacheValue {
fn encode(&self, writer: &mut impl std::io::Write) -> foyer::Result<()> {
let flag = CacheValueHeader::new(
self.bucket.len(),
self.object_size,
self.mtime,
self.data.len(),
self.cached_at,
)
.map_err(|msg| std::io::Error::new(std::io::ErrorKind::InvalidData, msg))?;
writer.write_all(&flag.to_bytes())?;
writer.write_all(self.bucket.as_bytes())?;
writer.write_all(&self.data)?;
Ok(())
}
fn decode(reader: &mut impl std::io::Read) -> foyer::Result<Self>
where
Self: Sized,
{
let header = {
let mut buf = [0u8; 17];
reader.read_exact(&mut buf)?;
CacheValueHeader::from_bytes(buf)
.map_err(|msg| std::io::Error::new(std::io::ErrorKind::InvalidData, msg))?
};
let bucket = {
let mut buf = BytesMut::zeroed(header.bucket_name_len());
reader.read_exact(&mut buf)?;
let str = CompactString::from_utf8(buf).map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Invalid UTF-8 in bucket name",
)
})?;
BucketName::new(str)
.map_err(|msg| std::io::Error::new(std::io::ErrorKind::InvalidData, msg))?
};
let data = {
let mut buf = BytesMut::zeroed(header.data_len());
reader.read_exact(&mut buf)?;
buf.freeze()
};
Ok(Self {
bucket,
object_size: header.object_size(),
mtime: header.mtime(),
data,
cached_at: header.cached_at(),
})
}
fn estimated_size(&self) -> usize {
size_of::<CacheValueHeader>() + self.bucket.len() + self.data.len()
}
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use super::*;
use crate::service::PAGE_SIZE;
#[test]
fn test_cache_key_header() {
let header = CacheKeyHeader::new(255, 63, 1024, 65535).unwrap();
assert_eq!(header.version(), 255);
assert_eq!(header.kind_len(), 63);
assert_eq!(header.key_len(), 1024);
assert_eq!(header.page_id(), 65535);
let bytes = header.to_bytes();
let decoded = CacheKeyHeader::from_bytes(bytes).unwrap();
assert_eq!(header, decoded);
let header_max = CacheKeyHeader::new(1, 64, 1024, 65535).unwrap();
assert_eq!(header_max.kind_len(), 64);
let bytes_max = header_max.to_bytes();
let decoded_max = CacheKeyHeader::from_bytes(bytes_max).unwrap();
assert_eq!(decoded_max.kind_len(), 64);
assert!(CacheKeyHeader::new(0, 0, 0, 0).is_err()); assert!(CacheKeyHeader::new(0, 65, 0, 0).is_err()); assert!(CacheKeyHeader::new(0, 1, 0, 0).is_err()); assert!(CacheKeyHeader::new(0, 1, 1025, 0).is_err()); }
#[test]
fn test_cache_value_header() {
let header =
CacheValueHeader::new(63, (1 << 40) - 1, u32::MAX, (1 << 24) - 1, 1_700_000_000)
.unwrap();
assert_eq!(header.bucket_name_len(), 63);
assert_eq!(header.object_size(), (1 << 40) - 1);
assert_eq!(header.mtime(), u32::MAX);
assert_eq!(header.data_len(), (1 << 24) - 1);
assert_eq!(header.cached_at(), 1_700_000_000);
let bytes = header.to_bytes();
let decoded = CacheValueHeader::from_bytes(bytes).unwrap();
assert_eq!(header.0, decoded.0);
let header_max =
CacheValueHeader::new(64, (1 << 40) - 1, u32::MAX, (1 << 24) - 1, 1_700_000_000)
.unwrap();
assert_eq!(header_max.bucket_name_len(), 64);
let bytes_max = header_max.to_bytes();
let decoded_max = CacheValueHeader::from_bytes(bytes_max).unwrap();
assert_eq!(decoded_max.bucket_name_len(), 64);
assert!(CacheValueHeader::new(0, 0, 0, 0, 0).is_err()); assert!(CacheValueHeader::new(65, 0, 0, 0, 0).is_err()); assert!(CacheValueHeader::new(1, 1 << 40, 0, 0, 0).is_err()); assert!(CacheValueHeader::new(1, 0, 0, (1 << 24) + 1, 0).is_err()); }
#[test]
fn test_cache_value_header_supports_page_size() {
let header = CacheValueHeader::new(1, PAGE_SIZE, 0, PAGE_SIZE as usize, 0).unwrap();
assert_eq!(header.data_len(), PAGE_SIZE as usize);
}
#[test]
fn test_cache_value_header_supports_empty_data() {
let header = CacheValueHeader::new(1, 0, 0, 0, 0).unwrap();
assert_eq!(header.data_len(), 0);
let decoded = CacheValueHeader::from_bytes(header.to_bytes()).unwrap();
assert_eq!(decoded.data_len(), 0);
}
#[test]
fn test_cache_key_encode_decode() {
let key = CacheKey {
kind: ObjectKind::new("test-kind").unwrap(),
object: ObjectKey::new("test/object.txt").unwrap(),
page_id: 42,
};
let mut encoded = Vec::new();
key.encode(&mut encoded).unwrap();
let mut reader = std::io::Cursor::new(encoded);
let decoded = CacheKey::decode(&mut reader).unwrap();
assert_eq!(key, decoded);
}
#[test]
fn test_cache_value_encode_decode() {
let value = CacheValue {
bucket: BucketName::new("test-bucket").unwrap(),
mtime: 1_234_567_890,
object_size: 9_876_543_210,
data: bytes::Bytes::from(vec![1, 2, 3, 4, 5]),
cached_at: 1_700_000_000,
};
let mut encoded = Vec::new();
value.encode(&mut encoded).unwrap();
let mut reader = std::io::Cursor::new(encoded);
let decoded = CacheValue::decode(&mut reader).unwrap();
assert_eq!(value.bucket, decoded.bucket);
assert_eq!(value.mtime, decoded.mtime);
assert_eq!(value.object_size, decoded.object_size);
assert_eq!(value.data, decoded.data);
assert_eq!(value.cached_at, decoded.cached_at);
}
#[test]
fn test_max_length_bucket_and_kind() {
let bucket_64 = "a".repeat(64);
let kind_64 = "b".repeat(64);
let key = "c".repeat(100);
let cache_key = CacheKey {
kind: ObjectKind::new(kind_64).unwrap(),
object: ObjectKey::new(key).unwrap(),
page_id: 42,
};
let mut encoded_key = Vec::new();
cache_key.encode(&mut encoded_key).unwrap();
let mut reader = std::io::Cursor::new(encoded_key);
let decoded_key = CacheKey::decode(&mut reader).unwrap();
assert_eq!(cache_key.kind.len(), 64);
assert_eq!(decoded_key.kind.len(), 64);
assert_eq!(cache_key, decoded_key);
let cache_value = CacheValue {
bucket: BucketName::new(bucket_64).unwrap(),
mtime: 1_234_567_890,
object_size: 9_876_543_210,
data: bytes::Bytes::from(vec![1, 2, 3, 4, 5]),
cached_at: 1_700_000_000,
};
let mut encoded_value = Vec::new();
cache_value.encode(&mut encoded_value).unwrap();
let mut reader = std::io::Cursor::new(encoded_value);
let decoded_value = CacheValue::decode(&mut reader).unwrap();
assert_eq!(cache_value.bucket.len(), 64);
assert_eq!(decoded_value.bucket.len(), 64);
assert_eq!(cache_value.bucket, decoded_value.bucket);
}
proptest! {
#[test]
fn prop_cache_key_header_roundtrip(
version in 0u8..=255,
kind_len in 1usize..=64,
key_len in 1usize..=1024,
page_id in 0u16..=u16::MAX
) {
let header = CacheKeyHeader::new(version, kind_len, key_len, page_id).unwrap();
let bytes = header.to_bytes();
let decoded = CacheKeyHeader::from_bytes(bytes).unwrap();
prop_assert_eq!(header.version(), decoded.version());
prop_assert_eq!(header.kind_len(), decoded.kind_len());
prop_assert_eq!(header.key_len(), decoded.key_len());
prop_assert_eq!(header.page_id(), decoded.page_id());
}
#[test]
fn prop_cache_value_header_roundtrip(
bucket_name_len in 1usize..=64,
object_size in 0u64..(1u64 << 40),
mtime in 0u32..=u32::MAX,
data_len in 0usize..=(1 << 24),
cached_at in 0u32..=u32::MAX
) {
let header = CacheValueHeader::new(bucket_name_len, object_size, mtime, data_len, cached_at).unwrap();
let bytes = header.to_bytes();
let decoded = CacheValueHeader::from_bytes(bytes).unwrap();
prop_assert_eq!(header.bucket_name_len(), decoded.bucket_name_len());
prop_assert_eq!(header.object_size(), decoded.object_size());
prop_assert_eq!(header.mtime(), decoded.mtime());
prop_assert_eq!(header.data_len(), decoded.data_len());
prop_assert_eq!(header.cached_at(), decoded.cached_at());
}
#[test]
fn prop_cache_key_roundtrip(
kind in "[a-z0-9.-]{1,63}",
object in "[a-zA-Z0-9/_.-]{1,1000}",
page_id in 0u16..=u16::MAX
) {
let kind = ObjectKind::new(kind);
prop_assume!(kind.is_ok());
let object = ObjectKey::new(object);
prop_assume!(object.is_ok());
let key = CacheKey {
kind: kind.unwrap(),
object: object.unwrap(),
page_id,
};
let mut encoded = Vec::new();
key.encode(&mut encoded).unwrap();
prop_assert_eq!(key.estimated_size(), encoded.len());
let mut reader = std::io::Cursor::new(encoded);
let decoded = CacheKey::decode(&mut reader).unwrap();
prop_assert_eq!(key, decoded);
}
#[test]
fn prop_cache_value_roundtrip(
bucket_name in "[a-z0-9.-]{3,63}",
mtime in 0u32..=u32::MAX,
object_size in 0u64..(1u64 << 40),
data in prop::collection::vec(0u8..=255, 0..1000),
cached_at in 0u32..=u32::MAX
) {
let bucket = BucketName::new(bucket_name);
prop_assume!(bucket.is_ok());
let value = CacheValue {
bucket: bucket.unwrap(),
mtime,
object_size,
data: bytes::Bytes::from(data),
cached_at,
};
let mut encoded = Vec::new();
value.encode(&mut encoded).unwrap();
prop_assert_eq!(value.estimated_size(), encoded.len());
let mut reader = std::io::Cursor::new(encoded);
let decoded = CacheValue::decode(&mut reader).unwrap();
prop_assert_eq!(value.bucket, decoded.bucket);
prop_assert_eq!(value.mtime, decoded.mtime);
prop_assert_eq!(value.object_size, decoded.object_size);
prop_assert_eq!(value.data, decoded.data);
prop_assert_eq!(value.cached_at, decoded.cached_at);
}
}
#[test]
fn test_cache_key_decode_errors() {
let data = vec![0xFF, 0xFF, 0xFF, 0xFF]; let mut reader = std::io::Cursor::new(data);
assert!(CacheKey::decode(&mut reader).is_err());
let header = CacheKeyHeader::new(0, 4, 4, 0).unwrap(); let mut data = Vec::new();
data.extend_from_slice(&header.to_bytes());
data.extend_from_slice(b"kind"); data.extend_from_slice(b"test"); let mut reader = std::io::Cursor::new(data);
assert!(CacheKey::decode(&mut reader).is_err());
let header = CacheKeyHeader::new(CacheKey::VERSION, 4, 4, 0).unwrap();
let mut data = Vec::new();
data.extend_from_slice(&header.to_bytes());
data.extend_from_slice(&[0xFF, 0xFF, 0xFF, 0xFF]); data.extend_from_slice(b"test"); let mut reader = std::io::Cursor::new(data);
assert!(CacheKey::decode(&mut reader).is_err());
let header = CacheKeyHeader::new(CacheKey::VERSION, 4, 4, 0).unwrap();
let mut data = Vec::new();
data.extend_from_slice(&header.to_bytes());
data.extend_from_slice(b"kind"); data.extend_from_slice(&[0xFF, 0xFF, 0xFF, 0xFF]); let mut reader = std::io::Cursor::new(data);
assert!(CacheKey::decode(&mut reader).is_err());
}
#[test]
fn test_cache_value_decode_errors() {
let data = vec![0xFF; 17]; let mut reader = std::io::Cursor::new(data);
assert!(CacheValue::decode(&mut reader).is_err());
let header = CacheValueHeader::new(4, 0, 0, 0, 0).unwrap();
let mut data = Vec::new();
data.extend_from_slice(&header.to_bytes());
data.extend_from_slice(&[0xFF, 0xFF, 0xFF, 0xFF]); let mut reader = std::io::Cursor::new(data);
assert!(CacheValue::decode(&mut reader).is_err());
}
}