use std::{collections::VecDeque, sync::Arc};
use qbice_stable_hash::{Compact128, StableHash};
use tokio::{sync::RwLock, task::JoinSet};
use crate::{
Engine, Query, TrackedEngine,
config::{Config, WriteTransaction},
engine::{
computation_graph::{
ActiveInputSessionGuard,
caller::{CallerInformation, CallerKind, QueryCaller},
},
guard::GuardExt,
},
query::QueryID,
};
pub struct InputSession<C: Config> {
engine: Arc<Engine<C>>,
dirty_batch: Arc<RwLock<VecDeque<QueryID>>>,
comitted: bool,
#[allow(clippy::type_complexity)]
transaction:
Arc<RwLock<Option<(WriteTransaction<C>, ActiveInputSessionGuard)>>>,
}
impl<C: Config> Drop for InputSession<C> {
fn drop(&mut self) {
if self.comitted {
return;
}
let transaction = self.transaction.clone();
let dirty_batch = self.dirty_batch.clone();
let engine = self.engine.clone();
tokio::spawn(async move {
let Some((transaction, guard)) = transaction.write().await.take()
else {
return;
};
let dirty_batch = std::mem::take(&mut *dirty_batch.write().await);
Self::commit_internal(engine, dirty_batch, transaction).await;
drop(guard);
});
}
}
impl<C: Config> InputSession<C> {
pub async fn commit(mut self) {
let engine = self.engine.clone();
async move {
self.comitted = true;
let dirty_batch =
std::mem::take(&mut *self.dirty_batch.write().await);
let (transaction, guard) =
self.transaction.write().await.take().unwrap();
Self::commit_internal(engine, dirty_batch, transaction).await;
drop(guard);
}
.guarded()
.await;
}
async fn commit_internal(
engine: Arc<Engine<C>>,
dirty_batch: VecDeque<QueryID>,
mut transaction: WriteTransaction<C>,
) {
engine.computation_graph.reset_statistic();
engine.clear_dirtied_queries();
transaction = engine
.dirty_propagate_from_batch(dirty_batch.into_iter(), transaction)
.await;
engine.submit_write_buffer(transaction);
}
}
impl<C: Config> std::fmt::Debug for InputSession<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InputSession")
.field("engine", &self.engine)
.field("dirty_batch", &self.dirty_batch)
.finish_non_exhaustive()
}
}
impl<C: Config> Engine<C> {
#[must_use]
#[allow(clippy::unused_async)]
pub async fn input_session(self: &Arc<Self>) -> InputSession<C> {
let (write_buffer_with_lock, active_input_session_guard) =
self.acquire_active_input_session_guard().await;
InputSession {
dirty_batch: Arc::new(RwLock::new(VecDeque::new())),
engine: self.clone(),
comitted: false,
transaction: Arc::new(RwLock::new(Some((
write_buffer_with_lock,
active_input_session_guard,
)))),
}
}
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, StableHash,
)]
pub enum SetInputResult {
Fresh,
Updated,
Unchanged,
}
impl<C: Config> InputSession<C> {
pub async fn set_input<Q: Query>(
&mut self,
query: Q,
new_value: Q::Value,
) -> SetInputResult {
let query_hash = self.engine.hash(&query);
let query_id = QueryID::new::<Q>(query_hash);
let mut snapshot = self.engine.get_exclusive_snapshot(query_hash).await;
let query_value_fingerprint = self.engine.hash(&new_value);
let set_input_result = (snapshot.node_info().await).map_or(
SetInputResult::Fresh,
|node_info| {
let fingerprint_diff =
node_info.value_fingerprint() != query_value_fingerprint;
if fingerprint_diff {
SetInputResult::Updated
} else {
SetInputResult::Unchanged
}
},
);
let ts = unsafe { self.engine.get_current_timestamp_unchecked() };
let transaction = self.transaction.clone();
let dirty_batch = self.dirty_batch.clone();
async move {
if set_input_result == SetInputResult::Updated {
dirty_batch.write().await.push_back(query_id);
}
let mut transaction = transaction.write().await;
let Some((write_buffer, _guard)) = transaction.as_mut() else {
panic!("InputSession transaction has already been committed");
};
snapshot
.set_computed_input(
query,
query_hash,
new_value,
query_value_fingerprint,
write_buffer,
true,
ts,
)
.await;
set_input_result
}
.guarded()
.await
}
pub async fn update<Q: Query>(
&mut self,
query: Q,
update_fn: impl FnOnce(Option<Q::Value>) -> Q::Value,
) -> SetInputResult {
let query_hash = self.engine.hash(&query);
let query_id = QueryID::new::<Q>(query_hash);
let mut snapshot = self.engine.get_exclusive_snapshot(query_hash).await;
let current_value = snapshot.query_result().await;
let new_value = update_fn(current_value);
let query_value_fingerprint = self.engine.hash(&new_value);
let set_input_result = (snapshot.node_info().await).map_or(
SetInputResult::Fresh,
|node_info| {
let fingerprint_diff =
node_info.value_fingerprint() != query_value_fingerprint;
if fingerprint_diff {
SetInputResult::Updated
} else {
SetInputResult::Unchanged
}
},
);
let ts = unsafe { self.engine.get_current_timestamp_unchecked() };
let transaction = self.transaction.clone();
let dirty_batch = self.dirty_batch.clone();
async move {
if set_input_result == SetInputResult::Updated {
dirty_batch.write().await.push_back(query_id);
}
let mut transaction = transaction.write().await;
let Some((write_buffer, _guard)) = transaction.as_mut() else {
panic!("InputSession transaction has already been committed");
};
snapshot
.set_computed_input(
query,
query_hash,
new_value,
query_value_fingerprint,
write_buffer,
true,
ts,
)
.await;
set_input_result
}
.guarded()
.await
}
#[allow(clippy::too_many_lines)]
pub async fn refresh<Q: Query>(&mut self) {
struct RefreshResult<K, V> {
query_input_hash: Compact128,
query_result_hash: Compact128,
query_input: K,
new_value: V,
}
let type_id = Q::STABLE_TYPE_ID;
let external_input_set =
self.engine.get_external_input_queries(&type_id).await;
let timestamp =
unsafe { self.engine.get_current_timestamp_unchecked() };
let hashes = external_input_set.collect::<Vec<_>>();
let expected_parallelism = std::thread::available_parallelism()
.map_or_else(|_| 1, std::num::NonZero::get)
* 4;
let chunk_size =
std::cmp::max((hashes.len()) / expected_parallelism, 1);
let mut join_set = JoinSet::new();
for chunk in hashes.chunks(chunk_size) {
let engine = self.engine.clone();
let chunk = chunk.to_owned();
join_set.spawn(async move {
let mut results = Vec::with_capacity(chunk.len());
for query_hash in chunk.iter().copied() {
let query_id = QueryID::new::<Q>(query_hash);
let query = engine.get_query_input::<Q>(&query_hash).await;
let wait_group = waitgroup::WaitGroup::new();
let tracked_engine = TrackedEngine::new(
engine.clone(),
CallerInformation::new(
CallerKind::Query(QueryCaller::new_external_input(
query_id,
wait_group.worker(),
)),
timestamp,
None,
),
);
let entry =
engine.executor_registry.get_executor_entry::<Q>();
let result = entry
.invoke_executor::<Q>(&query, &tracked_engine)
.await;
drop(tracked_engine);
wait_group.wait().await;
let new_value = match result {
Ok(value) => value,
Err(panic) => panic.resume_unwind(),
};
let new_fingerprint = engine.hash(&new_value);
results.push(RefreshResult {
query_input: query,
query_input_hash: query_hash,
query_result_hash: new_fingerprint,
new_value,
});
}
results
});
}
let dirty_batch = self.dirty_batch.clone();
let engine = self.engine.clone();
let transaction = self.transaction.clone();
async move {
let mut dirty_batch = dirty_batch.write().await;
let mut transaction = transaction.write().await;
let Some((transaction, _guard)) = transaction.as_mut() else {
panic!("InputSession transaction has already been committed");
};
while let Some(res) = join_set.join_next().await {
match res {
Ok(results) => {
for refresh_result in results {
let mut snapshot = engine
.get_exclusive_snapshot::<Q>(
refresh_result.query_input_hash,
)
.await;
let fingerprint_diff =
snapshot.value_fingerprint().await.unwrap()
!= refresh_result.query_result_hash;
if fingerprint_diff {
let query_id = QueryID::new::<Q>(
refresh_result.query_input_hash,
);
dirty_batch.push_back(query_id);
}
snapshot
.set_computed_input(
refresh_result.query_input,
refresh_result.query_input_hash,
refresh_result.new_value,
refresh_result.query_result_hash,
transaction,
false,
timestamp,
)
.await;
}
}
Err(er) => match er.try_into_panic() {
Ok(panic_reason) => {
std::panic::resume_unwind(panic_reason);
}
Err(er) => {
panic!(
"Failed to refresh external input query: {er}"
);
}
},
}
}
}
.guarded()
.await;
}
pub fn intern<
T: StableHash + crate::Identifiable + Send + Sync + 'static,
>(
&self,
value: T,
) -> qbice_storage::intern::Interned<T> {
self.engine.intern(value)
}
pub fn intern_unsized<
T: StableHash + crate::Identifiable + Send + Sync + 'static + ?Sized,
Q: std::borrow::Borrow<T> + Send + Sync + 'static,
>(
&self,
value: Q,
) -> qbice_storage::intern::Interned<T>
where
Arc<T>: From<Q>,
{
self.engine.intern_unsized(value)
}
}