use std::collections::HashMap;
use xlog_core::{RelId, ScalarType, Schema};
use xlog_cuda::{CudaBuffer, JoinIndexV2};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct JoinIndexKey {
pub(crate) rel: RelId,
pub(crate) version: u64,
pub(crate) key_cols: Vec<usize>,
pub(crate) schema: JoinIndexSchemaSignature,
pub(crate) device_ordinal: u32,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct JoinIndexSchemaSignature {
column_types: Vec<ScalarType>,
row_size_bytes: usize,
}
impl JoinIndexSchemaSignature {
fn from_schema(schema: &Schema) -> Self {
Self {
column_types: (0..schema.arity())
.filter_map(|idx| schema.column_type(idx))
.collect(),
row_size_bytes: schema.row_size_bytes(),
}
}
}
impl JoinIndexKey {
pub(crate) fn new(
rel: RelId,
version: u64,
key_cols: Vec<usize>,
schema: &Schema,
device_ordinal: u32,
) -> Self {
Self {
rel,
version,
key_cols,
schema: JoinIndexSchemaSignature::from_schema(schema),
device_ordinal,
}
}
}
struct CachedJoinIndex {
index: CachedJoinIndexPayload,
bytes: u64,
last_used: u64,
}
#[allow(clippy::large_enum_variant)]
enum CachedJoinIndexPayload {
Ready(JoinIndexV2),
#[cfg(test)]
Placeholder,
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct JoinIndexCacheStats {
pub lookups: u64,
pub hits: u64,
pub misses: u64,
pub builds: u64,
pub evictions: u64,
pub invalidations: u64,
pub stale_rejections: u64,
pub background_build_requests: u64,
pub background_builds_completed: u64,
pub background_builds_deferred: u64,
pub entries: usize,
pub total_bytes: u64,
}
pub(crate) struct JoinIndexCache {
entries: HashMap<JoinIndexKey, CachedJoinIndex>,
clock: u64,
total_bytes: u64,
pub(crate) max_bytes: u64,
stats: JoinIndexCacheStats,
}
pub(crate) fn estimate_join_index_bytes(right: &CudaBuffer, right_keys: &[usize]) -> u64 {
if right_keys.is_empty() {
return u64::MAX;
}
let mut key_bytes_per_row: u64 = 0;
for &k in right_keys {
let Some(ty) = right.schema().column_type(k) else {
return u64::MAX;
};
key_bytes_per_row = key_bytes_per_row.saturating_add(ty.size_bytes() as u64);
}
let num_rows = right.num_rows();
let packed_bytes = num_rows.saturating_mul(key_bytes_per_row);
let target = num_rows.saturating_mul(2).max(1024);
let num_buckets = target.next_power_of_two();
packed_bytes
.saturating_add(num_buckets.saturating_mul(8))
.saturating_add(num_rows.saturating_mul(12))
}
impl JoinIndexCache {
pub(crate) fn new(max_bytes: u64) -> Self {
Self {
entries: HashMap::new(),
clock: 0,
total_bytes: 0,
max_bytes,
stats: JoinIndexCacheStats::default(),
}
}
pub(crate) fn should_build(
&self,
est_index_bytes: u64,
build_heat: f32,
remaining_device_bytes: u64,
device_budget_bytes: u64,
) -> bool {
let heat_threshold = if self.max_bytes > 0 && est_index_bytes > self.max_bytes / 2 {
0.6
} else {
0.3
};
let has_room =
remaining_device_bytes >= est_index_bytes.saturating_add(device_budget_bytes / 10);
build_heat >= heat_threshold && est_index_bytes <= self.max_bytes && has_room
}
pub(crate) fn clear(&mut self) {
let removed = self.entries.len() as u64;
self.entries.clear();
self.clock = 0;
self.total_bytes = 0;
self.stats.invalidations = self.stats.invalidations.saturating_add(removed);
}
pub(crate) fn get(&mut self, key: &JoinIndexKey) -> Option<&JoinIndexV2> {
self.stats.lookups = self.stats.lookups.saturating_add(1);
let Some(entry) = self.entries.get_mut(key) else {
self.stats.misses = self.stats.misses.saturating_add(1);
return None;
};
self.clock = self.clock.saturating_add(1);
entry.last_used = self.clock;
match &entry.index {
CachedJoinIndexPayload::Ready(index) => {
self.stats.hits = self.stats.hits.saturating_add(1);
Some(index)
}
#[cfg(test)]
CachedJoinIndexPayload::Placeholder => {
self.stats.misses = self.stats.misses.saturating_add(1);
None
}
}
}
pub(crate) fn insert(&mut self, key: JoinIndexKey, index: JoinIndexV2) {
let bytes = index.estimated_bytes();
if bytes > self.max_bytes {
return;
}
self.evict_until_fits(bytes);
self.clock = self.clock.saturating_add(1);
let last_used = self.clock;
if let Some(prev) = self.entries.remove(&key) {
self.total_bytes = self.total_bytes.saturating_sub(prev.bytes);
}
self.total_bytes = self.total_bytes.saturating_add(bytes);
self.entries.insert(
key,
CachedJoinIndex {
index: CachedJoinIndexPayload::Ready(index),
bytes,
last_used,
},
);
self.stats.builds = self.stats.builds.saturating_add(1);
}
pub(crate) fn remove(&mut self, key: &JoinIndexKey) {
if let Some(prev) = self.entries.remove(key) {
self.total_bytes = self.total_bytes.saturating_sub(prev.bytes);
}
}
pub(crate) fn remove_stale(&mut self, key: &JoinIndexKey) {
let before = self.entries.len();
self.remove(key);
if self.entries.len() < before {
self.stats.stale_rejections = self.stats.stale_rejections.saturating_add(1);
}
}
pub(crate) fn invalidate_rel(&mut self, rel: RelId) {
let keys: Vec<JoinIndexKey> = self
.entries
.keys()
.filter(|k| k.rel == rel)
.cloned()
.collect();
for key in keys {
if let Some(entry) = self.entries.remove(&key) {
self.total_bytes = self.total_bytes.saturating_sub(entry.bytes);
self.stats.invalidations = self.stats.invalidations.saturating_add(1);
}
}
}
pub(crate) fn evict_until_fits(&mut self, additional_bytes: u64) {
while !self.entries.is_empty()
&& self.total_bytes.saturating_add(additional_bytes) > self.max_bytes
{
let mut oldest_key: Option<JoinIndexKey> = None;
let mut oldest_clock = u64::MAX;
for (k, v) in &self.entries {
if v.last_used < oldest_clock {
oldest_clock = v.last_used;
oldest_key = Some(k.clone());
}
}
let Some(key) = oldest_key else {
break;
};
if let Some(entry) = self.entries.remove(&key) {
self.total_bytes = self.total_bytes.saturating_sub(entry.bytes);
self.stats.evictions = self.stats.evictions.saturating_add(1);
} else {
break;
}
}
}
pub(crate) fn record_background_build_request(&mut self) {
self.stats.background_build_requests =
self.stats.background_build_requests.saturating_add(1);
}
pub(crate) fn record_background_build_complete(&mut self) {
self.stats.background_builds_completed =
self.stats.background_builds_completed.saturating_add(1);
}
pub(crate) fn record_background_build_deferred(&mut self) {
self.stats.background_builds_deferred =
self.stats.background_builds_deferred.saturating_add(1);
}
pub(crate) fn stats(&self) -> JoinIndexCacheStats {
let mut stats = self.stats.clone();
stats.entries = self.entries.len();
stats.total_bytes = self.total_bytes;
stats
}
#[cfg(test)]
fn insert_test_entry(&mut self, key: JoinIndexKey, bytes: u64) {
if bytes > self.max_bytes {
return;
}
self.evict_until_fits(bytes);
self.clock = self.clock.saturating_add(1);
self.total_bytes = self.total_bytes.saturating_add(bytes);
self.entries.insert(
key,
CachedJoinIndex {
index: CachedJoinIndexPayload::Placeholder,
bytes,
last_used: self.clock,
},
);
}
}
#[cfg(test)]
mod tests {
use super::*;
use xlog_core::{ScalarType, Schema};
fn schema(cols: Vec<(&str, ScalarType)>) -> Schema {
Schema::new(
cols.into_iter()
.map(|(name, ty)| (name.to_string(), ty))
.collect(),
)
}
#[test]
fn persistent_key_includes_schema_generation_key_and_device() {
let u32_schema = schema(vec![("k", ScalarType::U32)]);
let u64_schema = schema(vec![("k", ScalarType::U64)]);
let key = JoinIndexKey::new(RelId(7), 3, vec![0], &u32_schema, 0);
assert_eq!(key.rel, RelId(7));
assert_eq!(key.version, 3);
assert_eq!(key.key_cols, vec![0]);
assert_eq!(key.device_ordinal, 0);
assert_ne!(
key,
JoinIndexKey::new(RelId(7), 4, vec![0], &u32_schema, 0),
"generation/version must partition stale indexes"
);
assert_ne!(
key,
JoinIndexKey::new(RelId(7), 3, vec![0], &u64_schema, 0),
"schema changes must partition indexes"
);
assert_ne!(
key,
JoinIndexKey::new(RelId(7), 3, vec![0], &u32_schema, 1),
"device ordinal must partition indexes"
);
}
#[test]
fn persistent_cache_budget_evicts_lru_and_records_stats() {
let schema = schema(vec![("k", ScalarType::U32)]);
let key_a = JoinIndexKey::new(RelId(1), 1, vec![0], &schema, 0);
let key_b = JoinIndexKey::new(RelId(2), 1, vec![0], &schema, 0);
let mut cache = JoinIndexCache::new(100);
cache.insert_test_entry(key_a, 60);
cache.insert_test_entry(key_b, 60);
let stats = cache.stats();
assert_eq!(stats.entries, 1);
assert_eq!(stats.total_bytes, 60);
assert_eq!(stats.evictions, 1);
}
#[test]
fn persistent_cache_invalidation_records_removed_entries() {
let schema = schema(vec![("k", ScalarType::U32)]);
let key = JoinIndexKey::new(RelId(1), 1, vec![0], &schema, 0);
let mut cache = JoinIndexCache::new(100);
cache.insert_test_entry(key, 32);
cache.invalidate_rel(RelId(1));
let stats = cache.stats();
assert_eq!(stats.entries, 0);
assert_eq!(stats.total_bytes, 0);
assert_eq!(stats.invalidations, 1);
}
}