use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::RwLock;
use serde::Serialize;
use smallvec::SmallVec;
use crate::config::QueryDbConfig;
use crate::dependency::FileDep;
use crate::input::FileInputStore;
use crate::persistence::QueryDeps;
use crate::query::{DerivedQuery, QueryKey};
#[allow(dead_code)]
pub(crate) struct PersistableEntry {
pub query_type_id: u32,
pub raw_key_bytes: Arc<[u8]>,
pub raw_result_bytes: Arc<[u8]>,
pub deps: QueryDeps,
}
pub struct CachedResult {
value: Box<dyn Any + Send + Sync>,
file_deps: SmallVec<[FileDep; 8]>,
edge_revision: Option<u64>,
metadata_revision: Option<u64>,
raw_key_bytes: Arc<[u8]>,
raw_result_bytes: Arc<[u8]>,
query_type_id: u32,
persistent: bool,
}
impl CachedResult {
pub fn new<V: Clone + Send + Sync + 'static>(
value: V,
file_deps: SmallVec<[FileDep; 8]>,
edge_revision: Option<u64>,
metadata_revision: Option<u64>,
) -> Self {
let empty: Arc<[u8]> = Arc::from(Vec::<u8>::new().into_boxed_slice());
Self {
value: Box::new(value),
file_deps,
edge_revision,
metadata_revision,
raw_key_bytes: Arc::clone(&empty),
raw_result_bytes: empty,
query_type_id: 0,
persistent: false,
}
}
fn new_persistent<V: Clone + Send + Sync + 'static>(
value: V,
file_deps: SmallVec<[FileDep; 8]>,
edge_revision: Option<u64>,
metadata_revision: Option<u64>,
raw_key_bytes: Arc<[u8]>,
raw_result_bytes: Arc<[u8]>,
query_type_id: u32,
) -> Self {
Self {
value: Box::new(value),
file_deps,
edge_revision,
metadata_revision,
raw_key_bytes,
raw_result_bytes,
query_type_id,
persistent: true,
}
}
#[must_use]
pub fn downcast_value<V: Clone + 'static>(&self) -> Option<&V> {
self.value.downcast_ref::<V>()
}
#[inline]
#[must_use]
pub fn edge_revision(&self) -> Option<u64> {
self.edge_revision
}
#[inline]
#[must_use]
pub fn metadata_revision(&self) -> Option<u64> {
self.metadata_revision
}
#[inline]
#[must_use]
pub fn file_deps(&self) -> &SmallVec<[FileDep; 8]> {
&self.file_deps
}
#[inline]
#[must_use]
pub fn raw_key_bytes(&self) -> &Arc<[u8]> {
&self.raw_key_bytes
}
#[inline]
#[must_use]
pub fn raw_result_bytes(&self) -> &Arc<[u8]> {
&self.raw_result_bytes
}
#[inline]
#[must_use]
pub fn query_type_id(&self) -> u32 {
self.query_type_id
}
#[inline]
#[must_use]
pub fn persistent(&self) -> bool {
self.persistent
}
#[must_use]
pub fn validate_file_deps(&self, inputs: &FileInputStore) -> bool {
self.file_deps
.iter()
.all(|&(fid, rev)| inputs.revision(fid) == Some(rev))
}
}
impl std::fmt::Debug for CachedResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CachedResult")
.field("file_deps", &self.file_deps)
.field("edge_revision", &self.edge_revision)
.field("metadata_revision", &self.metadata_revision)
.field("raw_key_bytes_len", &self.raw_key_bytes.len())
.field("raw_result_bytes_len", &self.raw_result_bytes.len())
.field("query_type_id", &self.query_type_id)
.field("persistent", &self.persistent)
.finish_non_exhaustive()
}
}
pub struct ShardedCache {
shards: Vec<RwLock<HashMap<QueryKey, CachedResult>>>,
}
impl ShardedCache {
#[must_use]
pub fn new(shard_count: usize) -> Self {
assert!(shard_count > 0 && shard_count.is_power_of_two());
let shards = (0..shard_count)
.map(|_| RwLock::new(HashMap::new()))
.collect();
Self { shards }
}
#[inline]
#[must_use]
pub fn shard_count(&self) -> usize {
self.shards.len()
}
pub fn get_if_valid<V: Clone + 'static>(
&self,
shard_idx: usize,
key: &QueryKey,
validate: impl FnOnce(&CachedResult) -> bool,
) -> Option<V> {
let shard = self.shards[shard_idx].read();
let cached = shard.get(key)?;
if !validate(cached) {
return None;
}
cached.downcast_value::<V>().cloned()
}
pub fn get_cold_if_valid<V: Clone + Send + Sync + serde::de::DeserializeOwned + 'static>(
&self,
shard_idx: usize,
key: &QueryKey,
validate: impl FnOnce(&CachedResult) -> bool,
) -> Option<V> {
let raw_bytes_snapshot: Arc<[u8]> = {
let shard = self.shards[shard_idx].read();
let cached = shard.get(key)?;
if !validate(cached) {
return None;
}
cached.value.downcast_ref::<()>()?;
Arc::clone(&cached.raw_result_bytes)
};
let decoded: V = postcard::from_bytes(&raw_bytes_snapshot).ok()?;
{
let mut shard = self.shards[shard_idx].write();
if let Some(cached) = shard.get_mut(key) {
let still_placeholder = cached.value.downcast_ref::<()>().is_some();
let bytes_unchanged = Arc::ptr_eq(&cached.raw_result_bytes, &raw_bytes_snapshot);
if still_placeholder && bytes_unchanged {
cached.value = Box::new(decoded.clone());
}
}
}
Some(decoded)
}
pub fn insert(&self, shard_idx: usize, key: QueryKey, result: CachedResult) {
let mut shard = self.shards[shard_idx].write();
shard.insert(key, result);
}
#[allow(clippy::too_many_arguments)]
pub fn insert_query<Q: DerivedQuery>(
&self,
shard_idx: usize,
query_key: QueryKey,
key: &Q::Key,
value: Q::Value,
file_deps: SmallVec<[FileDep; 8]>,
edge_revision: Option<u64>,
metadata_revision: Option<u64>,
config: &QueryDbConfig,
) -> Result<(), postcard::Error>
where
Q::Key: Serialize,
Q::Value: Serialize,
{
if !Q::PERSISTENT {
let result = CachedResult::new(value, file_deps, edge_revision, metadata_revision);
let mut shard = self.shards[shard_idx].write();
shard.insert(query_key, result);
return Ok(());
}
let raw_key = postcard::to_allocvec(key)?;
let raw_value = postcard::to_allocvec(&value)?;
if raw_value.len() > config.max_entry_size_bytes {
log::debug!(
"sqry-db: skipping oversized cache entry (query_type_id={:#06x}, \
raw_result_bytes={} bytes, max={})",
Q::QUERY_TYPE_ID,
raw_value.len(),
config.max_entry_size_bytes,
);
return Ok(());
}
let raw_key_bytes: Arc<[u8]> = Arc::from(raw_key.into_boxed_slice());
let raw_result_bytes: Arc<[u8]> = Arc::from(raw_value.into_boxed_slice());
let result = CachedResult::new_persistent(
value,
file_deps,
edge_revision,
metadata_revision,
raw_key_bytes,
raw_result_bytes,
Q::QUERY_TYPE_ID,
);
let mut shard = self.shards[shard_idx].write();
shard.insert(query_key, result);
Ok(())
}
pub fn remove(&self, shard_idx: usize, key: &QueryKey) -> bool {
let mut shard = self.shards[shard_idx].write();
shard.remove(key).is_some()
}
pub fn clear_all(&self) {
for shard in &self.shards {
shard.write().clear();
}
}
#[must_use]
pub fn total_entries(&self) -> usize {
self.shards.iter().map(|s| s.read().len()).sum()
}
#[must_use]
pub fn shard_entry_counts(&self) -> Vec<usize> {
self.shards.iter().map(|s| s.read().len()).collect()
}
pub(crate) fn insert_validated(
&self,
query_type_id: u32,
raw_key_bytes: Arc<[u8]>,
raw_result_bytes: Arc<[u8]>,
deps: crate::persistence::QueryDeps,
) {
use std::hash::{Hash, Hasher};
let shard_idx =
crate::query::QueryRegistry::shard_for_query_type_id(query_type_id, self.shards.len());
let mut hasher = std::hash::DefaultHasher::new();
raw_key_bytes.hash(&mut hasher);
let hash = hasher.finish();
let file_deps: SmallVec<[crate::dependency::FileDep; 8]> =
deps.file_deps.iter().copied().collect();
let result = CachedResult {
value: Box::new(()),
file_deps,
edge_revision: deps.edge_revision,
metadata_revision: deps.metadata_revision,
raw_key_bytes,
raw_result_bytes,
query_type_id,
persistent: true,
};
let shard_key = QueryKey::from_raw(u64::from(query_type_id), hash);
let mut shard = self.shards[shard_idx].write();
shard.insert(shard_key, result);
}
#[allow(dead_code)]
pub(crate) fn iter_persistent(&self) -> impl Iterator<Item = PersistableEntry> + '_ {
self.shards.iter().flat_map(|shard| {
let guard = shard.read();
let entries: Vec<PersistableEntry> = guard
.values()
.filter(|e| e.persistent)
.map(|e| PersistableEntry {
query_type_id: e.query_type_id,
raw_key_bytes: Arc::clone(&e.raw_key_bytes),
raw_result_bytes: Arc::clone(&e.raw_result_bytes),
deps: QueryDeps {
file_deps: e.file_deps.to_vec(),
edge_revision: e.edge_revision,
metadata_revision: e.metadata_revision,
},
})
.collect();
drop(guard);
entries.into_iter()
})
}
}
unsafe impl Send for ShardedCache {}
unsafe impl Sync for ShardedCache {}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use sqry_core::graph::unified::concurrent::CodeGraph;
use sqry_core::graph::unified::file::id::FileId;
use crate::query::QueryKey;
fn empty_snapshot() -> Arc<sqry_core::graph::unified::concurrent::GraphSnapshot> {
Arc::new(CodeGraph::new().snapshot())
}
struct PersistentTestQuery;
#[derive(Serialize, Deserialize, Hash, Eq, PartialEq, Clone)]
struct PersistentTestKey(u32);
impl DerivedQuery for PersistentTestQuery {
type Key = PersistentTestKey;
type Value = Vec<u8>;
const QUERY_TYPE_ID: u32 = 0xF100;
const PERSISTENT: bool = true;
fn execute(
_key: &Self::Key,
_db: &crate::QueryDb,
_snapshot: &sqry_core::graph::unified::concurrent::GraphSnapshot,
) -> Self::Value {
vec![]
}
}
struct NonPersistentTestQuery;
#[derive(Serialize, Deserialize, Hash, Eq, PartialEq, Clone)]
struct NonPersistentTestKey(u32);
impl DerivedQuery for NonPersistentTestQuery {
type Key = NonPersistentTestKey;
type Value = String;
const QUERY_TYPE_ID: u32 = 0xF101;
const PERSISTENT: bool = false;
fn execute(
key: &Self::Key,
_db: &crate::QueryDb,
_snapshot: &sqry_core::graph::unified::concurrent::GraphSnapshot,
) -> Self::Value {
format!("result_{}", key.0)
}
}
#[test]
fn sharded_cache_basic_ops() {
let cache = ShardedCache::new(4);
assert_eq!(cache.shard_count(), 4);
assert_eq!(cache.total_entries(), 0);
let key = QueryKey::from_raw(42, 0);
let result = CachedResult::new(vec![1u32, 2, 3], SmallVec::new(), None, None);
cache.insert(0, key.clone(), result);
assert_eq!(cache.total_entries(), 1);
let val: Option<Vec<u32>> = cache.get_if_valid(0, &key, |_| true);
assert_eq!(val, Some(vec![1u32, 2, 3]));
assert!(cache.remove(0, &key));
assert_eq!(cache.total_entries(), 0);
}
#[test]
fn sharded_cache_validation_rejects() {
let cache = ShardedCache::new(4);
let key = QueryKey::from_raw(1, 0);
cache.insert(
0,
key.clone(),
CachedResult::new(42u32, SmallVec::new(), None, None),
);
let val: Option<u32> = cache.get_if_valid(0, &key, |_| false);
assert!(val.is_none());
let val: Option<u32> = cache.get_if_valid(0, &key, |_| true);
assert_eq!(val, Some(42));
}
#[test]
fn sharded_cache_clear_all() {
let cache = ShardedCache::new(4);
for i in 0..4 {
let key = QueryKey::from_raw(i as u64, 0);
cache.insert(i, key, CachedResult::new(i, SmallVec::new(), None, None));
}
assert_eq!(cache.total_entries(), 4);
cache.clear_all();
assert_eq!(cache.total_entries(), 0);
}
#[test]
fn cached_result_validates_file_deps() {
let mut store = crate::input::FileInputStore::new();
store.insert(
FileId::new(1),
crate::input::FileInput::new(Default::default()),
);
store.insert(
FileId::new(2),
crate::input::FileInput::new(Default::default()),
);
let mut deps: SmallVec<[FileDep; 8]> = SmallVec::new();
deps.push((FileId::new(1), 1)); deps.push((FileId::new(2), 1));
let result = CachedResult::new(42u32, deps, None, None);
assert!(result.validate_file_deps(&store));
store
.get_mut(FileId::new(1))
.unwrap()
.update(Default::default());
assert!(
!result.validate_file_deps(&store),
"should invalidate after revision bump"
);
}
#[test]
#[should_panic(expected = "is_power_of_two")]
fn sharded_cache_rejects_non_power_of_two() {
let _ = ShardedCache::new(3);
}
#[test]
fn shard_entry_counts() {
let cache = ShardedCache::new(4);
cache.insert(
0,
QueryKey::from_raw(1, 0),
CachedResult::new(1u32, SmallVec::new(), None, None),
);
cache.insert(
0,
QueryKey::from_raw(2, 0),
CachedResult::new(2u32, SmallVec::new(), None, None),
);
cache.insert(
2,
QueryKey::from_raw(3, 0),
CachedResult::new(3u32, SmallVec::new(), None, None),
);
let counts = cache.shard_entry_counts();
assert_eq!(counts, vec![2, 0, 1, 0]);
}
fn default_config() -> QueryDbConfig {
QueryDbConfig::default()
}
#[test]
fn cached_result_new_has_empty_raw_bytes() {
let r = CachedResult::new(42u32, SmallVec::new(), None, None);
assert!(r.raw_key_bytes().is_empty());
assert!(r.raw_result_bytes().is_empty());
assert_eq!(r.query_type_id(), 0);
assert!(!r.persistent());
}
#[test]
fn insert_query_persistent_stores_raw_bytes() {
let cache = ShardedCache::new(4);
let cfg = default_config();
let key = PersistentTestKey(7);
let value: Vec<u8> = vec![0xDE, 0xAD, 0xBE, 0xEF];
let query_key = QueryKey::new::<PersistentTestQuery>(&key);
let shard_idx = {
use std::hash::{Hash, Hasher};
let tid = std::any::TypeId::of::<PersistentTestQuery>();
let mut h = std::collections::hash_map::DefaultHasher::new();
tid.hash(&mut h);
(h.finish() & 3) as usize
};
cache
.insert_query::<PersistentTestQuery>(
shard_idx,
query_key.clone(),
&key,
value.clone(),
SmallVec::new(),
None,
None,
&cfg,
)
.expect("insert_query should not fail");
let got: Option<Vec<u8>> = cache.get_if_valid(shard_idx, &query_key, |_| true);
assert_eq!(got, Some(value));
let persistent: Vec<_> = cache.iter_persistent().collect();
assert_eq!(persistent.len(), 1);
assert_eq!(persistent[0].query_type_id, 0xF100);
assert!(!persistent[0].raw_key_bytes.is_empty());
assert!(!persistent[0].raw_result_bytes.is_empty());
}
#[test]
fn insert_query_oversize_entry_skipped() {
let cache = ShardedCache::new(4);
let cfg = QueryDbConfig::builder().max_entry_size_bytes(1024).build();
let key = PersistentTestKey(99);
let value: Vec<u8> = vec![0xABu8; 2048];
let query_key = QueryKey::new::<PersistentTestQuery>(&key);
let shard_idx = {
use std::hash::{Hash, Hasher};
let tid = std::any::TypeId::of::<PersistentTestQuery>();
let mut h = std::collections::hash_map::DefaultHasher::new();
tid.hash(&mut h);
(h.finish() & 3) as usize
};
cache
.insert_query::<PersistentTestQuery>(
shard_idx,
query_key.clone(),
&key,
value,
SmallVec::new(),
None,
None,
&cfg,
)
.expect("oversize soft-skip must still return Ok");
let got: Option<Vec<u8>> = cache.get_if_valid(shard_idx, &query_key, |_| true);
assert!(got.is_none(), "oversize entry must not be present in cache");
let persistent: Vec<_> = cache.iter_persistent().collect();
assert!(
persistent.is_empty(),
"oversize entry must not appear in iter_persistent"
);
}
#[test]
fn insert_query_non_persistent_invisible_to_iter_persistent() {
let cache = ShardedCache::new(4);
let cfg = default_config();
let key = NonPersistentTestKey(42);
let value = "hello".to_owned();
let query_key = QueryKey::new::<NonPersistentTestQuery>(&key);
let shard_idx = {
use std::hash::{Hash, Hasher};
let tid = std::any::TypeId::of::<NonPersistentTestQuery>();
let mut h = std::collections::hash_map::DefaultHasher::new();
tid.hash(&mut h);
(h.finish() & 3) as usize
};
cache
.insert_query::<NonPersistentTestQuery>(
shard_idx,
query_key.clone(),
&key,
value.clone(),
SmallVec::new(),
None,
None,
&cfg,
)
.expect("non-persistent insert must succeed");
let got: Option<String> = cache.get_if_valid(shard_idx, &query_key, |_| true);
assert_eq!(got, Some(value));
{
let shard = cache.shards[shard_idx].read();
let entry = shard.get(&query_key).expect("entry must be present");
assert!(
entry.raw_key_bytes().is_empty(),
"non-persistent entry must have empty raw_key_bytes"
);
assert!(
entry.raw_result_bytes().is_empty(),
"non-persistent entry must have empty raw_result_bytes"
);
assert!(
!entry.persistent(),
"PERSISTENT=false must set persistent=false"
);
}
let persistent: Vec<_> = cache.iter_persistent().collect();
assert!(
persistent.is_empty(),
"non-persistent entry must not appear in iter_persistent"
);
}
#[test]
fn insert_query_deps_propagated_to_persistable_entry() {
let cache = ShardedCache::new(4);
let cfg = default_config();
let key = PersistentTestKey(1);
let value: Vec<u8> = vec![1, 2, 3];
let query_key = QueryKey::new::<PersistentTestQuery>(&key);
let shard_idx = {
use std::hash::{Hash, Hasher};
let tid = std::any::TypeId::of::<PersistentTestQuery>();
let mut h = std::collections::hash_map::DefaultHasher::new();
tid.hash(&mut h);
(h.finish() & 3) as usize
};
let mut file_deps: SmallVec<[FileDep; 8]> = SmallVec::new();
file_deps.push((FileId::new(10), 5));
cache
.insert_query::<PersistentTestQuery>(
shard_idx,
query_key,
&key,
value,
file_deps,
Some(42),
Some(7),
&cfg,
)
.expect("insert_query should succeed");
let entries: Vec<_> = cache.iter_persistent().collect();
assert_eq!(entries.len(), 1);
let deps = &entries[0].deps;
assert_eq!(deps.file_deps.len(), 1);
assert_eq!(deps.file_deps[0], (FileId::new(10), 5));
assert_eq!(deps.edge_revision, Some(42));
assert_eq!(deps.metadata_revision, Some(7));
}
#[test]
fn iter_persistent_counts_correctly() {
let cache = ShardedCache::new(4);
let cfg = default_config();
let k1 = PersistentTestKey(1);
let qk1 = QueryKey::new::<PersistentTestQuery>(&k1);
let si1 = {
use std::hash::{Hash, Hasher};
let tid = std::any::TypeId::of::<PersistentTestQuery>();
let mut h = std::collections::hash_map::DefaultHasher::new();
tid.hash(&mut h);
(h.finish() & 3) as usize
};
cache
.insert_query::<PersistentTestQuery>(
si1,
qk1,
&k1,
vec![1u8],
SmallVec::new(),
None,
None,
&cfg,
)
.unwrap();
let k2 = PersistentTestKey(2);
let qk2 = QueryKey::new::<PersistentTestQuery>(&k2);
cache
.insert_query::<PersistentTestQuery>(
si1,
qk2,
&k2,
vec![2u8],
SmallVec::new(),
None,
None,
&cfg,
)
.unwrap();
let nk = NonPersistentTestKey(3);
let nqk = QueryKey::new::<NonPersistentTestQuery>(&nk);
let nsi = {
use std::hash::{Hash, Hasher};
let tid = std::any::TypeId::of::<NonPersistentTestQuery>();
let mut h = std::collections::hash_map::DefaultHasher::new();
tid.hash(&mut h);
(h.finish() & 3) as usize
};
cache
.insert_query::<NonPersistentTestQuery>(
nsi,
nqk,
&nk,
"skip".to_owned(),
SmallVec::new(),
None,
None,
&cfg,
)
.unwrap();
let count = cache.iter_persistent().count();
assert_eq!(count, 2, "only the two persistent entries should appear");
}
#[test]
fn empty_snapshot_compiles() {
let _ = empty_snapshot();
}
}