mod core;
use self::core::{FactStripeCore, Registration};
use crate::{FactKey, FactLoadError, FactLoadResult, FactSource, FactSourceRegistrationError};
use futures_channel::oneshot;
use std::any::{Any, TypeId};
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::Hasher;
use std::num::NonZeroUsize;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, MutexGuard, OnceLock};
use tracing::Instrument;
const FACT_STATE_STRIPES: usize = 64;
type StripeWaiter = oneshot::Sender<()>;
type StripeReceiver = oneshot::Receiver<()>;
type StripeCore<K> = FactStripeCore<K, StripeWaiter>;
#[derive(Default)]
struct EvaluationSessionInner {
states: Mutex<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
next_load_id: AtomicU64,
shared_empty: bool,
}
struct FactState<K>
where
K: FactKey,
{
source: Mutex<Option<Arc<dyn FactSource<K>>>>,
stripes: Box<[Mutex<StripeCore<K>>]>,
}
impl<K> FactState<K>
where
K: FactKey,
{
fn new(source: Option<Arc<dyn FactSource<K>>>) -> Self {
let stripes = (0..FACT_STATE_STRIPES)
.map(|_| Mutex::new(StripeCore::<K>::new()))
.collect();
Self {
source: Mutex::new(source),
stripes,
}
}
fn insert_source(
&self,
source: Arc<dyn FactSource<K>>,
replace: bool,
) -> Result<(), FactSourceRegistrationError> {
let mut source_guard = self
.source
.lock()
.expect("fact source mutex should not be poisoned");
let mut stripes = self.lock_stripes();
if !replace && source_guard.is_some() {
return Err(FactSourceRegistrationError::AlreadyRegistered { fact_name: K::NAME });
}
if stripes.iter().any(|stripe| !stripe.is_idle()) {
return Err(FactSourceRegistrationError::InFlight { fact_name: K::NAME });
}
*source_guard = Some(source);
for stripe in &mut stripes {
stripe.clear_cache();
}
Ok(())
}
fn lock_stripes(&self) -> Vec<MutexGuard<'_, StripeCore<K>>> {
self.stripes
.iter()
.map(|stripe| {
stripe
.lock()
.expect("fact state stripe mutex should not be poisoned")
})
.collect()
}
fn stripe_index(&self, key: &K) -> usize {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
(hasher.finish() as usize) % self.stripes.len()
}
fn plan_loads(&self, keys: &[K]) -> LoadPlan<K> {
let key_stripes = keys
.iter()
.map(|key| self.stripe_index(key))
.collect::<Vec<_>>();
let source_guard = self
.source
.lock()
.expect("fact source mutex should not be poisoned");
let source = source_guard.clone();
let cached_results = keys
.iter()
.zip(key_stripes.iter().copied())
.map(|(key, stripe_index)| {
let stripe = self.stripes[stripe_index]
.lock()
.expect("fact state stripe mutex should not be poisoned");
stripe.peek_cache(key)
})
.collect::<Option<Vec<_>>>();
if let Some(results) = cached_results {
return LoadPlan {
source,
cached_results: Some(results),
keys: Vec::new(),
waiters: Vec::new(),
};
}
let mut seen = std::collections::HashSet::new();
let mut leader_keys = Vec::new();
let mut waiters = Vec::new();
for (key, stripe_index) in keys.iter().zip(key_stripes.iter().copied()) {
if !seen.insert(key.clone()) {
continue;
}
let mut stripe = self.stripes[stripe_index]
.lock()
.expect("fact state stripe mutex should not be poisoned");
match stripe.try_register::<_, StripeReceiver>(key, || {
let (sender, receiver) = oneshot::channel();
(sender, receiver)
}) {
Registration::Cached(_) => {
}
Registration::Leading => leader_keys.push(key.clone()),
Registration::Joined(receiver) => waiters.push(receiver),
}
}
LoadPlan {
source,
cached_results: None,
keys: leader_keys,
waiters,
}
}
fn finish_keys(&self, keys: &[K], results: Vec<FactLoadResult<K::Value>>) {
assert_eq!(
keys.len(),
results.len(),
"finish_keys requires equal-length keys and results"
);
let mut all_waiters = Vec::new();
for (key, result) in keys.iter().cloned().zip(results) {
let stripe_index = self.stripe_index(&key);
let waiters = {
let mut stripe = self.stripes[stripe_index]
.lock()
.expect("fact state stripe mutex should not be poisoned");
stripe.finish(key, result)
};
all_waiters.extend(waiters);
}
for waiter in all_waiters {
let _ = waiter.send(());
}
}
fn results_from_cache(&self, keys: &[K]) -> Vec<FactLoadResult<K::Value>> {
keys.iter()
.map(|key| {
let stripe_index = self.stripe_index(key);
let stripe = self.stripes[stripe_index]
.lock()
.expect("fact state stripe mutex should not be poisoned");
stripe.peek_cache(key).unwrap_or_else(|| {
FactLoadResult::Error(FactLoadError::SourceContractViolation {
fact_name: K::NAME,
expected: 1,
actual: 0,
})
})
})
.collect()
}
}
#[derive(Clone, Default)]
pub struct EvaluationSession {
inner: Arc<EvaluationSessionInner>,
}
impl EvaluationSession {
pub fn new() -> Self {
Self::default()
}
pub fn empty() -> Self {
Self::new()
}
pub fn shared_empty() -> &'static Self {
static SHARED_EMPTY: OnceLock<EvaluationSession> = OnceLock::new();
SHARED_EMPTY.get_or_init(|| EvaluationSession {
inner: Arc::new(EvaluationSessionInner {
shared_empty: true,
..EvaluationSessionInner::default()
}),
})
}
pub fn builder() -> EvaluationSessionBuilder {
EvaluationSessionBuilder::new()
}
pub fn register<K, S>(&self, source: S)
where
K: FactKey,
S: FactSource<K> + 'static,
{
self.register_arc::<K>(Arc::new(source));
}
pub fn register_arc<K>(&self, source: Arc<dyn FactSource<K>>)
where
K: FactKey,
{
self.try_register_arc::<K>(source)
.unwrap_or_else(|error| panic!("{error}"));
}
pub fn try_register<K, S>(&self, source: S) -> Result<(), FactSourceRegistrationError>
where
K: FactKey,
S: FactSource<K> + 'static,
{
self.try_register_arc::<K>(Arc::new(source))
}
pub fn try_register_arc<K>(
&self,
source: Arc<dyn FactSource<K>>,
) -> Result<(), FactSourceRegistrationError>
where
K: FactKey,
{
self.insert_source::<K>(source, false)
}
pub fn replace<K, S>(&self, source: S)
where
K: FactKey,
S: FactSource<K> + 'static,
{
self.replace_arc::<K>(Arc::new(source));
}
pub fn replace_arc<K>(&self, source: Arc<dyn FactSource<K>>)
where
K: FactKey,
{
self.try_replace_arc::<K>(source)
.unwrap_or_else(|error| panic!("{error}"));
}
pub fn try_replace<K, S>(&self, source: S) -> Result<(), FactSourceRegistrationError>
where
K: FactKey,
S: FactSource<K> + 'static,
{
self.try_replace_arc::<K>(Arc::new(source))
}
pub fn try_replace_arc<K>(
&self,
source: Arc<dyn FactSource<K>>,
) -> Result<(), FactSourceRegistrationError>
where
K: FactKey,
{
self.insert_source::<K>(source, true)
}
fn insert_source<K>(
&self,
source: Arc<dyn FactSource<K>>,
replace: bool,
) -> Result<(), FactSourceRegistrationError>
where
K: FactKey,
{
if self.inner.shared_empty {
return Err(FactSourceRegistrationError::SharedEmptySession { fact_name: K::NAME });
}
let type_id = TypeId::of::<K>();
let state = {
let mut states = self
.inner
.states
.lock()
.expect("fact state registry mutex should not be poisoned");
if let Some(existing) = states
.get(&type_id)
.and_then(|state| state.downcast_ref::<Arc<FactState<K>>>())
{
Arc::clone(existing)
} else {
let state = Arc::new(FactState::new(None));
states.insert(type_id, Box::new(Arc::clone(&state)));
state
}
};
state.insert_source(source, replace)
}
pub async fn get<K>(&self, key: K) -> FactLoadResult<K::Value>
where
K: FactKey,
{
self.get_many(&[key])
.await
.into_iter()
.next()
.unwrap_or_else(|| {
FactLoadResult::Error(FactLoadError::SourceContractViolation {
fact_name: K::NAME,
expected: 1,
actual: 0,
})
})
}
pub async fn get_many<K>(&self, keys: &[K]) -> Vec<FactLoadResult<K::Value>>
where
K: FactKey,
{
if keys.is_empty() {
return Vec::new();
}
let state = self.state::<K>();
let load_plan = state.plan_loads(keys);
if let Some(results) = load_plan.cached_results {
return results;
}
let mut in_flight_guard = InFlightGuard::new(Arc::clone(&state), load_plan.keys.clone());
if !load_plan.keys.is_empty() {
if let Some(source) = load_plan.source.as_ref() {
let chunk_size = source
.max_batch_size()
.map_or(load_plan.keys.len(), NonZeroUsize::get)
.max(1);
for chunk in load_plan.keys.chunks(chunk_size) {
let load_id = self.inner.next_load_id.fetch_add(1, Ordering::Relaxed);
let load_span = tracing::debug_span!(
"gatehouse.fact_load",
fact.name = K::NAME,
fact.load_id = load_id,
fact.key_count = chunk.len(),
fact.unique_key_count = chunk.len(),
);
let loaded = source.load_many(chunk).instrument(load_span).await;
let results = if loaded.len() == chunk.len() {
loaded
} else {
chunk
.iter()
.map(|_| {
FactLoadResult::Error(FactLoadError::SourceContractViolation {
fact_name: K::NAME,
expected: chunk.len(),
actual: loaded.len(),
})
})
.collect()
};
state.finish_keys(chunk, results);
in_flight_guard.mark_finished(chunk);
}
} else {
let results = load_plan
.keys
.iter()
.map(|_| {
FactLoadResult::Error(FactLoadError::SourceNotRegistered {
fact_name: K::NAME,
})
})
.collect();
state.finish_keys(&load_plan.keys, results);
in_flight_guard.mark_finished(&load_plan.keys);
}
}
for waiter in load_plan.waiters {
let _ = waiter.await;
}
state.results_from_cache(keys)
}
fn state<K>(&self) -> Arc<FactState<K>>
where
K: FactKey,
{
let mut states = self
.inner
.states
.lock()
.expect("fact state registry mutex should not be poisoned");
states
.entry(TypeId::of::<K>())
.or_insert_with(|| Box::new(Arc::new(FactState::<K>::new(None))))
.downcast_ref::<Arc<FactState<K>>>()
.expect("fact state type should match registry key")
.clone()
}
}
pub struct EvaluationSessionBuilder {
session: EvaluationSession,
}
impl EvaluationSessionBuilder {
fn new() -> Self {
Self {
session: EvaluationSession::new(),
}
}
pub fn with<K, S>(self, source: S) -> Self
where
K: FactKey,
S: FactSource<K> + 'static,
{
self.session.register::<K, _>(source);
self
}
pub fn with_arc<K>(self, source: Arc<dyn FactSource<K>>) -> Self
where
K: FactKey,
{
self.session.register_arc::<K>(source);
self
}
pub fn build(self) -> EvaluationSession {
self.session
}
}
struct LoadPlan<K>
where
K: FactKey,
{
source: Option<Arc<dyn FactSource<K>>>,
cached_results: Option<Vec<FactLoadResult<K::Value>>>,
keys: Vec<K>,
waiters: Vec<oneshot::Receiver<()>>,
}
struct InFlightGuard<K>
where
K: FactKey,
{
state: Arc<FactState<K>>,
remaining: Vec<K>,
}
impl<K> InFlightGuard<K>
where
K: FactKey,
{
fn new(state: Arc<FactState<K>>, keys: Vec<K>) -> Self {
Self {
state,
remaining: keys,
}
}
fn mark_finished(&mut self, keys: &[K]) {
if self.remaining.is_empty() {
return;
}
let finished = keys
.iter()
.cloned()
.collect::<std::collections::HashSet<_>>();
self.remaining.retain(|key| !finished.contains(key));
}
}
impl<K> Drop for InFlightGuard<K>
where
K: FactKey,
{
fn drop(&mut self) {
if self.remaining.is_empty() {
return;
}
let cancelled = std::mem::take(&mut self.remaining);
let results = cancelled
.iter()
.map(|_| FactLoadResult::Error(FactLoadError::LoaderCancelled { fact_name: K::NAME }))
.collect();
self.state.finish_keys(&cancelled, results);
}
}