Skip to main content

inc_complete/db/
mod.rs

1use std::collections::BTreeSet;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU32, Ordering};
4
5use crate::accumulate::{ACCUMULATED_COMPUTATION_ID, Accumulate, Accumulated};
6use crate::cell::CellData;
7use crate::storage::StorageFor;
8use crate::{Cell, Computation, Storage};
9
10pub mod debug_with_db;
11mod handle;
12mod serialize;
13mod tests;
14
15pub use handle::DbHandle;
16use parking_lot::Mutex;
17use rustc_hash::FxHashSet;
18
19const START_VERSION: u32 = 1;
20
21/// The central database object to manage and cache incremental computations.
22///
23/// To use this, a type implementing `Storage` is required to be provided.
24/// See the documentation for `impl_storage!`.
25pub struct Db<Storage> {
26    cells: dashmap::DashMap<Cell, CellData>,
27    version: AtomicU32,
28    next_cell: AtomicU32,
29    storage: Storage,
30
31    /// Lock used when acquiring new Cells to ensure the same data isn't assigned
32    /// multiple ids concurrently. Maps computation_id to each lock.
33    cell_locks: dashmap::DashMap<u32, Arc<Mutex<()>>>,
34}
35
36impl<Storage: Default> Db<Storage> {
37    /// Construct a new `Db` object using `Default::default()` for the initial storage.
38    pub fn new() -> Self {
39        Self::with_storage(Storage::default())
40    }
41}
42
43impl<S: Default> Default for Db<S> {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49/// Abstracts over the `get` function provided by `Db<S>` and `DbHandle<S>` to avoid
50/// providing `get` and `get_db` variants for each function.
51pub trait DbGet<C: Computation> {
52    /// Run an incremental computation `C` and return its output.
53    /// If `C` is already cached, no computation will be performed.
54    fn get(&self, key: C) -> C::Output;
55}
56
57impl<S, C> DbGet<C> for Db<S>
58where
59    C: Computation,
60    S: Storage + StorageFor<C>,
61{
62    fn get(&self, key: C) -> C::Output {
63        self.get(key)
64    }
65}
66
67impl<S> Db<S> {
68    /// Construct a new `Db` object with the given initial storage.
69    pub fn with_storage(storage: S) -> Self {
70        Self {
71            cells: Default::default(),
72            version: AtomicU32::new(START_VERSION),
73            next_cell: AtomicU32::new(0),
74            cell_locks: Default::default(),
75            storage,
76        }
77    }
78
79    /// Retrieve an immutable reference to this `Db`'s storage
80    pub fn storage(&self) -> &S {
81        &self.storage
82    }
83
84    /// Retrieve a mutable reference to this `Db`'s storage.
85    ///
86    /// Note that any mutations made to the storage using this are _not_ tracked by the `Db`!
87    /// Using this incorrectly may break correctness!
88    pub fn storage_mut(&mut self) -> &mut S {
89        &mut self.storage
90    }
91}
92
93impl<S: Storage> Db<S> {
94    /// Return the corresponding Cell for a given computation, if it exists.
95    ///
96    /// This will not update any values.
97    fn get_cell<C: Computation>(&self, computation: &C) -> Option<Cell>
98    where
99        S: StorageFor<C>,
100    {
101        self.storage.get_cell_for_computation(computation)
102    }
103
104    pub(crate) fn get_or_insert_cell<C>(&self, input: C) -> Cell
105    where
106        C: Computation,
107        S: StorageFor<C>,
108    {
109        let computation_id = C::computation_id();
110        let lock = self.cell_locks.entry(computation_id).or_default().clone();
111        let _guard = lock.lock();
112
113        if let Some(cell) = self.get_cell(&input) {
114            cell
115        } else {
116            // We just need a unique ID here, we don't care about ordering between
117            // threads, so we're using Ordering::Relaxed.
118            let cell_id = self.next_cell.fetch_add(1, Ordering::Relaxed);
119            let new_cell = Cell::new(cell_id);
120
121            self.cells.insert(new_cell, CellData::new(computation_id));
122            self.storage.insert_new_cell(new_cell, input);
123            new_cell
124        }
125    }
126
127    fn handle(&self, cell: Cell) -> DbHandle<'_, S> {
128        DbHandle::new(self, cell)
129    }
130
131    #[cfg(test)]
132    #[allow(unused)]
133    pub(crate) fn with_cell_data<C: Computation>(&self, input: &C, f: impl FnOnce(&CellData))
134    where
135        S: StorageFor<C>,
136    {
137        let cell = self
138            .get_cell(input)
139            .unwrap_or_else(|| panic!("unwrap_cell_value: Expected cell to exist"));
140
141        self.cells.get(&cell).map(|value| f(&value)).unwrap()
142    }
143
144    pub fn version(&self) -> u32 {
145        self.version.load(Ordering::SeqCst)
146    }
147
148    pub fn gc(&mut self, version: u32) {
149        let used_cells: std::collections::HashSet<Cell> = self
150            .cells
151            .iter()
152            .filter_map(|entry| {
153                if entry.value().last_verified_version >= version {
154                    Some(entry.key().clone())
155                } else {
156                    None
157                }
158            })
159            .collect();
160
161        self.storage.gc(&used_cells);
162    }
163}
164
165impl<S: Storage> Db<S> {
166    /// Updates an input with a new value
167    ///
168    /// This requires an exclusive reference to self to ensure that there are no currently
169    /// running queries. Updating an input while an incremental computation is occurring
170    /// can break soundness for dependency tracking.
171    ///
172    /// Panics if the given computation is not an input - ie. panics if it has at least 1 dependency.
173    pub fn update_input<C>(&mut self, input: C, new_value: C::Output)
174    where
175        C: Computation,
176        S: StorageFor<C>,
177    {
178        let cell_id = self.get_or_insert_cell(input);
179        assert!(
180            self.is_input(cell_id),
181            "`update_input` given a non-input value. Inputs must have 0 dependencies",
182        );
183
184        let changed = self.storage.update_output(cell_id, new_value);
185        let mut cell = self.cells.get_mut(&cell_id).unwrap();
186
187        if changed {
188            let version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
189            cell.last_updated_version = version;
190            cell.last_verified_version = version;
191        } else {
192            cell.last_verified_version = self.version.load(Ordering::SeqCst);
193        }
194    }
195
196    fn is_input(&self, cell: Cell) -> bool {
197        self.with_cell(cell, |cell| {
198            cell.dependencies.is_empty() && cell.input_dependencies.is_empty()
199        })
200    }
201
202    /// True if a given computation is stale and needs to be re-computed.
203    /// Computations which have never been computed are also considered stale.
204    ///
205    /// Note that this may re-compute dependencies of the given computation.
206    pub fn is_stale<C: Computation>(&self, input: &C) -> bool
207    where
208        S: StorageFor<C>,
209    {
210        // If the cell doesn't exist, it is definitely stale
211        let Some(cell) = self.get_cell(input) else {
212            return true;
213        };
214        self.is_stale_cell(cell)
215    }
216
217    /// True if a given cell is stale and needs to be re-computed.
218    ///
219    /// Note that this may re-compute some input
220    fn is_stale_cell(&self, cell: Cell) -> bool {
221        let computation_id = self.with_cell(cell, |data| data.computation_id);
222
223        if self.storage.output_is_unset(cell, computation_id) {
224            return true;
225        }
226
227        // if any input dependency has changed, this cell is stale
228        let (last_verified, inputs, dependencies) = self.with_cell(cell, |data| {
229            (
230                data.last_verified_version,
231                data.input_dependencies.clone(),
232                data.dependencies.clone(),
233            )
234        });
235
236        // Optimization: only recursively check all dependencies if any
237        // of the inputs this cell depends on have changed
238        let inputs_changed = inputs.into_iter().any(|input_id| {
239            // This cell is stale if the dependency has been updated since
240            // we last verified this cell
241            self.with_cell(input_id, |input| input.last_updated_version > last_verified)
242        });
243
244        // Dependencies need to be iterated in the order they were computed.
245        // Otherwise we may re-run a computation which does not need to be re-run.
246        // In the worst case this could even lead to panics - see the div0 test.
247        inputs_changed
248            && dependencies.into_iter().any(|dependency_id| {
249                self.update_cell(dependency_id);
250                self.with_cell(dependency_id, |dependency| {
251                    if computation_id == ACCUMULATED_COMPUTATION_ID {
252                        dependency.last_run_version > last_verified
253                    } else {
254                        dependency.last_updated_version > last_verified
255                    }
256                })
257            })
258    }
259
260    /// Similar to `update_input` but runs the compute function
261    /// instead of accepting a given value. This also will not update
262    /// `self.version`
263    fn run_compute_function(&self, cell_id: Cell) {
264        let computation_id = self.with_cell(cell_id, |data| data.computation_id);
265        self.storage.clear_accumulated_for_cell(cell_id);
266        let handle = self.handle(cell_id);
267        let changed = S::run_computation(&handle, cell_id, computation_id);
268
269        let version = self.version.load(Ordering::SeqCst);
270        let mut cell = self.cells.get_mut(&cell_id).unwrap();
271        cell.last_verified_version = version;
272        cell.last_run_version = version;
273
274        if changed {
275            cell.last_updated_version = version;
276        }
277    }
278
279    /// Trigger an update of the given cell, recursively checking and re-running any out of date
280    /// dependencies.
281    fn update_cell(&self, cell_id: Cell) {
282        let last_verified_version = self.with_cell(cell_id, |data| data.last_verified_version);
283        let version = self.version.load(Ordering::SeqCst);
284
285        if last_verified_version != version {
286            // if any dependency may have changed, update
287            if self.is_stale_cell(cell_id) {
288                let lock = self.with_cell(cell_id, |cell| cell.lock.clone());
289
290                match lock.try_lock() {
291                    Some(guard) => {
292                        self.run_compute_function(cell_id);
293                        drop(guard);
294                    }
295                    None => {
296                        // This computation is already being run in another thread.
297                        // Before blocking and waiting, since we have time, check for a cycle and
298                        // issue and panic if found.
299                        self.check_for_cycle(cell_id);
300
301                        // Block until it finishes and return the result
302                        drop(lock.lock());
303                    }
304                }
305            } else {
306                let mut cell = self.cells.get_mut(&cell_id).unwrap();
307                cell.last_verified_version = version;
308            }
309        }
310    }
311
312    /// Perform a DFS to check for a cycle, panicking if found
313    fn check_for_cycle(&self, starting_cell: Cell) {
314        let mut visited = FxHashSet::default();
315        let mut path = Vec::new();
316
317        // We're going to push actions to this stack. Most actions will be pushing
318        // a dependency cell to track as the next node in the graph, but some will be
319        // pop actions for popping the top node off the current path. If we encounter
320        // a node which is already in the current path, we have found a cycle.
321        let mut stack = Vec::new();
322        stack.push(Action::Traverse(starting_cell));
323
324        enum Action {
325            Traverse(Cell),
326            Pop(Cell),
327        }
328
329        while let Some(action) = stack.pop() {
330            match action {
331                // This assert_eq is never expected to fail
332                Action::Pop(expected) => assert_eq!(path.pop(), Some(expected)),
333                Action::Traverse(cell) => {
334                    if path.contains(&cell) {
335                        // Include the same cell twice so the cycle is more clear to users
336                        path.push(cell);
337                        self.cycle_error(&path);
338                    }
339
340                    if visited.insert(cell) {
341                        path.push(cell);
342                        stack.push(Action::Pop(cell));
343                        self.with_cell(cell, |cell| {
344                            for dependency in cell.dependencies.iter() {
345                                stack.push(Action::Traverse(*dependency));
346                            }
347                        });
348                    }
349                }
350            }
351        }
352    }
353
354    /// Issue an error with the given cycle
355    fn cycle_error(&self, cycle: &[Cell]) {
356        let mut error = String::new();
357        for (i, cell) in cycle.iter().enumerate() {
358            error += &format!(
359                "\n  {}. {}",
360                i + 1,
361                self.storage.input_debug_string(self, *cell)
362            );
363        }
364        panic!("inc-complete: Cycle Detected!\n\nCycle:{error}")
365    }
366
367    /// Retrieves the up to date value for the given computation, re-running any dependencies as
368    /// necessary.
369    ///
370    /// This function can panic if the dynamic type of the value returned by `compute.run(..)` is not `T`.
371    ///
372    /// Locking behavior: This function locks the cell corresponding to the given computation. This
373    /// can cause a deadlock if the computation recursively depends on itself.
374    pub fn get<C: Computation>(&self, compute: C) -> C::Output
375    where
376        S: StorageFor<C>,
377    {
378        let cell_id = self.get_or_insert_cell(compute);
379        self.get_with_cell::<C>(cell_id)
380    }
381
382    pub(crate) fn get_with_cell<Concrete: Computation>(&self, cell_id: Cell) -> Concrete::Output
383    where
384        S: StorageFor<Concrete>,
385    {
386        self.update_cell(cell_id);
387
388        self.storage
389            .get_output(cell_id)
390            .expect("cell result should have been computed already")
391    }
392
393    fn with_cell<R>(&self, cell: Cell, f: impl FnOnce(&CellData) -> R) -> R {
394        f(&self.cells.get(&cell).unwrap())
395    }
396
397    /// Retrieve each accumulated value of the given type after the given computation is run.
398    ///
399    /// This is most often used for operations like retrieving diagnostics or logs.
400    ///
401    /// Compared to [Db::get_accumulated_uncached], this version reuses the normal flow for
402    /// queries and thus saves accumulated values for each intermediate query. This involves
403    /// more synching and data duplication but can be beneficial if intermediate results
404    /// ever need to be reused, e.g. if you call [Db::get_accumulated] in a loop where each
405    /// call may share dependencies. If you already have a single query which emits all the
406    /// accumulated values you need, [Db::get_accumulated_uncached] is likely faster, but
407    /// requires a `&mut Db`.
408    pub fn get_accumulated<Item, C>(&self, compute: C) -> BTreeSet<Item>
409    where
410        S: StorageFor<C> + StorageFor<Accumulated<Item>>,
411        C: Computation,
412        Item: 'static,
413    {
414        let cell_id = self.get_or_insert_cell(compute);
415        self.update_cell(cell_id);
416        self.get(Accumulated::<Item>::new(cell_id))
417    }
418
419    /// Retrieve each accumulated value of the given type after the given computation is run.
420    ///
421    /// This is most often used for operations like retrieving diagnostics or logs.
422    ///
423    /// This is a faster version of [Db::get_accumulated] for some use-cases. This version tends to be
424    /// more efficient when you already have a single query which emits all the accumulated values
425    /// you need, while the original [Db::get_accumulated] is more efficient when you have many
426    /// smaller calls since it avoids duplicated work and is safe to call with only a [DbHandle].
427    pub fn get_accumulated_uncached<Item, C>(&mut self, compute: C) -> BTreeSet<Item>
428    where
429        S: StorageFor<C> + StorageFor<Accumulated<Item>> + Accumulate<Item>,
430        C: Computation,
431        Item: 'static + Ord,
432    {
433        let cell_id = self.get_or_insert_cell(compute);
434        self.update_cell(cell_id);
435
436        let mut items = BTreeSet::new();
437        let mut visited = BTreeSet::new();
438        let mut queue = vec![cell_id];
439
440        while let Some(cell) = queue.pop() {
441            if visited.insert(cell) {
442                self.with_cell(cell, |data| queue.extend_from_slice(&data.dependencies));
443                items.extend(self.storage().get_accumulated::<Vec<Item>>(cell));
444            }
445        }
446
447        items
448    }
449}