use async_trait::async_trait;
use blake3;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use super::bounded::BoundedCache;
use super::policy::{Cache, RidKeyed};
use crate::commit::TenantId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct QueryCacheKey {
pub query_hash: [u8; 32],
pub tenant_id: TenantId,
pub namespace_hash: [u8; 32],
pub top_k: u32,
pub params_hash: [u8; 32],
}
pub struct QueryCacheKeyBuilder {
tenant_id: TenantId,
query_hasher: blake3::Hasher,
namespace_hasher: blake3::Hasher,
params_hasher: blake3::Hasher,
top_k: u32,
}
impl QueryCacheKeyBuilder {
pub fn new(tenant_id: TenantId) -> Self {
Self {
tenant_id,
query_hasher: blake3::Hasher::new(),
namespace_hasher: blake3::Hasher::new(),
params_hasher: blake3::Hasher::new(),
top_k: 0,
}
}
pub fn query(mut self, text: &str) -> Self {
self.query_hasher.update(text.as_bytes());
self
}
pub fn namespace(mut self, namespace: &str) -> Self {
self.namespace_hasher.update(namespace.as_bytes());
self
}
pub fn top_k(mut self, top_k: u32) -> Self {
self.top_k = top_k;
self
}
pub fn expand_entities(mut self, expand: bool) -> Self {
self.params_hasher.update(b"expand=");
self.params_hasher.update(if expand { b"1" } else { b"0" });
self.params_hasher.update(b";");
self
}
pub fn model_version(mut self, model: &str) -> Self {
self.params_hasher.update(b"model=");
self.params_hasher.update(model.as_bytes());
self.params_hasher.update(b";");
self
}
pub fn param(mut self, name: &str, value: &str) -> Self {
self.params_hasher.update(name.as_bytes());
self.params_hasher.update(b"=");
self.params_hasher.update(value.as_bytes());
self.params_hasher.update(b";");
self
}
pub fn build(self) -> QueryCacheKey {
let mut q = [0u8; 32];
q.copy_from_slice(self.query_hasher.finalize().as_bytes());
let mut n = [0u8; 32];
n.copy_from_slice(self.namespace_hasher.finalize().as_bytes());
let mut p = [0u8; 32];
p.copy_from_slice(self.params_hasher.finalize().as_bytes());
QueryCacheKey {
query_hash: q,
tenant_id: self.tenant_id,
namespace_hash: n,
top_k: self.top_k,
params_hash: p,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct CachedQueryResult<R> {
pub tenant_id: TenantId,
pub rids: Vec<String>,
pub results: Vec<R>,
}
impl<R> RidKeyed for CachedQueryResult<R> {
fn tenant_id(&self) -> TenantId {
self.tenant_id
}
fn rids(&self) -> Vec<String> {
self.rids.clone()
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct QueryResultCacheConfig {
pub max_entries: usize,
pub ttl_secs: u64,
}
impl Default for QueryResultCacheConfig {
fn default() -> Self {
Self {
max_entries: 10_000,
ttl_secs: 60,
}
}
}
#[derive(Clone)]
pub struct QueryResultCache<R>
where
R: Clone + Send + Sync + 'static,
{
inner: BoundedCache<QueryCacheKey, CachedQueryResult<R>>,
}
impl<R> QueryResultCache<R>
where
R: Clone + Send + Sync + 'static,
{
pub fn new(config: QueryResultCacheConfig) -> Self {
Self {
inner: BoundedCache::new(
config.max_entries,
Some(Duration::from_secs(config.ttl_secs)),
),
}
}
pub fn config_max_entries(&self) -> usize {
self.inner.max_entries()
}
pub fn config_ttl(&self) -> Option<Duration> {
self.inner.ttl()
}
pub fn sweep_expired(&self) -> usize {
self.inner.sweep_expired()
}
}
#[async_trait]
impl<R> Cache<QueryCacheKey, CachedQueryResult<R>> for QueryResultCache<R>
where
R: Clone + Send + Sync + 'static,
{
async fn get(&self, key: &QueryCacheKey) -> Option<CachedQueryResult<R>> {
self.inner.get(key)
}
async fn put(&self, key: QueryCacheKey, value: CachedQueryResult<R>) {
self.inner.put(key, value);
}
async fn invalidate(&self, key: &QueryCacheKey) {
self.inner.invalidate(key);
}
async fn clear(&self) {
self.inner.clear();
}
async fn len(&self) -> usize {
self.inner.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cache::policy::{NoopTombstoneProvider, TombstoneAwareCache, TombstoneProvider};
use std::sync::Arc;
use std::thread;
fn key(text: &str, namespace: &str, top_k: u32) -> QueryCacheKey {
QueryCacheKeyBuilder::new(TenantId::new(1))
.query(text)
.namespace(namespace)
.top_k(top_k)
.expand_entities(false)
.model_version("minilm-l6-v2")
.build()
}
fn cached_result(rids: Vec<&str>) -> CachedQueryResult<String> {
CachedQueryResult {
tenant_id: TenantId::new(1),
rids: rids.iter().map(|s| s.to_string()).collect(),
results: rids.iter().map(|s| format!("row[{s}]")).collect(),
}
}
#[test]
fn key_builder_is_deterministic() {
let k1 = key("query", "ns", 10);
let k2 = key("query", "ns", 10);
assert_eq!(k1, k2);
}
#[test]
fn key_changes_on_query_text_change() {
let k1 = key("query a", "ns", 10);
let k2 = key("query b", "ns", 10);
assert_ne!(k1, k2);
}
#[test]
fn key_changes_on_namespace_change() {
let k1 = key("query", "ns_a", 10);
let k2 = key("query", "ns_b", 10);
assert_ne!(k1, k2);
}
#[test]
fn key_changes_on_top_k_change() {
let k1 = key("query", "ns", 10);
let k2 = key("query", "ns", 20);
assert_ne!(k1, k2);
}
#[test]
fn key_changes_on_tenant_change() {
let k1 = QueryCacheKeyBuilder::new(TenantId::new(1))
.query("q")
.build();
let k2 = QueryCacheKeyBuilder::new(TenantId::new(2))
.query("q")
.build();
assert_ne!(k1, k2);
}
#[test]
fn key_changes_on_expand_entities_flip() {
let k1 = QueryCacheKeyBuilder::new(TenantId::new(1))
.query("q")
.expand_entities(false)
.build();
let k2 = QueryCacheKeyBuilder::new(TenantId::new(1))
.query("q")
.expand_entities(true)
.build();
assert_ne!(k1, k2);
}
#[test]
fn key_changes_on_model_upgrade() {
let k1 = QueryCacheKeyBuilder::new(TenantId::new(1))
.query("q")
.model_version("minilm-l6-v2")
.build();
let k2 = QueryCacheKeyBuilder::new(TenantId::new(1))
.query("q")
.model_version("bge-base")
.build();
assert_ne!(k1, k2);
}
#[tokio::test]
async fn put_then_get_returns_cached_results() {
let c: QueryResultCache<String> = QueryResultCache::new(QueryResultCacheConfig::default());
let k = key("q", "n", 10);
c.put(k, cached_result(vec!["r1", "r2", "r3"])).await;
let back = c.get(&k).await.unwrap();
assert_eq!(back.rids, vec!["r1", "r2", "r3"]);
assert_eq!(back.results.len(), 3);
}
#[tokio::test]
async fn miss_returns_none() {
let c: QueryResultCache<String> = QueryResultCache::new(QueryResultCacheConfig::default());
assert!(c.get(&key("q", "n", 10)).await.is_none());
}
#[tokio::test]
async fn ttl_expires_entry() {
let c: QueryResultCache<String> = QueryResultCache::new(QueryResultCacheConfig {
max_entries: 10,
ttl_secs: 0, });
let k = key("q", "n", 10);
c.put(k, cached_result(vec!["r1"])).await;
thread::sleep(Duration::from_millis(20));
assert!(c.get(&k).await.is_none(), "TTL=0 should always be expired");
}
#[tokio::test]
async fn cached_query_result_implements_rid_keyed() {
let r = cached_result(vec!["r1", "r2"]);
assert_eq!(
<CachedQueryResult<String> as RidKeyed>::tenant_id(&r),
TenantId::new(1)
);
assert_eq!(
<CachedQueryResult<String> as RidKeyed>::rids(&r),
vec!["r1", "r2"]
);
}
#[tokio::test]
async fn empty_rids_means_no_tombstone_check_in_wrapper() {
let r: CachedQueryResult<String> = CachedQueryResult {
tenant_id: TenantId::new(1),
rids: vec![],
results: vec![],
};
assert!(<CachedQueryResult<String> as RidKeyed>::rids(&r).is_empty());
}
#[tokio::test]
async fn composes_with_tombstone_aware_wrapper() {
let inner: QueryResultCache<String> =
QueryResultCache::new(QueryResultCacheConfig::default());
let wrapped: TombstoneAwareCache<QueryCacheKey, CachedQueryResult<String>, _> =
TombstoneAwareCache::new(inner, Arc::new(NoopTombstoneProvider));
let k = key("q", "n", 10);
wrapped.put(k, cached_result(vec!["r1", "r2"])).await;
let back = wrapped.get(&k).await.unwrap();
assert_eq!(back.rids, vec!["r1", "r2"]);
}
struct FakeTombstones {
rids: std::collections::HashSet<String>,
tenant: TenantId,
}
#[async_trait]
impl TombstoneProvider for FakeTombstones {
async fn is_tombstoned(&self, tenant_id: TenantId, rid: &str) -> bool {
tenant_id == self.tenant && self.rids.contains(rid)
}
}
#[tokio::test]
async fn tombstone_aware_wrapper_invalidates_on_match() {
let inner: QueryResultCache<String> =
QueryResultCache::new(QueryResultCacheConfig::default());
let mut tomb_set = std::collections::HashSet::new();
tomb_set.insert("r2".to_string());
let wrapped = TombstoneAwareCache::new(
inner,
Arc::new(FakeTombstones {
rids: tomb_set,
tenant: TenantId::new(1),
}),
);
let k = key("q", "n", 10);
wrapped.put(k, cached_result(vec!["r1", "r2", "r3"])).await;
assert!(wrapped.get(&k).await.is_none());
assert_eq!(wrapped.len().await, 0);
}
#[tokio::test]
async fn capacity_evicts_least_recently_used() {
let c: QueryResultCache<String> = QueryResultCache::new(QueryResultCacheConfig {
max_entries: 2,
ttl_secs: 60,
});
let k1 = key("q1", "n", 10);
let k2 = key("q2", "n", 10);
let k3 = key("q3", "n", 10);
c.put(k1, cached_result(vec!["r1"])).await;
c.put(k2, cached_result(vec!["r2"])).await;
let _ = c.get(&k1).await; c.put(k3, cached_result(vec!["r3"])).await; assert!(c.get(&k1).await.is_some());
assert!(c.get(&k2).await.is_none());
assert!(c.get(&k3).await.is_some());
}
#[tokio::test]
async fn sweep_expired_caller_driven() {
let c: QueryResultCache<String> = QueryResultCache::new(QueryResultCacheConfig {
max_entries: 10,
ttl_secs: 0,
});
c.put(key("q1", "n", 10), cached_result(vec!["r1"])).await;
c.put(key("q2", "n", 10), cached_result(vec!["r2"])).await;
thread::sleep(Duration::from_millis(20));
let removed = c.sweep_expired();
assert_eq!(removed, 2);
assert_eq!(c.len().await, 0);
}
#[test]
fn config_defaults_match_spec() {
let cfg = QueryResultCacheConfig::default();
assert_eq!(cfg.max_entries, 10_000);
assert_eq!(cfg.ttl_secs, 60);
}
}