use crate::backend::interface::{BackendKind, CacheBackend, CacheConnector, CacheReader, CacheWriter};
use crate::backend::score::BackendScore;
use crate::error::{CacheError, Result};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tracing::instrument;
#[derive(Clone)]
pub struct ChainLink {
backend: Arc<dyn CacheBackend>,
score: u8,
is_persistent: bool,
name: &'static str,
}
impl ChainLink {
pub fn new<B>(backend: B, score: u8, is_persistent: bool, name: &'static str) -> Self
where
B: CacheBackend + BackendScore + 'static,
{
Self {
backend: Arc::new(backend),
score,
is_persistent,
name,
}
}
pub fn from_backend<B>(backend: B) -> Self
where
B: CacheBackend + BackendScore + 'static,
{
let score = backend.score();
let is_persistent = backend.is_persistent();
let name = backend.backend_name();
Self {
backend: Arc::new(backend),
score,
is_persistent,
name,
}
}
pub fn backend(&self) -> &Arc<dyn CacheBackend> {
&self.backend
}
pub fn score(&self) -> u8 {
self.score
}
pub fn is_persistent(&self) -> bool {
self.is_persistent
}
pub fn name(&self) -> &'static str {
self.name
}
}
impl std::fmt::Debug for ChainLink {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChainLink")
.field("score", &self.score)
.field("is_persistent", &self.is_persistent)
.field("name", &self.name)
.finish()
}
}
pub struct ChainCache {
links: Vec<ChainLink>,
backfill_enabled: bool,
default_ttl: Option<Duration>,
}
impl ChainCache {
pub fn new(links: Vec<ChainLink>) -> Self {
Self::builder().links(links).build()
}
pub fn builder() -> ChainCacheBuilder {
ChainCacheBuilder::default()
}
pub fn links(&self) -> &[ChainLink] {
&self.links
}
pub fn len(&self) -> usize {
self.links.len()
}
pub fn is_empty(&self) -> bool {
self.links.is_empty()
}
pub fn get_by_score(&self, score: u8) -> Option<&ChainLink> {
self.links.iter().find(|link| link.score() == score)
}
pub fn highest_score_backend(&self) -> Option<&ChainLink> {
self.links.first()
}
pub fn lowest_score_backend(&self) -> Option<&ChainLink> {
self.links.last()
}
pub fn persistent_backends(&self) -> Vec<&ChainLink> {
self.links.iter().filter(|link| link.is_persistent()).collect()
}
pub fn non_persistent_backends(&self) -> Vec<&ChainLink> {
self.links.iter().filter(|link| !link.is_persistent()).collect()
}
#[instrument(skip(self), fields(key = %key))]
async fn read_from_chain(&self, key: &str) -> Result<Option<Vec<u8>>> {
for (index, link) in self.links.iter().enumerate() {
match link.backend().get(key).await {
Ok(Some(value)) => {
if self.backfill_enabled && index > 0 {
self.backfill_to_higher_backends(key, &value, index).await;
}
return Ok(Some(value));
}
Ok(None) => continue,
Err(_) => continue,
}
}
Ok(None)
}
async fn backfill_to_higher_backends(&self, key: &str, value: &[u8], from_index: usize) {
for link in &self.links[..from_index] {
let _ = link.backend().set(key, value.to_vec(), None).await;
}
}
#[instrument(skip(self, value), fields(key = %key))]
async fn write_to_all_backends(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<()> {
let mut errors = Vec::new();
let count = self.links.len();
if count == 0 {
return Ok(());
}
let effective_ttl = ttl.or(self.default_ttl);
for link in self.links.iter().take(count - 1) {
if let Err(e) = link.backend().set(key, value.clone(), effective_ttl).await {
errors.push((link.name(), e));
}
}
if let Some(link) = self.links.last() {
if let Err(e) = link.backend().set(key, value, effective_ttl).await {
errors.push((link.name(), e));
}
}
if errors.len() == self.links.len() {
return Err(CacheError::Operation("All backends failed to write".to_string()));
}
Ok(())
}
#[instrument(skip(self), fields(key = %key))]
async fn delete_from_all_backends(&self, key: &str) -> Result<()> {
let mut errors = Vec::new();
for link in &self.links {
if let Err(e) = link.backend().delete(key).await {
errors.push((link.name(), e));
}
}
if errors.len() == self.links.len() {
return Err(CacheError::Operation(format!(
"All backends failed to delete: {:?}",
errors
)));
}
Ok(())
}
}
#[async_trait]
impl CacheReader for ChainCache {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
if self.links.is_empty() {
return Ok(None);
}
self.read_from_chain(key).await
}
async fn exists(&self, key: &str) -> Result<bool> {
for link in &self.links {
match link.backend().exists(key).await {
Ok(true) => return Ok(true),
Ok(false) => continue,
Err(_) => continue,
}
}
Ok(false)
}
async fn ttl(&self, key: &str) -> Result<Option<Duration>> {
for link in &self.links {
match link.backend().ttl(key).await {
Ok(Some(ttl)) => return Ok(Some(ttl)),
Ok(None) => continue,
Err(_) => continue,
}
}
Ok(None)
}
async fn len(&self) -> Result<u64> {
if let Some(link) = self.links.first() {
link.backend().len().await
} else {
Ok(0)
}
}
async fn is_empty(&self) -> Result<bool> {
if let Some(link) = self.links.first() {
link.backend().is_empty().await
} else {
Ok(true)
}
}
async fn capacity(&self) -> Result<u64> {
if let Some(link) = self.links.first() {
link.backend().capacity().await
} else {
Ok(0)
}
}
async fn stats(&self) -> Result<HashMap<String, String>> {
let mut stats = HashMap::new();
stats.insert("type".to_string(), "chain".to_string());
stats.insert("backend_count".to_string(), self.links.len().to_string());
for (index, link) in self.links.iter().enumerate() {
stats.insert(format!("backend_{}_name", index), link.name().to_string());
stats.insert(format!("backend_{}_score", index), link.score().to_string());
}
Ok(stats)
}
}
#[async_trait]
impl CacheWriter for ChainCache {
async fn set(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<()> {
if self.links.is_empty() {
return Err(CacheError::Operation("Chain has no backends".to_string()));
}
self.write_to_all_backends(key, value, ttl).await
}
async fn delete(&self, key: &str) -> Result<()> {
if self.links.is_empty() {
return Ok(());
}
self.delete_from_all_backends(key).await
}
async fn clear(&self) -> Result<()> {
let mut errors = Vec::new();
for link in &self.links {
if let Err(e) = link.backend().clear().await {
errors.push((link.name(), e));
}
}
if errors.len() == self.links.len() && !self.links.is_empty() {
return Err(CacheError::Operation(format!(
"All backends failed to clear: {:?}",
errors
)));
}
Ok(())
}
async fn expire(&self, key: &str, ttl: Duration) -> Result<bool> {
let mut any_success = false;
for link in &self.links {
match link.backend().expire(key, ttl).await {
Ok(true) => any_success = true,
_ => continue,
}
}
Ok(any_success)
}
}
#[async_trait]
impl CacheConnector for ChainCache {
async fn health_check(&self) -> Result<()> {
if self.links.is_empty() {
return Ok(());
}
for link in &self.links {
link.backend().health_check().await?;
}
Ok(())
}
async fn shutdown(&self) {
for link in &self.links {
link.backend().shutdown().await;
}
}
fn backend_kind(&self) -> BackendKind {
BackendKind::Chain
}
}
#[derive(Default)]
pub struct ChainCacheBuilder {
links: Vec<ChainLink>,
backfill_enabled: bool,
default_ttl: Option<Duration>,
}
impl ChainCacheBuilder {
pub fn link(mut self, link: ChainLink) -> Self {
self.links.push(link);
self
}
pub fn links(mut self, mut links: Vec<ChainLink>) -> Self {
self.links.append(&mut links);
self
}
pub fn backend<B>(self, backend: B) -> Self
where
B: CacheBackend + BackendScore + 'static,
{
self.link(ChainLink::from_backend(backend))
}
pub fn default_time_to_live(mut self, ttl: Duration) -> Self {
self.default_ttl = Some(ttl);
self
}
pub fn enable_backfill(mut self) -> Self {
self.backfill_enabled = true;
self
}
pub fn disable_backfill(mut self) -> Self {
self.backfill_enabled = false;
self
}
pub fn build(self) -> ChainCache {
let mut links = self.links;
links.sort_by_key(|link| std::cmp::Reverse(link.score()));
ChainCache {
links,
backfill_enabled: self.backfill_enabled,
default_ttl: self.default_ttl,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::MokaMemoryBackend;
use crate::testing::mock::MockBackend;
#[test]
fn test_chain_link_creation() {
let backend = MockBackend::new("test", 50, false);
let link = ChainLink::from_backend(backend);
assert_eq!(link.score(), 50);
assert!(!link.is_persistent());
assert_eq!(link.name(), "test");
}
#[test]
fn test_chain_cache_builder() {
let high = MockBackend::new("high", 100, false);
let low = MockBackend::new("low", 50, true);
let chain = ChainCache::builder()
.backend(low)
.backend(high)
.enable_backfill()
.build();
assert_eq!(chain.links().len(), 2);
assert_eq!(chain.links()[0].score(), 100);
assert_eq!(chain.links()[1].score(), 50);
}
#[tokio::test]
async fn test_chain_cache_get_set() {
let high = MockBackend::new("high", 100, false);
let low = MockBackend::new("low", 50, true);
let chain = ChainCache::builder().backend(high).backend(low).build();
chain.set("key", b"value".to_vec(), None).await.unwrap();
let value = chain.get("key").await.unwrap();
assert_eq!(value, Some(b"value".to_vec()));
}
#[tokio::test]
async fn test_chain_cache_delete() {
let high = MockBackend::new("high", 100, false);
let low = MockBackend::new("low", 50, true);
let chain = ChainCache::builder().backend(high).backend(low).build();
chain.set("key", b"value".to_vec(), None).await.unwrap();
chain.delete("key").await.unwrap();
let exists = chain.exists("key").await.unwrap();
assert!(!exists);
}
#[tokio::test]
async fn test_chain_cache_backfill() {
let chain = ChainCache::builder()
.link(ChainLink::new(MockBackend::new("high", 100, false), 100, false, "high"))
.link(ChainLink::new(MockBackend::new("low", 50, true), 50, true, "low"))
.enable_backfill()
.build();
chain.set("key", b"value".to_vec(), None).await.unwrap();
let value = chain.get("key").await.unwrap();
assert_eq!(value, Some(b"value".to_vec()));
}
#[tokio::test]
async fn test_empty_chain() {
let chain = ChainCache::new(vec![]);
let value = chain.get("key").await.unwrap();
assert!(value.is_none());
let exists = chain.exists("key").await.unwrap();
assert!(!exists);
}
#[test]
fn test_chain_link_new_constructor() {
let backend = MokaMemoryBackend::new();
let link = ChainLink::new(backend, 75, true, "custom");
assert_eq!(link.score(), 75);
assert!(link.is_persistent());
assert_eq!(link.name(), "custom");
let _backend_ref = link.backend();
}
#[test]
fn test_chain_link_from_backend_moka() {
let backend = MokaMemoryBackend::new();
let link = ChainLink::from_backend(backend);
assert_eq!(link.score(), 100);
assert!(!link.is_persistent());
assert_eq!(link.name(), "moka");
}
#[test]
fn test_chain_link_debug() {
let backend = MokaMemoryBackend::new();
let link = ChainLink::new(backend, 80, true, "dbg");
let debug_str = format!("{:?}", link);
assert!(debug_str.contains("ChainLink"));
assert!(debug_str.contains("80"));
assert!(debug_str.contains("dbg"));
}
#[test]
fn test_chain_cache_new_constructor() {
let link = ChainLink::from_backend(MokaMemoryBackend::new());
let chain = ChainCache::new(vec![link]);
assert_eq!(chain.len(), 1);
assert!(!chain.is_empty());
}
#[test]
fn test_chain_cache_len_is_empty() {
let empty = ChainCache::new(vec![]);
assert!(empty.is_empty());
assert_eq!(empty.len(), 0);
let chain = ChainCache::builder().backend(MokaMemoryBackend::new()).build();
assert!(!chain.is_empty());
assert_eq!(chain.len(), 1);
}
#[test]
fn test_chain_cache_get_by_score() {
let chain = ChainCache::builder()
.link(ChainLink::new(MokaMemoryBackend::new(), 100, false, "high"))
.link(ChainLink::new(MokaMemoryBackend::new(), 50, true, "low"))
.build();
assert!(chain.get_by_score(100).is_some());
assert!(chain.get_by_score(50).is_some());
assert!(chain.get_by_score(75).is_none());
}
#[test]
fn test_chain_cache_highest_lowest_backend() {
let chain = ChainCache::builder()
.link(ChainLink::new(MokaMemoryBackend::new(), 50, true, "low"))
.link(ChainLink::new(MokaMemoryBackend::new(), 100, false, "high"))
.build();
let highest = chain.highest_score_backend().unwrap();
assert_eq!(highest.score(), 100);
assert_eq!(highest.name(), "high");
let lowest = chain.lowest_score_backend().unwrap();
assert_eq!(lowest.score(), 50);
assert_eq!(lowest.name(), "low");
}
#[test]
fn test_chain_cache_highest_lowest_empty() {
let chain = ChainCache::new(vec![]);
assert!(chain.highest_score_backend().is_none());
assert!(chain.lowest_score_backend().is_none());
}
#[test]
fn test_chain_cache_persistent_filters() {
let chain = ChainCache::builder()
.link(ChainLink::new(MokaMemoryBackend::new(), 100, false, "high"))
.link(ChainLink::new(MokaMemoryBackend::new(), 50, true, "low"))
.build();
let persistent = chain.persistent_backends();
assert_eq!(persistent.len(), 1);
assert_eq!(persistent[0].name(), "low");
let non_persistent = chain.non_persistent_backends();
assert_eq!(non_persistent.len(), 1);
assert_eq!(non_persistent[0].name(), "high");
}
#[test]
fn test_chain_cache_links_accessor() {
let chain = ChainCache::builder()
.link(ChainLink::new(MokaMemoryBackend::new(), 100, false, "high"))
.build();
let links = chain.links();
assert_eq!(links.len(), 1);
assert_eq!(links[0].name(), "high");
}
#[test]
fn test_builder_link_method() {
let link = ChainLink::new(MokaMemoryBackend::new(), 100, false, "moka");
let chain = ChainCache::builder().link(link).build();
assert_eq!(chain.len(), 1);
}
#[test]
fn test_builder_links_method() {
let links = vec![
ChainLink::new(MokaMemoryBackend::new(), 100, false, "high"),
ChainLink::new(MokaMemoryBackend::new(), 50, true, "low"),
];
let chain = ChainCache::builder().links(links).build();
assert_eq!(chain.len(), 2);
assert_eq!(chain.links()[0].score(), 100);
assert_eq!(chain.links()[1].score(), 50);
}
#[tokio::test]
async fn test_builder_default_time_to_live() {
let chain = ChainCache::builder()
.backend(MokaMemoryBackend::new())
.default_time_to_live(Duration::from_secs(60))
.build();
chain.set("key", b"value".to_vec(), None).await.unwrap();
let value = chain.get("key").await.unwrap();
assert_eq!(value, Some(b"value".to_vec()));
}
#[test]
fn test_builder_disable_backfill() {
let chain = ChainCache::builder()
.backend(MokaMemoryBackend::new())
.enable_backfill()
.disable_backfill()
.build();
assert_eq!(chain.len(), 1);
}
#[tokio::test]
async fn test_chain_cache_get_bytes_set_bytes() {
use crate::UnifiedCache;
let chain = ChainCache::builder().backend(MokaMemoryBackend::new()).build();
chain.set_bytes("key", b"value".to_vec(), None).await.unwrap();
let value = chain.get_bytes("key").await.unwrap();
assert_eq!(value, Some(b"value".to_vec()));
}
#[tokio::test]
async fn test_chain_cache_get_bytes_missing() {
use crate::UnifiedCache;
let chain = ChainCache::builder().backend(MokaMemoryBackend::new()).build();
let value = chain.get_bytes("missing").await.unwrap();
assert!(value.is_none());
}
#[tokio::test]
async fn test_chain_cache_clear() {
let chain = ChainCache::builder().backend(MokaMemoryBackend::new()).build();
chain.set("key", b"value".to_vec(), None).await.unwrap();
assert!(chain.exists("key").await.unwrap());
chain.clear().await.unwrap();
assert!(!chain.exists("key").await.unwrap());
}
#[tokio::test]
async fn test_chain_cache_clear_empty() {
let chain = ChainCache::new(vec![]);
assert!(chain.clear().await.is_ok());
}
#[tokio::test]
async fn test_chain_cache_expire() {
let chain = ChainCache::builder().backend(MokaMemoryBackend::new()).build();
chain.set("key", b"value".to_vec(), None).await.unwrap();
let result = chain.expire("key", Duration::from_secs(60)).await.unwrap();
assert!(!result);
}
#[tokio::test]
async fn test_chain_cache_expire_missing_key() {
let chain = ChainCache::builder().backend(MokaMemoryBackend::new()).build();
let result = chain.expire("missing", Duration::from_secs(60)).await.unwrap();
assert!(!result);
}
#[tokio::test]
async fn test_chain_cache_set_empty_chain_error() {
let chain = ChainCache::new(vec![]);
let result = chain.set("key", b"value".to_vec(), None).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_chain_cache_delete_empty_chain() {
let chain = ChainCache::new(vec![]);
assert!(chain.delete("key").await.is_ok());
}
#[tokio::test]
async fn test_chain_cache_set_with_explicit_ttl() {
let chain = ChainCache::builder().backend(MokaMemoryBackend::new()).build();
chain
.set("key", b"value".to_vec(), Some(Duration::from_secs(60)))
.await
.unwrap();
let value = chain.get("key").await.unwrap();
assert_eq!(value, Some(b"value".to_vec()));
}
#[tokio::test]
async fn test_chain_cache_multi_backend_set_writes_all() {
let high = MokaMemoryBackend::new();
let low = MokaMemoryBackend::new();
let high_ref = high.clone();
let low_ref = low.clone();
let chain = ChainCache::builder()
.link(ChainLink::new(high, 100, false, "high"))
.link(ChainLink::new(low, 50, true, "low"))
.build();
chain.set("key", b"value".to_vec(), None).await.unwrap();
assert_eq!(high_ref.get("key").await.unwrap(), Some(b"value".to_vec()));
assert_eq!(low_ref.get("key").await.unwrap(), Some(b"value".to_vec()));
}
#[tokio::test]
async fn test_chain_cache_delete_removes_from_all() {
let high = MokaMemoryBackend::new();
let low = MokaMemoryBackend::new();
let high_ref = high.clone();
let low_ref = low.clone();
let chain = ChainCache::builder()
.link(ChainLink::new(high, 100, false, "high"))
.link(ChainLink::new(low, 50, true, "low"))
.build();
chain.set("key", b"value".to_vec(), None).await.unwrap();
chain.delete("key").await.unwrap();
assert!(high_ref.get("key").await.unwrap().is_none());
assert!(low_ref.get("key").await.unwrap().is_none());
}
#[tokio::test]
async fn test_chain_cache_backfill_populates_higher() {
let high = MokaMemoryBackend::new();
let low = MokaMemoryBackend::new();
let high_ref = high.clone();
let low_ref = low.clone();
let chain = ChainCache::builder()
.link(ChainLink::new(high, 100, false, "high"))
.link(ChainLink::new(low, 50, true, "low"))
.enable_backfill()
.build();
low_ref.set("key", b"low_value".to_vec(), None).await.unwrap();
assert!(high_ref.get("key").await.unwrap().is_none());
let value = chain.get("key").await.unwrap();
assert_eq!(value, Some(b"low_value".to_vec()));
let high_value = high_ref.get("key").await.unwrap();
assert_eq!(high_value, Some(b"low_value".to_vec()));
}
#[tokio::test]
async fn test_chain_cache_no_backfill_when_disabled() {
let high = MokaMemoryBackend::new();
let low = MokaMemoryBackend::new();
let high_ref = high.clone();
let low_ref = low.clone();
let chain = ChainCache::builder()
.link(ChainLink::new(high, 100, false, "high"))
.link(ChainLink::new(low, 50, true, "low"))
.build();
low_ref.set("key", b"low_value".to_vec(), None).await.unwrap();
let value = chain.get("key").await.unwrap();
assert_eq!(value, Some(b"low_value".to_vec()));
assert!(high_ref.get("key").await.unwrap().is_none());
}
#[tokio::test]
async fn test_chain_cache_ttl_len_capacity() {
let chain = ChainCache::builder().backend(MokaMemoryBackend::new()).build();
chain.set("key", b"value".to_vec(), None).await.unwrap();
let ttl = chain.ttl("key").await.unwrap();
assert!(ttl.is_none());
let len = CacheReader::len(&chain).await.unwrap();
assert!(len <= 100, "len should be reasonable after single insert");
let capacity = chain.capacity().await.unwrap();
assert!(capacity > 0);
}
#[tokio::test]
async fn test_chain_cache_reader_empty() {
let chain = ChainCache::new(vec![]);
assert_eq!(CacheReader::len(&chain).await.unwrap(), 0);
assert!(CacheReader::is_empty(&chain).await.unwrap());
assert_eq!(chain.capacity().await.unwrap(), 0);
let ttl = chain.ttl("key").await.unwrap();
assert!(ttl.is_none());
}
#[tokio::test]
async fn test_chain_cache_stats() {
let chain = ChainCache::builder()
.link(ChainLink::new(MokaMemoryBackend::new(), 100, false, "high"))
.link(ChainLink::new(MokaMemoryBackend::new(), 50, true, "low"))
.build();
let stats = chain.stats().await.unwrap();
assert_eq!(stats.get("type"), Some(&"chain".to_string()));
assert_eq!(stats.get("backend_count"), Some(&"2".to_string()));
assert_eq!(stats.get("backend_0_name"), Some(&"high".to_string()));
assert_eq!(stats.get("backend_0_score"), Some(&"100".to_string()));
assert_eq!(stats.get("backend_1_name"), Some(&"low".to_string()));
assert_eq!(stats.get("backend_1_score"), Some(&"50".to_string()));
}
#[tokio::test]
async fn test_chain_cache_stats_empty() {
let chain = ChainCache::new(vec![]);
let stats = chain.stats().await.unwrap();
assert_eq!(stats.get("type"), Some(&"chain".to_string()));
assert_eq!(stats.get("backend_count"), Some(&"0".to_string()));
}
#[tokio::test]
async fn test_chain_cache_health_check() {
let chain = ChainCache::builder().backend(MokaMemoryBackend::new()).build();
assert!(chain.health_check().await.is_ok());
}
#[tokio::test]
async fn test_chain_cache_health_check_empty() {
let chain = ChainCache::new(vec![]);
assert!(chain.health_check().await.is_ok());
}
#[tokio::test]
async fn test_chain_cache_shutdown() {
let chain = ChainCache::builder().backend(MokaMemoryBackend::new()).build();
chain.shutdown().await;
}
#[test]
fn test_chain_cache_backend_kind() {
let chain = ChainCache::builder().backend(MokaMemoryBackend::new()).build();
assert_eq!(chain.backend_kind(), BackendKind::Chain);
}
}