use std::{any::Any, cell::RefCell, collections::HashMap, sync::Arc};
pub(crate) use caller::CallerInformation;
use dashmap::DashSet;
pub(crate) use database::{ActiveInputSessionGuard, QueryDebug};
pub use input_session::{InputSession, SetInputResult};
use qbice_serialize::{Decode, Encode};
use qbice_stable_hash::{BuildStableHasher, StableHash, StableHasher};
use qbice_stable_type_id::Identifiable;
pub(crate) use slow_path::GuardedTrackedEngine;
use thread_local::ThreadLocal;
use crate::{
Engine, ExecutionStyle, Query,
config::Config,
engine::computation_graph::{
caller::CallerKind,
computing::Computing,
database::{ActiveComputationGuard, Database},
dirty_worker::DirtyWorker,
fast_path::FastPathResult,
query_lock_manager::QueryLockManager,
slow_path::SlowPath,
statistic::Statistic,
},
executor::{CyclicError, CyclicPanicPayload},
query::QueryID,
};
mod backward_projection;
mod caller;
mod computing;
mod database;
mod dirty_worker;
mod fast_path;
mod input_session;
mod query_lock_manager;
mod register_callee;
mod repair;
mod slow_path;
mod statistic;
mod tfc_achetype;
mod visualization;
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Encode,
Decode,
Default,
)]
pub enum QueryKind {
#[default]
Input,
Executable(ExecutionStyle),
}
impl QueryKind {
#[must_use]
pub const fn is_input(self) -> bool { matches!(self, Self::Input) }
#[must_use]
pub const fn is_firewall(self) -> bool {
matches!(self, Self::Executable(ExecutionStyle::Firewall))
}
pub const fn is_external_input(self) -> bool {
matches!(self, Self::Executable(ExecutionStyle::ExternalInput))
}
}
pub struct ComputationGraph<C: Config> {
dirty_worker: DirtyWorker<C>,
dirtied_queries: Arc<DashSet<QueryID, C::BuildHasher>>,
database: Arc<Database<C>>,
statistic: Arc<Statistic>,
computing: Computing<C>,
lock_manager: QueryLockManager,
}
impl<C: Config> ComputationGraph<C> {
pub async fn new(db: &C::StorageEngine) -> Self {
let database = Arc::new(Database::new(db).await);
let statistic = Arc::new(Statistic::default());
let dirtied_queries =
Arc::new(DashSet::with_hasher(C::BuildHasher::default()));
Self {
dirty_worker: DirtyWorker::new(
&database,
&statistic,
&dirtied_queries,
),
database,
dirtied_queries,
statistic,
computing: Computing::new(),
lock_manager: QueryLockManager::new(2u64.pow(14)),
}
}
}
pub struct TrackedEngine<C: Config> {
engine: Arc<Engine<C>>,
cache: ThreadLocal<RefCell<HashMap<QueryID, Box<dyn Any + Send + Sync>>>>,
caller: CallerInformation,
}
impl<C: Config> TrackedEngine<C> {
pub async fn query<Q: Query>(&self, query: &Q) -> Q::Value {
self.engine.yielder.tick().await;
let query_with_id = self.engine.new_query_with_id(query);
if let Some(val) =
self.cache.get_or_default().borrow().get(&query_with_id.id)
{
return val
.downcast_ref::<Q::Value>()
.expect("cached value has incorrect type")
.clone();
}
let result = self
.engine
.query_for(&query_with_id, &self.caller)
.await
.map(QueryResult::unwrap_return);
if let Ok(value) = &result {
self.cache
.get_or_default()
.borrow_mut()
.insert(query_with_id.id, Box::new(value.clone()));
}
result.unwrap_or_else(|_| CyclicPanicPayload::unwind())
}
pub async fn repair_transitive_firewall_callees<Q: Query>(
&self,
query: &Q,
) {
let query_with_id = self.engine.new_query_with_id(query);
self.engine
.repair_transitive_firewall_callees_for(
&query_with_id,
&self.caller,
)
.await;
}
pub fn intern<T: StableHash + Identifiable + Send + Sync + 'static>(
&self,
value: T,
) -> qbice_storage::intern::Interned<T> {
self.engine.intern(value)
}
pub fn intern_unsized<
T: StableHash + Identifiable + Send + Sync + 'static + ?Sized,
Q: std::borrow::Borrow<T> + Send + Sync + 'static,
>(
&self,
value: Q,
) -> qbice_storage::intern::Interned<T>
where
Arc<T>: From<Q>,
{
self.engine.intern_unsized(value)
}
}
impl<C: Config> TrackedEngine<C> {
pub(crate) const fn new(
engine: Arc<Engine<C>>,
caller: CallerInformation,
) -> Self {
Self { engine, cache: ThreadLocal::new(), caller }
}
}
impl<C: Config> Clone for TrackedEngine<C> {
fn clone(&self) -> Self {
Self {
engine: self.engine.clone(),
cache: ThreadLocal::new(),
caller: self.caller.clone(),
}
}
}
impl<C: Config> std::fmt::Debug for TrackedEngine<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TrackedEngine")
.field("engine", &self.engine)
.field("caller", &self.caller)
.finish_non_exhaustive()
}
}
impl<C: Config> Engine<C> {
#[must_use]
#[allow(clippy::unused_async)]
pub async fn tracked(self: Arc<Self>) -> TrackedEngine<C> {
let (active_computation_guard, timestamp) =
self.acquire_active_computation_guard().await;
TrackedEngine {
caller: CallerInformation::new(
CallerKind::User,
timestamp,
Some(active_computation_guard),
),
cache: ThreadLocal::new(),
engine: self,
}
}
}
#[derive(Debug, Clone)]
pub struct QueryWithID<'c, Q: Query> {
id: QueryID,
query: &'c Q,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum QueryStatus {
Repaired,
UpToDate,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct QueryResult<V> {
pub return_value: Option<V>,
pub status: QueryStatus,
}
impl<V> QueryResult<V> {
pub fn unwrap_return(self) -> V {
self.return_value.expect("Query did not return a value")
}
}
impl<C: Config> Engine<C> {
async fn repair_transitive_firewall_callees_for<Q: Query>(
self: &Arc<Self>,
query: &QueryWithID<'_, Q>,
caller: &CallerInformation,
) {
let mut snapshot =
self.get_read_snapshot::<Q>(query.id.compact_hash_128()).await;
if snapshot.last_verified().await.is_none() {
return;
}
snapshot.repair_transitive_firewall_callees(caller).await;
}
async fn query_for<Q: Query>(
self: &Arc<Self>,
query: &QueryWithID<'_, Q>,
caller: &CallerInformation,
) -> Result<QueryResult<Q::Value>, CyclicError> {
self.yielder.tick().await;
let undo_register = self.register_callee(caller, &query.id);
let mut status = QueryStatus::UpToDate;
let value = loop {
match self.exit_scc(&query.id, caller).await {
Ok(_) => {}
Err(err) => {
if let Some(undo) = undo_register {
undo.defuse();
}
return Err(err);
}
}
let mut snapshot =
self.get_read_snapshot::<Q>(query.id.compact_hash_128()).await;
let slow_path = match snapshot.fast_path(caller).await {
FastPathResult::ToSlowPath(slow_path) => slow_path,
FastPathResult::Hit(value) => {
if let Some(undo_register) = undo_register {
undo_register.defuse();
}
break QueryResult { return_value: value, status };
}
};
if matches!(
caller.kind(),
CallerKind::User | CallerKind::RepairFirewall
) && slow_path == SlowPath::Repair
{
snapshot.repair_transitive_firewall_callees(caller).await;
snapshot = self
.get_read_snapshot::<Q>(query.id.compact_hash_128())
.await;
}
let Some((snapshot, guard)) =
snapshot.get_write_guard(slow_path, caller).await
else {
continue;
};
snapshot.process_query(query.query, caller, guard).await;
status = QueryStatus::Repaired;
};
self.is_query_running_in_scc(caller.get_caller())?;
Ok(value)
}
pub(super) fn new_query_with_id<'c, Q: Query>(
&'c self,
query: &'c Q,
) -> QueryWithID<'c, Q> {
let mut hash = self.build_stable_hasher.build_stable_hasher();
query.stable_hash(&mut hash);
QueryWithID { id: QueryID::new::<Q>(hash.finish().into()), query }
}
}