use std::any::TypeId;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use sqry_core::graph::unified::concurrent::GraphSnapshot;
use crate::QueryDb;
pub trait DerivedQuery: Send + Sync + 'static {
type Key: Hash + Eq + Clone + Send + serde::Serialize + 'static;
type Value: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static;
const QUERY_TYPE_ID: u32;
const PERSISTENT: bool = true;
const TRACKS_EDGE_REVISION: bool = false;
const TRACKS_METADATA_REVISION: bool = false;
fn execute(key: &Self::Key, db: &QueryDb, snapshot: &GraphSnapshot) -> Self::Value;
}
#[derive(Clone, Eq, PartialEq)]
pub struct QueryKey {
type_hash: u64,
key_hash: u64,
}
impl QueryKey {
pub fn new<Q: DerivedQuery>(key: &Q::Key) -> Self {
let type_hash = u64::from(Q::QUERY_TYPE_ID);
let key_hash = Self::hash_serialized_key(key);
Self {
type_hash,
key_hash,
}
}
#[doc(hidden)]
#[must_use]
pub fn from_raw(type_hash: u64, key_hash: u64) -> Self {
Self {
type_hash,
key_hash,
}
}
fn hash_serialized_key<K: serde::Serialize>(key: &K) -> u64 {
let bytes = postcard::to_allocvec(key).expect(
"DerivedQuery::Key requires serde::Serialize to be infallible; \
postcard::to_allocvec must not fail",
);
let mut hasher = std::hash::DefaultHasher::new();
bytes.hash(&mut hasher);
hasher.finish()
}
}
impl Hash for QueryKey {
fn hash<H: Hasher>(&self, state: &mut H) {
self.type_hash.hash(state);
self.key_hash.hash(state);
}
}
impl std::fmt::Debug for QueryKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("QueryKey")
.field("type_hash", &format!("{:#018x}", self.type_hash))
.field("key_hash", &format!("{:#018x}", self.key_hash))
.finish()
}
}
pub struct QueryRegistry {
assignments: HashMap<TypeId, usize>,
}
impl QueryRegistry {
#[must_use]
pub fn new() -> Self {
Self {
assignments: HashMap::new(),
}
}
pub fn register<Q: DerivedQuery>(&mut self) {
let tid = TypeId::of::<Q>();
self.assignments.entry(tid).or_insert(0);
}
#[must_use]
pub fn shard_for<Q: DerivedQuery>(&self, shard_count: usize) -> usize {
Self::shard_for_query_type_id(Q::QUERY_TYPE_ID, shard_count)
}
#[must_use]
pub fn shard_for_query_type_id(query_type_id: u32, shard_count: usize) -> usize {
let mask = (shard_count - 1) as u64;
(u64::from(query_type_id) & mask) as usize
}
#[must_use]
pub fn registered_count(&self) -> usize {
self.assignments.len()
}
}
impl Default for QueryRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestQuery;
impl DerivedQuery for TestQuery {
type Key = u32;
type Value = String;
const QUERY_TYPE_ID: u32 = 0xF001;
fn execute(key: &u32, _db: &QueryDb, _snapshot: &GraphSnapshot) -> String {
format!("result_{key}")
}
}
struct EdgeTrackingQuery;
impl DerivedQuery for EdgeTrackingQuery {
type Key = u32;
type Value = Vec<u32>;
const QUERY_TYPE_ID: u32 = 0xF002;
const TRACKS_EDGE_REVISION: bool = true;
fn execute(_key: &u32, _db: &QueryDb, _snapshot: &GraphSnapshot) -> Vec<u32> {
vec![]
}
}
#[test]
fn query_key_equality() {
let k1 = QueryKey::new::<TestQuery>(&42);
let k2 = QueryKey::new::<TestQuery>(&42);
let k3 = QueryKey::new::<TestQuery>(&99);
assert_eq!(k1, k2);
assert_ne!(k1, k3);
}
#[test]
fn query_key_different_types_differ() {
let k1 = QueryKey::new::<TestQuery>(&42);
let k2 = QueryKey::new::<EdgeTrackingQuery>(&42);
assert_ne!(
k1, k2,
"different query types should produce different keys"
);
}
#[test]
fn registry_shard_assignment_deterministic() {
let reg = QueryRegistry::new();
let s1 = reg.shard_for::<TestQuery>(64);
let s2 = reg.shard_for::<TestQuery>(64);
assert_eq!(s1, s2);
assert!(s1 < 64);
}
#[test]
fn registry_register_and_count() {
let mut reg = QueryRegistry::new();
assert_eq!(reg.registered_count(), 0);
reg.register::<TestQuery>();
assert_eq!(reg.registered_count(), 1);
reg.register::<EdgeTrackingQuery>();
assert_eq!(reg.registered_count(), 2);
}
#[test]
fn registry_register_is_idempotent() {
let mut reg = QueryRegistry::new();
reg.register::<TestQuery>();
let before_shard = reg.shard_for::<TestQuery>(64);
let before_count = reg.registered_count();
reg.register::<TestQuery>();
reg.register::<TestQuery>();
reg.register::<TestQuery>();
assert_eq!(
reg.registered_count(),
before_count,
"duplicate register::<Q>() must not bump registered_count"
);
assert_eq!(
reg.shard_for::<TestQuery>(64),
before_shard,
"duplicate register::<Q>() must not change shard routing"
);
}
}