use std::collections::BTreeSet;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use crate::accumulate::Accumulated;
use crate::cell::CellData;
use crate::storage::StorageFor;
use crate::{Cell, Computation, Storage};
pub mod debug_with_db;
mod handle;
mod serialize;
mod tests;
pub use handle::DbHandle;
use parking_lot::Mutex;
use rustc_hash::FxHashSet;
const START_VERSION: u32 = 1;
pub struct Db<Storage> {
cells: dashmap::DashMap<Cell, CellData>,
version: AtomicU32,
next_cell: AtomicU32,
storage: Storage,
cell_locks: dashmap::DashMap<u32, Arc<Mutex<()>>>,
}
impl<Storage: Default> Db<Storage> {
pub fn new() -> Self {
Self::with_storage(Storage::default())
}
}
impl<S: Default> Default for Db<S> {
fn default() -> Self {
Self::new()
}
}
pub trait DbGet<C: Computation> {
fn get(&self, key: C) -> C::Output;
}
impl<S, C> DbGet<C> for Db<S>
where
C: Computation,
S: Storage + StorageFor<C>,
{
fn get(&self, key: C) -> C::Output {
self.get(key)
}
}
impl<S> Db<S> {
pub fn with_storage(storage: S) -> Self {
Self {
cells: Default::default(),
version: AtomicU32::new(START_VERSION),
next_cell: AtomicU32::new(0),
cell_locks: Default::default(),
storage,
}
}
pub fn storage(&self) -> &S {
&self.storage
}
pub fn storage_mut(&mut self) -> &mut S {
&mut self.storage
}
}
impl<S: Storage> Db<S> {
fn get_cell<C: Computation>(&self, computation: &C) -> Option<Cell>
where
S: StorageFor<C>,
{
self.storage.get_cell_for_computation(computation)
}
pub(crate) fn get_or_insert_cell<C>(&self, input: C) -> Cell
where
C: Computation,
S: StorageFor<C>,
{
let computation_id = C::computation_id();
let lock = self.cell_locks.entry(computation_id).or_default().clone();
let _guard = lock.lock();
if let Some(cell) = self.get_cell(&input) {
cell
} else {
let cell_id = self.next_cell.fetch_add(1, Ordering::Relaxed);
let new_cell = Cell::new(cell_id);
self.cells.insert(new_cell, CellData::new(computation_id));
self.storage.insert_new_cell(new_cell, input);
new_cell
}
}
fn handle(&self, cell: Cell) -> DbHandle<'_, S> {
DbHandle::new(self, cell)
}
#[cfg(test)]
#[allow(unused)]
pub(crate) fn with_cell_data<C: Computation>(&self, input: &C, f: impl FnOnce(&CellData))
where
S: StorageFor<C>,
{
let cell = self
.get_cell(input)
.unwrap_or_else(|| panic!("unwrap_cell_value: Expected cell to exist"));
self.cells.get(&cell).map(|value| f(&value)).unwrap()
}
pub fn version(&self) -> u32 {
self.version.load(Ordering::SeqCst)
}
pub fn gc(&mut self, version: u32) {
let used_cells: std::collections::HashSet<Cell> = self
.cells
.iter()
.filter_map(|entry| {
if entry.value().last_verified_version >= version {
Some(entry.key().clone())
} else {
None
}
})
.collect();
self.storage.gc(&used_cells);
}
}
impl<S: Storage> Db<S> {
pub fn update_input<C>(&mut self, input: C, new_value: C::Output)
where
C: Computation,
S: StorageFor<C>,
{
let cell_id = self.get_or_insert_cell(input);
assert!(
self.is_input(cell_id),
"`update_input` given a non-input value. Inputs must have 0 dependencies",
);
let changed = self.storage.update_output(cell_id, new_value);
let mut cell = self.cells.get_mut(&cell_id).unwrap();
if changed {
let version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
cell.last_updated_version = version;
cell.last_verified_version = version;
} else {
cell.last_verified_version = self.version.load(Ordering::SeqCst);
}
}
fn is_input(&self, cell: Cell) -> bool {
self.with_cell(cell, |cell| {
cell.dependencies.is_empty() && cell.input_dependencies.is_empty()
})
}
pub fn is_stale<C: Computation>(&self, input: &C) -> bool
where
S: StorageFor<C>,
{
let Some(cell) = self.get_cell(input) else {
return true;
};
self.is_stale_cell(cell)
}
fn is_stale_cell(&self, cell: Cell) -> bool {
let computation_id = self.with_cell(cell, |data| data.computation_id);
if self.storage.output_is_unset(cell, computation_id) {
return true;
}
let (last_verified, inputs, dependencies) = self.with_cell(cell, |data| {
(
data.last_verified_version,
data.input_dependencies.clone(),
data.dependencies.clone(),
)
});
let inputs_changed = inputs.into_iter().any(|input_id| {
self.with_cell(input_id, |input| input.last_updated_version > last_verified)
});
inputs_changed
&& dependencies.into_iter().any(|dependency_id| {
self.update_cell(dependency_id);
self.with_cell(dependency_id, |dependency| {
dependency.last_updated_version > last_verified
})
})
}
fn run_compute_function(&self, cell_id: Cell) {
let computation_id = self.with_cell(cell_id, |data| data.computation_id);
self.storage.clear_accumulated_for_cell(cell_id);
let handle = self.handle(cell_id);
let changed = S::run_computation(&handle, cell_id, computation_id);
let version = self.version.load(Ordering::SeqCst);
let mut cell = self.cells.get_mut(&cell_id).unwrap();
cell.last_verified_version = version;
if changed {
cell.last_updated_version = version;
}
}
fn update_cell(&self, cell_id: Cell) {
let last_verified_version = self.with_cell(cell_id, |data| data.last_verified_version);
let version = self.version.load(Ordering::SeqCst);
if last_verified_version != version {
if self.is_stale_cell(cell_id) {
let lock = self.with_cell(cell_id, |cell| cell.lock.clone());
match lock.try_lock() {
Some(guard) => {
self.run_compute_function(cell_id);
drop(guard);
}
None => {
self.check_for_cycle(cell_id);
drop(lock.lock());
}
}
} else {
let mut cell = self.cells.get_mut(&cell_id).unwrap();
cell.last_verified_version = version;
}
}
}
fn check_for_cycle(&self, starting_cell: Cell) {
let mut visited = FxHashSet::default();
let mut path = Vec::new();
let mut stack = Vec::new();
stack.push(Action::Traverse(starting_cell));
enum Action {
Traverse(Cell),
Pop(Cell),
}
while let Some(action) = stack.pop() {
match action {
Action::Pop(expected) => assert_eq!(path.pop(), Some(expected)),
Action::Traverse(cell) => {
if path.contains(&cell) {
path.push(cell);
self.cycle_error(&path);
}
if visited.insert(cell) {
path.push(cell);
stack.push(Action::Pop(cell));
self.with_cell(cell, |cell| {
for dependency in cell.dependencies.iter() {
stack.push(Action::Traverse(*dependency));
}
});
}
}
}
}
}
fn cycle_error(&self, cycle: &[Cell]) {
let mut error = String::new();
for (i, cell) in cycle.iter().enumerate() {
error += &format!(
"\n {}. {}",
i + 1,
self.storage.input_debug_string(self, *cell)
);
}
panic!("inc-complete: Cycle Detected!\n\nCycle:{error}")
}
pub fn get<C: Computation>(&self, compute: C) -> C::Output
where
S: StorageFor<C>,
{
let cell_id = self.get_or_insert_cell(compute);
self.get_with_cell::<C>(cell_id)
}
pub(crate) fn get_with_cell<Concrete: Computation>(&self, cell_id: Cell) -> Concrete::Output
where
S: StorageFor<Concrete>,
{
self.update_cell(cell_id);
self.storage
.get_output(cell_id)
.expect("cell result should have been computed already")
}
fn with_cell<R>(&self, cell: Cell, f: impl FnOnce(&CellData) -> R) -> R {
f(&self.cells.get(&cell).unwrap())
}
pub fn get_accumulated<Item, C>(&self, compute: C) -> BTreeSet<Item>
where
S: StorageFor<C> + StorageFor<Accumulated<Item>>,
C: Computation,
Item: 'static,
{
let cell_id = self.get_or_insert_cell(compute);
self.update_cell(cell_id);
self.get(Accumulated::<Item>::new(cell_id))
}
}