Skip to main content

inc_complete/db/
mod.rs

1use std::collections::BTreeSet;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU32, Ordering};
4
5use crate::accumulate::{ACCUMULATED_COMPUTATION_ID, Accumulate, Accumulated};
6use crate::cell::CellData;
7use crate::storage::StorageFor;
8use crate::{Cell, Computation, Storage};
9
10pub mod debug_with_db;
11mod handle;
12mod serialize;
13mod tests;
14
15pub use handle::DbHandle;
16use parking_lot::Mutex;
17use rustc_hash::FxHashSet;
18
19const START_VERSION: u32 = 1;
20
21/// The central database object to manage and cache incremental computations.
22///
23/// To use this, a type implementing `Storage` is required to be provided.
24/// See the documentation for `impl_storage!`.
25pub struct Db<Storage> {
26    cells: dashmap::DashMap<Cell, CellData>,
27    version: AtomicU32,
28    next_cell: AtomicU32,
29    storage: Storage,
30
31    /// Lock used when acquiring new Cells to ensure the same data isn't assigned
32    /// multiple ids concurrently. Maps computation_id to each lock.
33    cell_locks: dashmap::DashMap<u32, Arc<Mutex<()>>>,
34}
35
36impl<Storage: Default> Db<Storage> {
37    /// Construct a new `Db` object using `Default::default()` for the initial storage.
38    pub fn new() -> Self {
39        Self::with_storage(Storage::default())
40    }
41}
42
43impl<S: Default> Default for Db<S> {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49/// Abstracts over the `get` function provided by `Db<S>` and `DbHandle<S>` to avoid
50/// providing `get` and `get_db` variants for each function.
51pub trait DbGet<C: Computation> {
52    /// Run an incremental computation `C` and return its output.
53    /// If `C` is already cached, no computation will be performed.
54    fn get(&self, key: C) -> C::Output;
55}
56
57impl<S, C> DbGet<C> for Db<S>
58where
59    C: Computation,
60    S: Storage + StorageFor<C>,
61{
62    fn get(&self, key: C) -> C::Output {
63        self.get(key)
64    }
65}
66
67impl<S> Db<S> {
68    /// Construct a new `Db` object with the given initial storage.
69    pub fn with_storage(storage: S) -> Self {
70        Self {
71            cells: Default::default(),
72            version: AtomicU32::new(START_VERSION),
73            next_cell: AtomicU32::new(0),
74            cell_locks: Default::default(),
75            storage,
76        }
77    }
78
79    /// Retrieve an immutable reference to this `Db`'s storage
80    pub fn storage(&self) -> &S {
81        &self.storage
82    }
83
84    /// Retrieve a mutable reference to this `Db`'s storage.
85    ///
86    /// Note that any mutations made to the storage using this are _not_ tracked by the `Db`!
87    /// Using this incorrectly may break correctness!
88    pub fn storage_mut(&mut self) -> &mut S {
89        &mut self.storage
90    }
91}
92
93impl<S: Storage> Db<S> {
94    /// Return the corresponding Cell for a given computation, if it exists.
95    ///
96    /// This will not update any values.
97    fn get_cell<C: Computation>(&self, computation: &C) -> Option<Cell>
98    where
99        S: StorageFor<C>,
100    {
101        self.storage.get_cell_for_computation(computation)
102    }
103
104    pub(crate) fn get_or_insert_cell<C>(&self, input: C) -> Cell
105    where
106        C: Computation,
107        S: StorageFor<C>,
108    {
109        if let Some(cell) = self.get_cell(&input) {
110            return cell;
111        }
112
113        // Cell doesn't exist yet, we need to lock & create a unique Cell for this input
114        let computation_id = C::computation_id();
115        let lock = self.cell_locks.entry(computation_id).or_default().clone();
116        let _guard = lock.lock();
117
118        // Need to check get_cell again in case another thread created this Cell after
119        // our get_cell call but before we acquired the lock
120        if let Some(cell) = self.get_cell(&input) {
121            cell
122        } else {
123            // We just need a unique ID here, we don't care about ordering between
124            // threads, so we're using Ordering::Relaxed.
125            let cell_id = self.next_cell.fetch_add(1, Ordering::Relaxed);
126            let new_cell = Cell::new(cell_id);
127
128            self.cells.insert(new_cell, CellData::new(computation_id));
129            self.storage.insert_new_cell(new_cell, input);
130            new_cell
131        }
132    }
133
134    fn handle(&self, cell: Cell) -> DbHandle<'_, S> {
135        DbHandle::new(self, cell)
136    }
137
138    #[cfg(test)]
139    #[allow(unused)]
140    pub(crate) fn with_cell_data<C: Computation>(&self, input: &C, f: impl FnOnce(&CellData))
141    where
142        S: StorageFor<C>,
143    {
144        let cell = self
145            .get_cell(input)
146            .unwrap_or_else(|| panic!("unwrap_cell_value: Expected cell to exist"));
147
148        self.cells.get(&cell).map(|value| f(&value)).unwrap()
149    }
150
151    pub fn version(&self) -> u32 {
152        self.version.load(Ordering::SeqCst)
153    }
154
155    pub fn gc(&mut self, version: u32) {
156        let used_cells: std::collections::HashSet<Cell> = self
157            .cells
158            .iter()
159            .filter_map(|entry| {
160                if entry.value().last_verified_version >= version {
161                    Some(entry.key().clone())
162                } else {
163                    None
164                }
165            })
166            .collect();
167
168        self.storage.gc(&used_cells);
169    }
170}
171
172impl<S: Storage> Db<S> {
173    /// Updates an input with a new value
174    ///
175    /// This requires an exclusive reference to self to ensure that there are no currently
176    /// running queries. Updating an input while an incremental computation is occurring
177    /// can break soundness for dependency tracking.
178    ///
179    /// Panics if the given computation is not an input - ie. panics if it has at least 1 dependency.
180    pub fn update_input<C>(&mut self, input: C, new_value: C::Output)
181    where
182        C: Computation,
183        S: StorageFor<C>,
184    {
185        let cell_id = self.get_or_insert_cell(input);
186        assert!(
187            self.is_input(cell_id),
188            "`update_input` given a non-input value. Inputs must have 0 dependencies",
189        );
190
191        let changed = self.storage.update_output(cell_id, new_value);
192        let mut cell = self.cells.get_mut(&cell_id).unwrap();
193
194        if changed {
195            let version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
196            cell.last_updated_version = version;
197            cell.last_verified_version = version;
198        } else {
199            cell.last_verified_version = self.version.load(Ordering::SeqCst);
200        }
201    }
202
203    fn is_input(&self, cell: Cell) -> bool {
204        self.with_cell(cell, |cell| {
205            cell.dependencies.is_empty() && cell.input_dependencies.is_empty()
206        })
207    }
208
209    /// True if a given computation is stale and needs to be re-computed.
210    /// Computations which have never been computed are also considered stale.
211    ///
212    /// Note that this may re-compute dependencies of the given computation.
213    pub fn is_stale<C: Computation>(&self, input: &C) -> bool
214    where
215        S: StorageFor<C>,
216    {
217        // If the cell doesn't exist, it is definitely stale
218        let Some(cell) = self.get_cell(input) else {
219            return true;
220        };
221        self.is_stale_cell(cell)
222    }
223
224    /// True if a given cell is stale and needs to be re-computed.
225    ///
226    /// Note that this may re-compute some input
227    fn is_stale_cell(&self, cell: Cell) -> bool {
228        let (computation_id, last_verified, inputs, dependencies) = self.with_cell(cell, |data| {
229            (
230                data.computation_id,
231                data.last_verified_version,
232                data.input_dependencies.clone(),
233                data.dependencies.clone(),
234            )
235        });
236
237        if self.storage.output_is_unset(cell, computation_id) {
238            return true;
239        }
240
241        // Optimization: only recursively check all dependencies if any
242        // of the inputs this cell depends on have changed
243        let inputs_changed = inputs.into_iter().any(|input_id| {
244            // This cell is stale if the dependency has been updated since
245            // we last verified this cell
246            self.with_cell(input_id, |input| input.last_updated_version > last_verified)
247        });
248
249        // Dependencies need to be iterated in the order they were computed.
250        // Otherwise we may re-run a computation which does not need to be re-run.
251        // In the worst case this could even lead to panics - see the div0 test.
252        inputs_changed
253            && dependencies.into_iter().any(|dependency_id| {
254                self.update_cell(dependency_id);
255                self.with_cell(dependency_id, |dependency| {
256                    if computation_id == ACCUMULATED_COMPUTATION_ID {
257                        dependency.last_run_version > last_verified
258                    } else {
259                        dependency.last_updated_version > last_verified
260                    }
261                })
262            })
263    }
264
265    /// Similar to `update_input` but runs the compute function
266    /// instead of accepting a given value. This also will not update
267    /// `self.version`
268    fn run_compute_function(&self, cell_id: Cell) {
269        let computation_id = self.with_cell(cell_id, |data| data.computation_id);
270        self.storage.clear_accumulated_for_cell(cell_id);
271        let handle = self.handle(cell_id);
272        let changed = S::run_computation(&handle, cell_id, computation_id);
273
274        let version = self.version.load(Ordering::SeqCst);
275        let mut cell = self.cells.get_mut(&cell_id).unwrap();
276        cell.last_verified_version = version;
277        cell.last_run_version = version;
278
279        if changed {
280            cell.last_updated_version = version;
281        }
282    }
283
284    /// Trigger an update of the given cell, recursively checking and re-running any out of date
285    /// dependencies.
286    fn update_cell(&self, cell_id: Cell) {
287        let last_verified_version = self.with_cell(cell_id, |data| data.last_verified_version);
288        let version = self.version.load(Ordering::SeqCst);
289
290        if last_verified_version != version {
291            // if any dependency may have changed, update
292            if self.is_stale_cell(cell_id) {
293                let lock = self.with_cell(cell_id, |cell| cell.lock.clone());
294
295                match lock.try_lock() {
296                    Some(guard) => {
297                        self.run_compute_function(cell_id);
298                        drop(guard);
299                    }
300                    None => {
301                        // This computation is already being run in another thread.
302                        // Before blocking and waiting, since we have time, check for a cycle and
303                        // issue and panic if found.
304                        self.check_for_cycle(cell_id);
305
306                        // Block until it finishes and return the result
307                        drop(lock.lock());
308                    }
309                }
310            } else {
311                let mut cell = self.cells.get_mut(&cell_id).unwrap();
312                cell.last_verified_version = version;
313            }
314        }
315    }
316
317    /// Perform a DFS to check for a cycle, panicking if found
318    fn check_for_cycle(&self, starting_cell: Cell) {
319        let mut visited = FxHashSet::default();
320        let mut path = Vec::new();
321
322        // We're going to push actions to this stack. Most actions will be pushing
323        // a dependency cell to track as the next node in the graph, but some will be
324        // pop actions for popping the top node off the current path. If we encounter
325        // a node which is already in the current path, we have found a cycle.
326        let mut stack = Vec::new();
327        stack.push(Action::Traverse(starting_cell));
328
329        enum Action {
330            Traverse(Cell),
331            Pop(Cell),
332        }
333
334        while let Some(action) = stack.pop() {
335            match action {
336                // This assert_eq is never expected to fail
337                Action::Pop(expected) => assert_eq!(path.pop(), Some(expected)),
338                Action::Traverse(cell) => {
339                    if path.contains(&cell) {
340                        // Include the same cell twice so the cycle is more clear to users
341                        path.push(cell);
342                        self.cycle_error(&path);
343                    }
344
345                    if visited.insert(cell) {
346                        path.push(cell);
347                        stack.push(Action::Pop(cell));
348                        self.with_cell(cell, |cell| {
349                            for dependency in cell.dependencies.iter() {
350                                stack.push(Action::Traverse(*dependency));
351                            }
352                        });
353                    }
354                }
355            }
356        }
357    }
358
359    /// Issue an error with the given cycle
360    fn cycle_error(&self, cycle: &[Cell]) {
361        let mut error = String::new();
362        for (i, cell) in cycle.iter().enumerate() {
363            error += &format!(
364                "\n  {}. {}",
365                i + 1,
366                self.storage.input_debug_string(self, *cell)
367            );
368        }
369        panic!("inc-complete: Cycle Detected!\n\nCycle:{error}")
370    }
371
372    /// Retrieves the up to date value for the given computation, re-running any dependencies as
373    /// necessary.
374    ///
375    /// This function can panic if the dynamic type of the value returned by `compute.run(..)` is not `T`.
376    ///
377    /// Locking behavior: This function locks the cell corresponding to the given computation. This
378    /// can cause a deadlock if the computation recursively depends on itself.
379    pub fn get<C: Computation>(&self, compute: C) -> C::Output
380    where
381        S: StorageFor<C>,
382    {
383        let cell_id = self.get_or_insert_cell(compute);
384        self.get_with_cell::<C>(cell_id)
385    }
386
387    pub(crate) fn get_with_cell<Concrete: Computation>(&self, cell_id: Cell) -> Concrete::Output
388    where
389        S: StorageFor<Concrete>,
390    {
391        self.update_cell(cell_id);
392
393        self.storage
394            .get_output(cell_id)
395            .expect("cell result should have been computed already")
396    }
397
398    fn with_cell<R>(&self, cell: Cell, f: impl FnOnce(&CellData) -> R) -> R {
399        f(&self.cells.get(&cell).unwrap())
400    }
401
402    /// Retrieve each accumulated value of the given type after the given computation is run.
403    ///
404    /// This is most often used for operations like retrieving diagnostics or logs.
405    ///
406    /// Compared to [Db::get_accumulated_uncached], this version reuses the normal flow for
407    /// queries and thus saves accumulated values for each intermediate query. This involves
408    /// more synching and data duplication but can be beneficial if intermediate results
409    /// ever need to be reused, e.g. if you call [Db::get_accumulated] in a loop where each
410    /// call may share dependencies. If you already have a single query which emits all the
411    /// accumulated values you need, [Db::get_accumulated_uncached] is likely faster, but
412    /// requires a `&mut Db`.
413    pub fn get_accumulated<Item, C>(&self, compute: C) -> BTreeSet<Item>
414    where
415        S: StorageFor<C> + StorageFor<Accumulated<Item>>,
416        C: Computation,
417        Item: 'static,
418    {
419        let cell_id = self.get_or_insert_cell(compute);
420        self.update_cell(cell_id);
421        self.get(Accumulated::<Item>::new(cell_id))
422    }
423
424    /// Retrieve each accumulated value of the given type after the given computation is run.
425    ///
426    /// This is most often used for operations like retrieving diagnostics or logs.
427    ///
428    /// This is a faster version of [Db::get_accumulated] for some use-cases. This version tends to be
429    /// more efficient when you already have a single query which emits all the accumulated values
430    /// you need, while the original [Db::get_accumulated] is more efficient when you have many
431    /// smaller calls since it avoids duplicated work and is safe to call with only a [DbHandle].
432    pub fn get_accumulated_uncached<Item, C>(&mut self, compute: C) -> BTreeSet<Item>
433    where
434        S: StorageFor<C> + StorageFor<Accumulated<Item>> + Accumulate<Item>,
435        C: Computation,
436        Item: 'static + Ord,
437    {
438        let cell_id = self.get_or_insert_cell(compute);
439        self.update_cell(cell_id);
440
441        let mut items = BTreeSet::new();
442        let mut visited = BTreeSet::new();
443        let mut queue = vec![cell_id];
444
445        while let Some(cell) = queue.pop() {
446            if visited.insert(cell) {
447                self.with_cell(cell, |data| queue.extend_from_slice(&data.dependencies));
448                items.extend(self.storage().get_accumulated::<Vec<Item>>(cell));
449            }
450        }
451
452        items
453    }
454}