use std::{
any::Any, collections::HashMap, mem::MaybeUninit, panic::AssertUnwindSafe,
pin::Pin, sync::Arc,
};
use futures::FutureExt;
use qbice_stable_hash::Compact128;
use qbice_stable_type_id::StableTypeID;
use crate::{
Engine, TrackedEngine,
config::Config,
engine::computation_graph::{
CallerInformation, GuardedTrackedEngine, QueryDebug, QueryStatus,
},
query::{ExecutionStyle, Query},
};
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Default,
thiserror::Error,
)]
#[error("cyclic query detected")]
pub struct CyclicError;
pub(crate) struct Panicked(Box<dyn Any + Send + 'static>);
impl Panicked {
pub fn resume_unwind(self) -> ! { std::panic::resume_unwind(self.0) }
}
pub(crate) struct CyclicPanicPayload;
impl CyclicPanicPayload {
pub fn unwind() -> ! { std::panic::panic_any(Self) }
}
pub trait Executor<Q: Query, C: Config>: 'static + Send + Sync {
fn execute<'s, 'q, 'e>(
&'s self,
query: &'q Q,
engine: &'e TrackedEngine<C>,
) -> impl Future<Output = Q::Value> + Send + use<'s, 'q, 'e, Self, Q, C>;
#[must_use]
fn execution_style() -> ExecutionStyle { ExecutionStyle::Normal }
#[must_use]
fn scc_value() -> Q::Value { panic!("SCC value is not specified") }
}
fn invoke_executor<
'a,
C: Config,
E: Executor<K, C> + 'static,
K: Query + 'static,
>(
key: &'a dyn Any,
executor: &'a dyn Any,
engine: &'a GuardedTrackedEngine<C>,
result: &'a mut (dyn Any + Send),
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
let key = key.downcast_ref::<K>().expect("Key type mismatch");
let executor =
executor.downcast_ref::<E>().expect("Executor type mismatch");
Box::pin(async {
let result_buffer: &mut MaybeUninit<Result<K::Value, Panicked>> =
result.downcast_mut().expect("Result type mismatch");
let result =
AssertUnwindSafe(executor.execute(key, engine.tracked_engine()))
.catch_unwind()
.await
.map_err(Panicked);
result_buffer.write(result);
})
}
type InvokeExecutorFn<C> =
for<'a> fn(
key: &'a dyn Any,
executor: &'a dyn Any,
engine: &'a GuardedTrackedEngine<C>,
result: &'a mut (dyn Any + Send),
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
type RepairQueryFn<C> = for<'a> fn(
engine: &'a Arc<Engine<C>>,
query_id: &'a Compact128,
called_from: &'a CallerInformation,
) -> Pin<
Box<dyn Future<Output = Result<QueryStatus, CyclicError>> + Send + 'a>,
>;
type ObtainSccValueFn = for<'a> fn(buffer: &'a mut dyn Any);
type ObtainExecutionStyleFn = fn() -> ExecutionStyle;
type DebugQueryFn<C> = for<'a> fn(
engine: &'a Engine<C>,
query_input_hash_128: Compact128,
) -> Pin<
Box<dyn Future<Output = Option<QueryDebug>> + Send + 'a>,
>;
fn obtain_scc_value<
C: Config,
E: Executor<K, C> + 'static,
K: Query + 'static,
>(
buffer: &mut dyn Any,
) {
let buffer = buffer
.downcast_mut::<MaybeUninit<K::Value>>()
.expect("SCC value buffer type mismatch");
let scc_value = E::scc_value();
buffer.write(scc_value);
}
fn obtain_execution_style<
C: Config,
E: Executor<K, C> + 'static,
K: Query + 'static,
>() -> ExecutionStyle {
E::execution_style()
}
#[derive(Debug, Clone)]
pub(crate) struct Entry<C: Config> {
executor: Arc<dyn Any + Send + Sync>,
invoke_executor: InvokeExecutorFn<C>,
query_debug: DebugQueryFn<C>,
repair_query: RepairQueryFn<C>,
obtain_scc_value: ObtainSccValueFn,
obtain_execution_style: ObtainExecutionStyleFn,
}
impl<C: Config> Entry<C> {
pub fn new<Q: Query, E: Executor<Q, C> + 'static>(
executor: Arc<E>,
) -> Self {
Self {
executor,
invoke_executor: invoke_executor::<C, E, Q>,
query_debug: Engine::<C>::get_query_debug_future::<Q>,
repair_query: Engine::<C>::repair_query_from_query_id::<Q>,
obtain_scc_value: obtain_scc_value::<C, E, Q>,
obtain_execution_style: obtain_execution_style::<C, E, Q>,
}
}
pub async fn invoke_executor<Q: Query>(
&self,
query_key: &Q,
engine: &GuardedTrackedEngine<C>,
) -> Result<Q::Value, Panicked> {
let mut result_buffer =
MaybeUninit::<Result<Q::Value, Panicked>>::uninit();
(self.invoke_executor)(
query_key,
self.executor.as_ref(),
engine,
&mut result_buffer,
)
.await;
unsafe { result_buffer.assume_init() }
}
pub async fn repair_query_from_query_id(
&self,
engine: &Arc<Engine<C>>,
query_id: &Compact128,
caller_information: &CallerInformation,
) -> Result<QueryStatus, CyclicError> {
(self.repair_query)(engine, query_id, caller_information).await
}
pub fn obtain_scc_value<Q: Query>(&self) -> Q::Value {
let mut buffer = MaybeUninit::<Q::Value>::uninit();
(self.obtain_scc_value)(&mut buffer);
unsafe { buffer.assume_init() }
}
pub fn obtain_execution_style(&self) -> ExecutionStyle {
(self.obtain_execution_style)()
}
pub async fn get_query_debug(
&self,
engine: &Engine<C>,
query_input_hash_128: Compact128,
) -> Option<QueryDebug> {
(self.query_debug)(engine, query_input_hash_128).await
}
}
#[derive(Debug, Default)]
pub struct Registry<C: Config> {
executors_by_key_type_id: HashMap<StableTypeID, Entry<C>>,
}
impl<C: Config> Registry<C> {
pub fn register<Q: Query, E: Executor<Q, C> + 'static>(
&mut self,
executor: Arc<E>,
) {
let entry = Entry::new::<Q, E>(executor);
self.executors_by_key_type_id.insert(Q::STABLE_TYPE_ID, entry);
}
#[must_use]
pub(crate) fn get_executor_entry_by_type_id(
&self,
type_id: &StableTypeID,
) -> &Entry<C> {
self.executors_by_key_type_id.get(type_id).unwrap_or_else(|| {
panic!("Failed to find executor for query type id: {type_id:?}")
})
}
#[must_use]
pub(crate) fn try_get_executor_entry_by_type_id(
&self,
type_id: &StableTypeID,
) -> Option<&Entry<C>> {
self.executors_by_key_type_id.get(type_id)
}
#[must_use]
pub(crate) fn get_executor_entry<Q: Query>(&self) -> &Entry<C> {
self.executors_by_key_type_id.get(&Q::STABLE_TYPE_ID).unwrap_or_else(
|| {
panic!(
"Failed to find executor for query name: {}",
std::any::type_name::<Q>()
)
},
)
}
}