use std::collections::BTreeSet;
use crate::{
Cell, Computation, Db, Storage,
accumulate::{Accumulate, Accumulated},
storage::StorageFor,
};
use super::DbGet;
pub struct DbHandle<'db, S> {
db: &'db Db<S>,
current_operation: Cell,
}
impl<'db, S> DbHandle<'db, S> {
pub(crate) fn new(db: &'db Db<S>, current_operation: Cell) -> Self {
let mut cell = db.cells.get_mut(¤t_operation).unwrap();
cell.dependencies.clear();
cell.input_dependencies.clear();
Self {
db,
current_operation,
}
}
pub fn storage(&self) -> &S {
self.db.storage()
}
}
impl<S: Storage> DbHandle<'_, S> {
pub fn get<C: Computation>(&self, compute: C) -> C::Output
where
S: StorageFor<C>,
{
let dependency = self.db.get_or_insert_cell(compute);
self.update_and_register_dependency::<C>(dependency);
self.db.get_with_cell(dependency)
}
fn update_and_register_dependency<C: Computation>(&self, dependency: Cell) {
self.update_and_register_dependency_inner(dependency, C::IS_INPUT);
}
fn update_and_register_dependency_inner(&self, dependency: Cell, is_input: bool) {
let mut cell = self.db.cells.get_mut(&self.current_operation).unwrap();
cell.dependencies.push(dependency);
if is_input {
cell.input_dependencies.insert(dependency);
}
drop(cell);
self.db.update_cell(dependency);
let dependency = self.db.cells.get(&dependency).unwrap();
let dependency_inputs = dependency.input_dependencies.clone();
drop(dependency);
if !dependency_inputs.is_empty() {
let mut cell = self.db.cells.get_mut(&self.current_operation).unwrap();
for input in dependency_inputs {
cell.input_dependencies.insert(input);
}
}
}
pub fn accumulate<Item>(&self, item: Item)
where
S: Accumulate<Item>,
{
self.storage().accumulate(self.current_operation, item);
}
pub fn get_accumulated<Item, C>(&self, compute: C) -> BTreeSet<Item>
where
C: Computation,
Item: 'static + Ord,
S: StorageFor<Accumulated<Item>> + StorageFor<C> + Accumulate<Item>,
{
let dependency = self.db.get_or_insert_cell(compute);
self.get_accumulated_with_cell::<Item>(dependency)
}
pub(crate) fn get_accumulated_with_cell<Item>(&self, cell_id: Cell) -> BTreeSet<Item>
where
Item: 'static + Ord,
S: StorageFor<Accumulated<Item>> + Accumulate<Item>,
{
self.update_and_register_dependency_inner(cell_id, false);
let dependencies = self.db.with_cell(cell_id, |cell| cell.dependencies.clone());
let computation_id = Accumulated::<Item>::computation_id();
let mut result: BTreeSet<Item> = dependencies
.into_iter()
.filter(|&dep| self.db.with_cell(dep, |cell| cell.computation_id) != computation_id)
.flat_map(|dependency| self.get(Accumulated::<Item>::new(dependency)))
.collect();
result.extend(self.storage().get_accumulated::<Vec<Item>>(cell_id));
result
}
}
impl<'db, S, C> DbGet<C> for DbHandle<'db, S>
where
C: Computation,
S: Storage + StorageFor<C>,
{
fn get(&self, key: C) -> C::Output {
self.get(key)
}
}