kitt_score 0.1.0

Decision engine at the core of Project KITT โ€” in-memory stateful matching with pluggable scoring backends.
Documentation
//! Sharded location table behind an `ArcSwap` for atomic bulk reload.
//!
//! # Algorithm summary (spec ยง4.5)
//!
//! - Number of shards `N` is a const power of two (default 256). The shard
//!   index for a `LocId` is `hash(id) & (N - 1)`.
//! - The whole sharded structure is held as an immutable `Shards<T>` behind
//!   an `ArcSwap<Shards<T>>`. Hot reads do `arc.load()` then hash into the
//!   array.
//! - `reload_all` builds a brand-new `Shards<T>` on the caller's thread, then
//!   does `arc.store(new)`. Outstanding readers retain the old `Shards<T>`
//!   via their `Guard` until they drop it; the old structure is reclaimed
//!   deterministically.
//! - `upsert` / `remove` mutate the *current* `Shards<T>` in place (the
//!   per-shard `RwLock` guards this), so small edits do not require
//!   rebuilding the whole table.
//!
//! # References
//!
//! - `arc_swap` crate docs: the core wait-free read primitive.
//! - Fraser, "Practical lock-freedom" (2004): epoch-style reclamation theory
//!   underlying the `ArcSwap::load` guarantee.

use crate::location::shard::Shard;
use crate::location::state::LocationState;
use crate::location::LocationDef;
use crate::schema::attr::{AttrType, OwnedValue};
use crate::schema::schema::Schema;
use crate::{AttrId, KindId, LocId};
use ahash::RandomState;
use arc_swap::ArcSwap;
use parking_lot::Mutex;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

/// Number of shards in the location table. Must be a power of two.
pub const N_SHARDS: usize = 256;

/// The inner sharded map โ€” immutable once constructed.
pub struct Shards<T> {
    shards: [Shard<T>; N_SHARDS],
    hasher: RandomState,
}

impl<T> Shards<T> {
    fn new() -> Self {
        let shards = std::array::from_fn(|_| Shard {
            map: parking_lot::RwLock::new(ahash::AHashMap::default()),
        });
        Self {
            shards,
            hasher: RandomState::new(),
        }
    }

    #[inline]
    #[allow(clippy::cast_possible_truncation)]
    fn shard_for(&self, id: LocId) -> &Shard<T> {
        let h = self.hasher.hash_one(id.0);
        &self.shards[(h as usize) & (N_SHARDS - 1)]
    }

    /// Look up a location by ID.
    #[must_use]
    pub fn get(&self, id: LocId) -> Option<Arc<Mutex<LocationState<T>>>> {
        self.shard_for(id).get(id)
    }

    /// Insert or replace a location.
    pub fn insert(&self, id: LocId, st: LocationState<T>) -> Arc<Mutex<LocationState<T>>> {
        self.shard_for(id).insert(id, st)
    }

    /// Remove a location by ID. Returns `true` if it was present.
    pub fn remove(&self, id: LocId) -> bool {
        self.shard_for(id).remove(id)
    }

    /// Total number of locations across all shards.
    #[must_use]
    pub fn total_locations(&self) -> usize {
        self.shards.iter().map(Shard::len).sum()
    }
}

/// Errors from `upsert` and `reload_all`.
#[derive(Debug, thiserror::Error)]
pub enum ReloadErr {
    /// A `LocationDef` referenced a `KindId` that is not in the schema.
    #[error("unknown kind in LocationDef: {0:?}")]
    UnknownKind(KindId),
    /// A `LocationDef` referenced an `AttrId` that is not in the schema or
    /// not in the given kind.
    #[error("unknown attribute in LocationDef: {0:?}")]
    UnknownAttr(AttrId),
    /// A `LocationDef` reference-attribute value does not match its slot type.
    #[error("attribute type mismatch: expected {expected:?}")]
    TypeMismatch {
        /// The slot type expected by the schema.
        expected: AttrType,
    },
}

/// Sharded, bulk-reloadable location table.
pub struct LocationTable<T> {
    schema: Arc<Schema>,
    inner: ArcSwap<Shards<T>>,
    /// Set to `true` while a `reload_all` call is running.
    pub reload_in_progress: AtomicBool,
}

impl<T> LocationTable<T> {
    /// Create a new empty table.
    #[must_use]
    pub fn new(schema: Arc<Schema>) -> Self {
        Self {
            schema,
            inner: ArcSwap::from_pointee(Shards::new()),
            reload_in_progress: AtomicBool::new(false),
        }
    }

    /// The schema this table was constructed with.
    #[must_use]
    pub const fn schema(&self) -> &Arc<Schema> {
        &self.schema
    }

    /// Look up a location handle by ID.
    #[must_use]
    pub fn get(&self, id: LocId) -> Option<Arc<Mutex<LocationState<T>>>> {
        self.inner.load().get(id)
    }

    /// Insert or replace a single location.
    ///
    /// # Errors
    ///
    /// Returns [`ReloadErr`] if any reference attribute on the def cannot be
    /// resolved through the schema.
    ///
    /// # Concurrency
    ///
    /// `upsert` reads the current `Shards<T>` snapshot via `ArcSwap::load`, then
    /// writes into it through the per-shard `RwLock`. If a concurrent
    /// `reload_all` swaps in a new snapshot between the `load` and the write,
    /// the write lands in the now-orphaned old snapshot and is silently lost.
    /// Callers that mix `upsert` with `reload_all` must coordinate externally โ€”
    /// typically by pausing single-location upserts during a bulk reload.
    pub fn upsert(&self, def: &LocationDef) -> Result<(), ReloadErr> {
        let snap = self.inner.load();
        let state = Self::build_state(&self.schema, def)?;
        snap.insert(def.id, state);
        Ok(())
    }

    /// Remove a location by ID. Returns `true` if it was present.
    ///
    /// # Concurrency
    ///
    /// `remove` reads the current `Shards<T>` snapshot via `ArcSwap::load`, then
    /// writes into it through the per-shard `RwLock`. If a concurrent
    /// `reload_all` swaps in a new snapshot between the `load` and the write,
    /// the write lands in the now-orphaned old snapshot and is silently lost.
    /// Callers that mix `remove` with `reload_all` must coordinate externally โ€”
    /// typically by pausing single-location removes during a bulk reload.
    pub fn remove(&self, id: LocId) -> bool {
        self.inner.load().remove(id)
    }

    /// Replace the entire table atomically from an iterator of defs.
    ///
    /// Readers holding a `Guard` into the previous table continue to see it
    /// until they drop the guard.
    ///
    /// # Errors
    ///
    /// Returns [`ReloadErr`] if any def contains an unresolvable reference
    /// attribute. The new table is not installed on error.
    pub fn reload_all(&self, defs: impl IntoIterator<Item = LocationDef>) -> Result<(), ReloadErr> {
        self.reload_in_progress.store(true, Ordering::Release);
        let result = (|| -> Result<Shards<T>, ReloadErr> {
            let new = Shards::new();
            for def in defs {
                let state = Self::build_state(&self.schema, &def)?;
                new.insert(def.id, state);
            }
            Ok(new)
        })();
        match result {
            Ok(new) => {
                self.inner.store(Arc::new(new));
                self.reload_in_progress.store(false, Ordering::Release);
                Ok(())
            }
            Err(e) => {
                self.reload_in_progress.store(false, Ordering::Release);
                Err(e)
            }
        }
    }

    /// Total number of locations in the table.
    #[must_use]
    pub fn total_locations(&self) -> usize {
        self.inner.load().total_locations()
    }

    fn build_state(schema: &Arc<Schema>, def: &LocationDef) -> Result<LocationState<T>, ReloadErr> {
        for kid in &def.kinds_allowed {
            if usize::from(kid.0) >= schema.kind_names.len() {
                return Err(ReloadErr::UnknownKind(*kid));
            }
        }
        let mut st = LocationState::new(schema.clone());
        for (kid, aid, val) in &def.ref_attrs {
            let Some(slot) = schema.slot_layout.resolve(*kid, *aid) else {
                return Err(ReloadErr::UnknownAttr(*aid));
            };
            write_owned(&mut st.buf, slot.offset as usize, slot.ty, val)?;
        }
        Ok(st)
    }
}

fn write_owned(buf: &mut [u8], off: usize, ty: AttrType, v: &OwnedValue) -> Result<(), ReloadErr> {
    match (ty, v) {
        (AttrType::Int, OwnedValue::Int(n)) => {
            buf[off..off + 8].copy_from_slice(&n.to_le_bytes());
            Ok(())
        }
        (AttrType::F32, OwnedValue::F32(x)) => {
            buf[off..off + 4].copy_from_slice(&x.to_le_bytes());
            Ok(())
        }
        (AttrType::F64, OwnedValue::F64(x)) => {
            buf[off..off + 8].copy_from_slice(&x.to_le_bytes());
            Ok(())
        }
        (AttrType::EnumStr, OwnedValue::EnumCode(c)) => {
            buf[off..off + 4].copy_from_slice(&c.to_le_bytes());
            Ok(())
        }
        _ => Err(ReloadErr::TypeMismatch { expected: ty }),
    }
}

#[cfg(test)]
mod tests {
    #![allow(clippy::unwrap_used)]
    use super::*;
    use crate::schema::SchemaBuilder;

    fn tiny_schema() -> Arc<Schema> {
        let mut b = SchemaBuilder::new();
        let _ = b.kind(
            "audience",
            &[("male_frac", AttrType::F32), ("dwell", AttrType::Int)],
        );
        b.build()
    }

    #[test]
    fn upsert_and_get() {
        let schema = tiny_schema();
        let t: LocationTable<()> = LocationTable::new(schema);
        let def = LocationDef {
            id: LocId(1),
            kinds_allowed: vec![],
            ref_attrs: vec![],
        };
        t.upsert(&def).unwrap();
        assert!(t.get(LocId(1)).is_some());
        assert!(t.get(LocId(2)).is_none());
        assert_eq!(t.total_locations(), 1);
    }

    #[test]
    fn reload_replaces_table_atomically() {
        let schema = tiny_schema();
        let t: LocationTable<()> = LocationTable::new(schema);
        t.upsert(&LocationDef {
            id: LocId(1),
            kinds_allowed: vec![],
            ref_attrs: vec![],
        })
        .unwrap();
        assert!(t.get(LocId(1)).is_some());
        t.reload_all([LocationDef {
            id: LocId(2),
            kinds_allowed: vec![],
            ref_attrs: vec![],
        }])
        .unwrap();
        assert!(t.get(LocId(1)).is_none());
        assert!(t.get(LocId(2)).is_some());
    }

    #[test]
    fn remove_returns_whether_found() {
        let schema = tiny_schema();
        let t: LocationTable<()> = LocationTable::new(schema);
        t.upsert(&LocationDef {
            id: LocId(1),
            kinds_allowed: vec![],
            ref_attrs: vec![],
        })
        .unwrap();
        assert!(t.remove(LocId(1)));
        assert!(!t.remove(LocId(1)));
    }

    #[test]
    fn concurrent_reads_across_reload() {
        use std::sync::atomic::AtomicUsize;
        let schema = tiny_schema();
        let t: Arc<LocationTable<()>> = Arc::new(LocationTable::new(schema));
        for i in 0..1000_u64 {
            t.upsert(&LocationDef {
                id: LocId(i),
                kinds_allowed: vec![],
                ref_attrs: vec![],
            })
            .unwrap();
        }
        let hits = Arc::new(AtomicUsize::new(0));
        let mut handles = Vec::new();
        for _ in 0..8 {
            let t2 = t.clone();
            let hits2 = hits.clone();
            handles.push(std::thread::spawn(move || {
                for _ in 0..10_000 {
                    if t2.get(LocId(500)).is_some() {
                        hits2.fetch_add(1, Ordering::Relaxed);
                    }
                }
            }));
        }
        let defs: Vec<_> = (0..1000_u64)
            .map(|i| LocationDef {
                id: LocId(i),
                kinds_allowed: vec![],
                ref_attrs: vec![],
            })
            .collect();
        t.reload_all(defs).unwrap();
        for h in handles {
            h.join().unwrap();
        }
        assert!(hits.load(Ordering::Relaxed) > 0);
    }

    #[test]
    fn reload_resets_in_progress_on_error() {
        let schema = tiny_schema();
        let t: LocationTable<()> = LocationTable::new(schema);
        // A def with a bogus attribute triggers UnknownAttr from build_state.
        let bad = LocationDef {
            id: LocId(1),
            kinds_allowed: vec![],
            ref_attrs: vec![(KindId(0), AttrId(999), OwnedValue::Int(0))],
        };
        let err = t.reload_all([bad]).unwrap_err();
        assert!(matches!(err, ReloadErr::UnknownAttr(_)));
        assert!(!t.reload_in_progress.load(Ordering::Acquire));
    }

    #[test]
    fn reload_rejects_unknown_kind() {
        let schema = tiny_schema();
        let t: LocationTable<()> = LocationTable::new(schema);
        let def = LocationDef {
            id: LocId(1),
            kinds_allowed: vec![KindId(999)],
            ref_attrs: vec![],
        };
        let err = t.reload_all([def]).unwrap_err();
        assert!(matches!(err, ReloadErr::UnknownKind(_)));
    }
}