use crate::location::shard::Shard;
use crate::location::state::LocationState;
use crate::location::LocationDef;
use crate::schema::attr::{AttrType, OwnedValue};
use crate::schema::schema::Schema;
use crate::{AttrId, KindId, LocId};
use ahash::RandomState;
use arc_swap::ArcSwap;
use parking_lot::Mutex;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
pub const N_SHARDS: usize = 256;
pub struct Shards<T> {
shards: [Shard<T>; N_SHARDS],
hasher: RandomState,
}
impl<T> Shards<T> {
fn new() -> Self {
let shards = std::array::from_fn(|_| Shard {
map: parking_lot::RwLock::new(ahash::AHashMap::default()),
});
Self {
shards,
hasher: RandomState::new(),
}
}
#[inline]
#[allow(clippy::cast_possible_truncation)]
fn shard_for(&self, id: LocId) -> &Shard<T> {
let h = self.hasher.hash_one(id.0);
&self.shards[(h as usize) & (N_SHARDS - 1)]
}
#[must_use]
pub fn get(&self, id: LocId) -> Option<Arc<Mutex<LocationState<T>>>> {
self.shard_for(id).get(id)
}
pub fn insert(&self, id: LocId, st: LocationState<T>) -> Arc<Mutex<LocationState<T>>> {
self.shard_for(id).insert(id, st)
}
pub fn remove(&self, id: LocId) -> bool {
self.shard_for(id).remove(id)
}
#[must_use]
pub fn total_locations(&self) -> usize {
self.shards.iter().map(Shard::len).sum()
}
}
#[derive(Debug, thiserror::Error)]
pub enum ReloadErr {
#[error("unknown kind in LocationDef: {0:?}")]
UnknownKind(KindId),
#[error("unknown attribute in LocationDef: {0:?}")]
UnknownAttr(AttrId),
#[error("attribute type mismatch: expected {expected:?}")]
TypeMismatch {
expected: AttrType,
},
}
pub struct LocationTable<T> {
schema: Arc<Schema>,
inner: ArcSwap<Shards<T>>,
pub reload_in_progress: AtomicBool,
}
impl<T> LocationTable<T> {
#[must_use]
pub fn new(schema: Arc<Schema>) -> Self {
Self {
schema,
inner: ArcSwap::from_pointee(Shards::new()),
reload_in_progress: AtomicBool::new(false),
}
}
#[must_use]
pub const fn schema(&self) -> &Arc<Schema> {
&self.schema
}
#[must_use]
pub fn get(&self, id: LocId) -> Option<Arc<Mutex<LocationState<T>>>> {
self.inner.load().get(id)
}
pub fn upsert(&self, def: &LocationDef) -> Result<(), ReloadErr> {
let snap = self.inner.load();
let state = Self::build_state(&self.schema, def)?;
snap.insert(def.id, state);
Ok(())
}
pub fn remove(&self, id: LocId) -> bool {
self.inner.load().remove(id)
}
pub fn reload_all(&self, defs: impl IntoIterator<Item = LocationDef>) -> Result<(), ReloadErr> {
self.reload_in_progress.store(true, Ordering::Release);
let result = (|| -> Result<Shards<T>, ReloadErr> {
let new = Shards::new();
for def in defs {
let state = Self::build_state(&self.schema, &def)?;
new.insert(def.id, state);
}
Ok(new)
})();
match result {
Ok(new) => {
self.inner.store(Arc::new(new));
self.reload_in_progress.store(false, Ordering::Release);
Ok(())
}
Err(e) => {
self.reload_in_progress.store(false, Ordering::Release);
Err(e)
}
}
}
#[must_use]
pub fn total_locations(&self) -> usize {
self.inner.load().total_locations()
}
fn build_state(schema: &Arc<Schema>, def: &LocationDef) -> Result<LocationState<T>, ReloadErr> {
for kid in &def.kinds_allowed {
if usize::from(kid.0) >= schema.kind_names.len() {
return Err(ReloadErr::UnknownKind(*kid));
}
}
let mut st = LocationState::new(schema.clone());
for (kid, aid, val) in &def.ref_attrs {
let Some(slot) = schema.slot_layout.resolve(*kid, *aid) else {
return Err(ReloadErr::UnknownAttr(*aid));
};
write_owned(&mut st.buf, slot.offset as usize, slot.ty, val)?;
}
Ok(st)
}
}
fn write_owned(buf: &mut [u8], off: usize, ty: AttrType, v: &OwnedValue) -> Result<(), ReloadErr> {
match (ty, v) {
(AttrType::Int, OwnedValue::Int(n)) => {
buf[off..off + 8].copy_from_slice(&n.to_le_bytes());
Ok(())
}
(AttrType::F32, OwnedValue::F32(x)) => {
buf[off..off + 4].copy_from_slice(&x.to_le_bytes());
Ok(())
}
(AttrType::F64, OwnedValue::F64(x)) => {
buf[off..off + 8].copy_from_slice(&x.to_le_bytes());
Ok(())
}
(AttrType::EnumStr, OwnedValue::EnumCode(c)) => {
buf[off..off + 4].copy_from_slice(&c.to_le_bytes());
Ok(())
}
_ => Err(ReloadErr::TypeMismatch { expected: ty }),
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
use crate::schema::SchemaBuilder;
fn tiny_schema() -> Arc<Schema> {
let mut b = SchemaBuilder::new();
let _ = b.kind(
"audience",
&[("male_frac", AttrType::F32), ("dwell", AttrType::Int)],
);
b.build()
}
#[test]
fn upsert_and_get() {
let schema = tiny_schema();
let t: LocationTable<()> = LocationTable::new(schema);
let def = LocationDef {
id: LocId(1),
kinds_allowed: vec![],
ref_attrs: vec![],
};
t.upsert(&def).unwrap();
assert!(t.get(LocId(1)).is_some());
assert!(t.get(LocId(2)).is_none());
assert_eq!(t.total_locations(), 1);
}
#[test]
fn reload_replaces_table_atomically() {
let schema = tiny_schema();
let t: LocationTable<()> = LocationTable::new(schema);
t.upsert(&LocationDef {
id: LocId(1),
kinds_allowed: vec![],
ref_attrs: vec![],
})
.unwrap();
assert!(t.get(LocId(1)).is_some());
t.reload_all([LocationDef {
id: LocId(2),
kinds_allowed: vec![],
ref_attrs: vec![],
}])
.unwrap();
assert!(t.get(LocId(1)).is_none());
assert!(t.get(LocId(2)).is_some());
}
#[test]
fn remove_returns_whether_found() {
let schema = tiny_schema();
let t: LocationTable<()> = LocationTable::new(schema);
t.upsert(&LocationDef {
id: LocId(1),
kinds_allowed: vec![],
ref_attrs: vec![],
})
.unwrap();
assert!(t.remove(LocId(1)));
assert!(!t.remove(LocId(1)));
}
#[test]
fn concurrent_reads_across_reload() {
use std::sync::atomic::AtomicUsize;
let schema = tiny_schema();
let t: Arc<LocationTable<()>> = Arc::new(LocationTable::new(schema));
for i in 0..1000_u64 {
t.upsert(&LocationDef {
id: LocId(i),
kinds_allowed: vec![],
ref_attrs: vec![],
})
.unwrap();
}
let hits = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for _ in 0..8 {
let t2 = t.clone();
let hits2 = hits.clone();
handles.push(std::thread::spawn(move || {
for _ in 0..10_000 {
if t2.get(LocId(500)).is_some() {
hits2.fetch_add(1, Ordering::Relaxed);
}
}
}));
}
let defs: Vec<_> = (0..1000_u64)
.map(|i| LocationDef {
id: LocId(i),
kinds_allowed: vec![],
ref_attrs: vec![],
})
.collect();
t.reload_all(defs).unwrap();
for h in handles {
h.join().unwrap();
}
assert!(hits.load(Ordering::Relaxed) > 0);
}
#[test]
fn reload_resets_in_progress_on_error() {
let schema = tiny_schema();
let t: LocationTable<()> = LocationTable::new(schema);
let bad = LocationDef {
id: LocId(1),
kinds_allowed: vec![],
ref_attrs: vec![(KindId(0), AttrId(999), OwnedValue::Int(0))],
};
let err = t.reload_all([bad]).unwrap_err();
assert!(matches!(err, ReloadErr::UnknownAttr(_)));
assert!(!t.reload_in_progress.load(Ordering::Acquire));
}
#[test]
fn reload_rejects_unknown_kind() {
let schema = tiny_schema();
let t: LocationTable<()> = LocationTable::new(schema);
let def = LocationDef {
id: LocId(1),
kinds_allowed: vec![KindId(999)],
ref_attrs: vec![],
};
let err = t.reload_all([def]).unwrap_err();
assert!(matches!(err, ReloadErr::UnknownKind(_)));
}
}