1use std::collections::BTreeSet;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU32, Ordering};
4
5use crate::accumulate::{ACCUMULATED_COMPUTATION_ID, Accumulate, Accumulated};
6use crate::cell::CellData;
7use crate::storage::StorageFor;
8use crate::{Cell, Computation, Storage};
9
10pub mod debug_with_db;
11mod handle;
12mod serialize;
13mod tests;
14
15pub use handle::DbHandle;
16use parking_lot::Mutex;
17use rustc_hash::FxHashSet;
18
19const START_VERSION: u32 = 1;
20
21pub struct Db<Storage> {
26 cells: dashmap::DashMap<Cell, CellData>,
27 version: AtomicU32,
28 next_cell: AtomicU32,
29 storage: Storage,
30
31 cell_locks: dashmap::DashMap<u32, Arc<Mutex<()>>>,
34}
35
36impl<Storage: Default> Db<Storage> {
37 pub fn new() -> Self {
39 Self::with_storage(Storage::default())
40 }
41}
42
43impl<S: Default> Default for Db<S> {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49pub trait DbGet<C: Computation> {
52 fn get(&self, key: C) -> C::Output;
55}
56
57impl<S, C> DbGet<C> for Db<S>
58where
59 C: Computation,
60 S: Storage + StorageFor<C>,
61{
62 fn get(&self, key: C) -> C::Output {
63 self.get(key)
64 }
65}
66
67impl<S> Db<S> {
68 pub fn with_storage(storage: S) -> Self {
70 Self {
71 cells: Default::default(),
72 version: AtomicU32::new(START_VERSION),
73 next_cell: AtomicU32::new(0),
74 cell_locks: Default::default(),
75 storage,
76 }
77 }
78
79 pub fn storage(&self) -> &S {
81 &self.storage
82 }
83
84 pub fn storage_mut(&mut self) -> &mut S {
89 &mut self.storage
90 }
91}
92
93impl<S: Storage> Db<S> {
94 fn get_cell<C: Computation>(&self, computation: &C) -> Option<Cell>
98 where
99 S: StorageFor<C>,
100 {
101 self.storage.get_cell_for_computation(computation)
102 }
103
104 pub(crate) fn get_or_insert_cell<C>(&self, input: C) -> Cell
105 where
106 C: Computation,
107 S: StorageFor<C>,
108 {
109 if let Some(cell) = self.get_cell(&input) {
110 return cell;
111 }
112
113 let computation_id = C::computation_id();
115 let lock = self.cell_locks.entry(computation_id).or_default().clone();
116 let _guard = lock.lock();
117
118 if let Some(cell) = self.get_cell(&input) {
121 cell
122 } else {
123 let cell_id = self.next_cell.fetch_add(1, Ordering::Relaxed);
126 let new_cell = Cell::new(cell_id);
127
128 self.cells.insert(new_cell, CellData::new(computation_id));
129 self.storage.insert_new_cell(new_cell, input);
130 new_cell
131 }
132 }
133
134 fn handle(&self, cell: Cell) -> DbHandle<'_, S> {
135 DbHandle::new(self, cell)
136 }
137
138 #[cfg(test)]
139 #[allow(unused)]
140 pub(crate) fn with_cell_data<C: Computation>(&self, input: &C, f: impl FnOnce(&CellData))
141 where
142 S: StorageFor<C>,
143 {
144 let cell = self
145 .get_cell(input)
146 .unwrap_or_else(|| panic!("unwrap_cell_value: Expected cell to exist"));
147
148 self.cells.get(&cell).map(|value| f(&value)).unwrap()
149 }
150
151 pub fn version(&self) -> u32 {
152 self.version.load(Ordering::SeqCst)
153 }
154
155 pub fn gc(&mut self, version: u32) {
156 let used_cells: std::collections::HashSet<Cell> = self
157 .cells
158 .iter()
159 .filter_map(|entry| {
160 if entry.value().last_verified_version >= version {
161 Some(entry.key().clone())
162 } else {
163 None
164 }
165 })
166 .collect();
167
168 self.storage.gc(&used_cells);
169 }
170}
171
172impl<S: Storage> Db<S> {
173 pub fn update_input<C>(&mut self, input: C, new_value: C::Output)
181 where
182 C: Computation,
183 S: StorageFor<C>,
184 {
185 let cell_id = self.get_or_insert_cell(input);
186 assert!(
187 self.is_input(cell_id),
188 "`update_input` given a non-input value. Inputs must have 0 dependencies",
189 );
190
191 let changed = self.storage.update_output(cell_id, new_value);
192 let mut cell = self.cells.get_mut(&cell_id).unwrap();
193
194 if changed {
195 let version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
196 cell.last_updated_version = version;
197 cell.last_verified_version = version;
198 } else {
199 cell.last_verified_version = self.version.load(Ordering::SeqCst);
200 }
201 }
202
203 fn is_input(&self, cell: Cell) -> bool {
204 self.with_cell(cell, |cell| {
205 cell.dependencies.is_empty() && cell.input_dependencies.is_empty()
206 })
207 }
208
209 pub fn is_stale<C: Computation>(&self, input: &C) -> bool
214 where
215 S: StorageFor<C>,
216 {
217 let Some(cell) = self.get_cell(input) else {
219 return true;
220 };
221 self.is_stale_cell(cell)
222 }
223
224 fn is_stale_cell(&self, cell: Cell) -> bool {
228 let (computation_id, last_verified, inputs, dependencies) = self.with_cell(cell, |data| {
229 (
230 data.computation_id,
231 data.last_verified_version,
232 data.input_dependencies.clone(),
233 data.dependencies.clone(),
234 )
235 });
236
237 if self.storage.output_is_unset(cell, computation_id) {
238 return true;
239 }
240
241 let inputs_changed = inputs.into_iter().any(|input_id| {
244 self.with_cell(input_id, |input| input.last_updated_version > last_verified)
247 });
248
249 inputs_changed
253 && dependencies.into_iter().any(|dependency_id| {
254 self.update_cell(dependency_id);
255 self.with_cell(dependency_id, |dependency| {
256 if computation_id == ACCUMULATED_COMPUTATION_ID {
257 dependency.last_run_version > last_verified
258 } else {
259 dependency.last_updated_version > last_verified
260 }
261 })
262 })
263 }
264
265 fn run_compute_function(&self, cell_id: Cell) {
269 let computation_id = self.with_cell(cell_id, |data| data.computation_id);
270 self.storage.clear_accumulated_for_cell(cell_id);
271 let handle = self.handle(cell_id);
272 let changed = S::run_computation(&handle, cell_id, computation_id);
273
274 let version = self.version.load(Ordering::SeqCst);
275 let mut cell = self.cells.get_mut(&cell_id).unwrap();
276 cell.last_verified_version = version;
277 cell.last_run_version = version;
278
279 if changed {
280 cell.last_updated_version = version;
281 }
282 }
283
284 fn update_cell(&self, cell_id: Cell) {
287 let last_verified_version = self.with_cell(cell_id, |data| data.last_verified_version);
288 let version = self.version.load(Ordering::SeqCst);
289
290 if last_verified_version != version {
291 if self.is_stale_cell(cell_id) {
293 let lock = self.with_cell(cell_id, |cell| cell.lock.clone());
294
295 match lock.try_lock() {
296 Some(guard) => {
297 self.run_compute_function(cell_id);
298 drop(guard);
299 }
300 None => {
301 self.check_for_cycle(cell_id);
305
306 drop(lock.lock());
308 }
309 }
310 } else {
311 let mut cell = self.cells.get_mut(&cell_id).unwrap();
312 cell.last_verified_version = version;
313 }
314 }
315 }
316
317 fn check_for_cycle(&self, starting_cell: Cell) {
319 let mut visited = FxHashSet::default();
320 let mut path = Vec::new();
321
322 let mut stack = Vec::new();
327 stack.push(Action::Traverse(starting_cell));
328
329 enum Action {
330 Traverse(Cell),
331 Pop(Cell),
332 }
333
334 while let Some(action) = stack.pop() {
335 match action {
336 Action::Pop(expected) => assert_eq!(path.pop(), Some(expected)),
338 Action::Traverse(cell) => {
339 if path.contains(&cell) {
340 path.push(cell);
342 self.cycle_error(&path);
343 }
344
345 if visited.insert(cell) {
346 path.push(cell);
347 stack.push(Action::Pop(cell));
348 self.with_cell(cell, |cell| {
349 for dependency in cell.dependencies.iter() {
350 stack.push(Action::Traverse(*dependency));
351 }
352 });
353 }
354 }
355 }
356 }
357 }
358
359 fn cycle_error(&self, cycle: &[Cell]) {
361 let mut error = String::new();
362 for (i, cell) in cycle.iter().enumerate() {
363 error += &format!(
364 "\n {}. {}",
365 i + 1,
366 self.storage.input_debug_string(self, *cell)
367 );
368 }
369 panic!("inc-complete: Cycle Detected!\n\nCycle:{error}")
370 }
371
372 pub fn get<C: Computation>(&self, compute: C) -> C::Output
380 where
381 S: StorageFor<C>,
382 {
383 let cell_id = self.get_or_insert_cell(compute);
384 self.get_with_cell::<C>(cell_id)
385 }
386
387 pub(crate) fn get_with_cell<Concrete: Computation>(&self, cell_id: Cell) -> Concrete::Output
388 where
389 S: StorageFor<Concrete>,
390 {
391 self.update_cell(cell_id);
392
393 self.storage
394 .get_output(cell_id)
395 .expect("cell result should have been computed already")
396 }
397
398 fn with_cell<R>(&self, cell: Cell, f: impl FnOnce(&CellData) -> R) -> R {
399 f(&self.cells.get(&cell).unwrap())
400 }
401
402 pub fn get_accumulated<Item, C>(&self, compute: C) -> BTreeSet<Item>
414 where
415 S: StorageFor<C> + StorageFor<Accumulated<Item>>,
416 C: Computation,
417 Item: 'static,
418 {
419 let cell_id = self.get_or_insert_cell(compute);
420 self.update_cell(cell_id);
421 self.get(Accumulated::<Item>::new(cell_id))
422 }
423
424 pub fn get_accumulated_uncached<Item, C>(&mut self, compute: C) -> BTreeSet<Item>
433 where
434 S: StorageFor<C> + StorageFor<Accumulated<Item>> + Accumulate<Item>,
435 C: Computation,
436 Item: 'static + Ord,
437 {
438 let cell_id = self.get_or_insert_cell(compute);
439 self.update_cell(cell_id);
440
441 let mut items = BTreeSet::new();
442 let mut visited = BTreeSet::new();
443 let mut queue = vec![cell_id];
444
445 while let Some(cell) = queue.pop() {
446 if visited.insert(cell) {
447 self.with_cell(cell, |data| queue.extend_from_slice(&data.dependencies));
448 items.extend(self.storage().get_accumulated::<Vec<Item>>(cell));
449 }
450 }
451
452 items
453 }
454}