use pgrx::callbacks::{register_xact_callback, PgXactCallbackEvent};
use pgrx::prelude::*;
use pgrx::{pg_shmem_init, PGRXSharedMemory, PgAtomic, PgLwLock};
use std::cell::RefCell;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
const SLOTS: usize = 16_384;
const PROBE_DEPTH: usize = 8;
#[derive(Copy, Clone, Default)]
#[repr(C)]
pub(crate) struct DictCacheSlot {
key_hash1: u64,
key_hash2: u64,
generation: u64,
dict_id: i64,
occupied: u8,
_pad: [u8; 7],
}
unsafe impl PGRXSharedMemory for DictCacheSlot {}
impl DictCacheSlot {
const fn default_const() -> Self {
Self {
key_hash1: 0,
key_hash2: 0,
generation: 0,
dict_id: 0,
occupied: 0,
_pad: [0; 7],
}
}
}
static DICT_CACHE: PgLwLock<[DictCacheSlot; SLOTS]> =
unsafe { PgLwLock::new(c"pgrdf_dict_cache_v1") };
pub(crate) static HITS: PgAtomic<AtomicU64> = unsafe { PgAtomic::new(c"pgrdf_dict_cache_hits") };
static MISSES: PgAtomic<AtomicU64> = unsafe { PgAtomic::new(c"pgrdf_dict_cache_misses") };
static INSERTS: PgAtomic<AtomicU64> = unsafe { PgAtomic::new(c"pgrdf_dict_cache_inserts") };
static EVICTIONS: PgAtomic<AtomicU64> = unsafe { PgAtomic::new(c"pgrdf_dict_cache_evictions") };
static GENERATION: PgAtomic<AtomicU64> = unsafe { PgAtomic::new(c"pgrdf_dict_cache_generation") };
pub fn init_in_postmaster() {
pg_shmem_init!(DICT_CACHE = [DictCacheSlot::default_const(); SLOTS]);
pg_shmem_init!(HITS);
pg_shmem_init!(MISSES);
pg_shmem_init!(INSERTS);
pg_shmem_init!(EVICTIONS);
pg_shmem_init!(GENERATION = AtomicU64::new(1));
mark_ready();
}
fn current_generation() -> u64 {
if !is_ready() {
return 0;
}
GENERATION.get().load(Ordering::Relaxed)
}
pub fn reset() {
if !is_ready() {
return;
}
GENERATION.get().fetch_add(1, Ordering::Relaxed);
}
static SHMEM_READY: AtomicBool = AtomicBool::new(false);
pub fn mark_ready() {
SHMEM_READY.store(true, Ordering::Relaxed);
}
pub fn is_ready() -> bool {
SHMEM_READY.load(Ordering::Relaxed)
}
const SEED_A: u64 = 0x9E37_79B9_7F4A_7C15; const SEED_B: u64 = 0xC4F1_7B5E_9D0A_3E27;
fn fingerprint(
term_type: i16,
value: &str,
datatype_id: Option<i64>,
language: Option<&str>,
) -> (u64, u64) {
let mut h1 = DefaultHasher::new();
SEED_A.hash(&mut h1);
term_type.hash(&mut h1);
value.hash(&mut h1);
datatype_id.hash(&mut h1);
language.hash(&mut h1);
let mut h2 = DefaultHasher::new();
SEED_B.hash(&mut h2);
term_type.hash(&mut h2);
value.hash(&mut h2);
datatype_id.hash(&mut h2);
language.hash(&mut h2);
(h1.finish(), h2.finish())
}
pub fn lookup(
term_type: i16,
value: &str,
datatype_id: Option<i64>,
language: Option<&str>,
) -> Option<i64> {
if !is_ready() {
return None;
}
let gen = current_generation();
let (h1, h2) = fingerprint(term_type, value, datatype_id, language);
let table = DICT_CACHE.share();
let start = (h1 as usize) % SLOTS;
for i in 0..PROBE_DEPTH {
let slot = &table[(start + i) % SLOTS];
if slot.occupied != 0
&& slot.generation == gen
&& slot.key_hash1 == h1
&& slot.key_hash2 == h2
{
HITS.get().fetch_add(1, Ordering::Relaxed);
return Some(slot.dict_id);
}
}
MISSES.get().fetch_add(1, Ordering::Relaxed);
None
}
thread_local! {
static PENDING: RefCell<Vec<(u64, u64, i64)>> = const { RefCell::new(Vec::new()) };
static REGISTERED: RefCell<bool> = const { RefCell::new(false) };
}
pub fn stage_for_commit(
term_type: i16,
value: &str,
datatype_id: Option<i64>,
language: Option<&str>,
dict_id: i64,
) {
if !is_ready() {
return;
}
let (h1, h2) = fingerprint(term_type, value, datatype_id, language);
PENDING.with(|p| p.borrow_mut().push((h1, h2, dict_id)));
register_xact_callbacks_once();
}
pub fn insert_committed(
term_type: i16,
value: &str,
datatype_id: Option<i64>,
language: Option<&str>,
dict_id: i64,
) {
if !is_ready() {
return;
}
let (h1, h2) = fingerprint(term_type, value, datatype_id, language);
insert_slot(h1, h2, dict_id);
}
fn register_xact_callbacks_once() {
let needs_register = REGISTERED.with(|r| {
if *r.borrow() {
false
} else {
*r.borrow_mut() = true;
true
}
});
if !needs_register {
return;
}
register_xact_callback(PgXactCallbackEvent::Commit, || {
flush_pending();
REGISTERED.with(|r| *r.borrow_mut() = false);
});
register_xact_callback(PgXactCallbackEvent::Abort, || {
PENDING.with(|p| p.borrow_mut().clear());
REGISTERED.with(|r| *r.borrow_mut() = false);
});
}
fn flush_pending() {
let drained: Vec<(u64, u64, i64)> = PENDING.with(|p| std::mem::take(&mut *p.borrow_mut()));
for (h1, h2, dict_id) in drained {
insert_slot(h1, h2, dict_id);
}
}
fn insert_slot(h1: u64, h2: u64, dict_id: i64) {
let gen = current_generation();
let mut table = DICT_CACHE.exclusive();
let start = (h1 as usize) % SLOTS;
for i in 0..PROBE_DEPTH {
let idx = (start + i) % SLOTS;
let slot_usable = table[idx].occupied != 0 && table[idx].generation == gen;
if !slot_usable {
table[idx] = DictCacheSlot {
key_hash1: h1,
key_hash2: h2,
generation: gen,
dict_id,
occupied: 1,
_pad: [0; 7],
};
INSERTS.get().fetch_add(1, Ordering::Relaxed);
return;
}
if table[idx].key_hash1 == h1 && table[idx].key_hash2 == h2 {
table[idx].dict_id = dict_id;
return;
}
}
let idx = start;
table[idx] = DictCacheSlot {
key_hash1: h1,
key_hash2: h2,
generation: gen,
dict_id,
occupied: 1,
_pad: [0; 7],
};
EVICTIONS.get().fetch_add(1, Ordering::Relaxed);
INSERTS.get().fetch_add(1, Ordering::Relaxed);
}
pub struct Snapshot {
pub ready: bool,
pub slots: usize,
pub hits: u64,
pub misses: u64,
pub inserts: u64,
pub evictions: u64,
}
pub fn snapshot() -> Snapshot {
Snapshot {
ready: is_ready(),
slots: SLOTS,
hits: if is_ready() {
HITS.get().load(Ordering::Relaxed)
} else {
0
},
misses: if is_ready() {
MISSES.get().load(Ordering::Relaxed)
} else {
0
},
inserts: if is_ready() {
INSERTS.get().load(Ordering::Relaxed)
} else {
0
},
evictions: if is_ready() {
EVICTIONS.get().load(Ordering::Relaxed)
} else {
0
},
}
}
#[cfg(any(test, feature = "pg_test"))]
#[pg_schema]
mod tests {
use super::*;
use crate::storage::dict::term_type;
#[pg_test]
fn shmem_ready_in_test() {
assert!(is_ready(), "shmem cache must be initialised in pg_test");
}
#[pg_test]
fn shmem_roundtrip_via_committed() {
let key_value = "http://example.com/shmem-test-1";
insert_committed(term_type::URI, key_value, None, None, 4242);
let got = lookup(term_type::URI, key_value, None, None);
assert_eq!(got, Some(4242));
}
#[pg_test]
fn shmem_disambiguates_keys() {
insert_committed(
term_type::URI,
"http://example.com/shmem-test-2a",
None,
None,
100,
);
insert_committed(
term_type::URI,
"http://example.com/shmem-test-2b",
None,
None,
200,
);
assert_eq!(
lookup(
term_type::URI,
"http://example.com/shmem-test-2a",
None,
None
),
Some(100)
);
assert_eq!(
lookup(
term_type::URI,
"http://example.com/shmem-test-2b",
None,
None
),
Some(200)
);
}
#[pg_test]
fn shmem_datatype_in_key() {
insert_committed(term_type::LITERAL, "42", None, None, 1);
insert_committed(term_type::LITERAL, "42", Some(7), None, 2);
assert_eq!(lookup(term_type::LITERAL, "42", None, None), Some(1));
assert_eq!(lookup(term_type::LITERAL, "42", Some(7), None), Some(2));
}
#[pg_test]
fn shmem_counters_advance() {
let before = snapshot();
assert!(lookup(term_type::URI, "http://example.com/cold-miss", None, None).is_none());
let after_miss = snapshot();
assert!(after_miss.misses > before.misses);
insert_committed(
term_type::URI,
"http://example.com/warm-hit",
None,
None,
9999,
);
let _ = lookup(term_type::URI, "http://example.com/warm-hit", None, None);
let after_hit = snapshot();
assert!(after_hit.hits > after_miss.hits);
assert!(after_hit.inserts > before.inserts);
}
}