Skip to main content

inc_complete/db/
mod.rs

1use std::collections::BTreeSet;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU32, Ordering};
4
5use crate::accumulate::Accumulated;
6use crate::cell::CellData;
7use crate::storage::StorageFor;
8use crate::{Cell, Computation, Storage};
9
10pub mod debug_with_db;
11mod handle;
12mod serialize;
13mod tests;
14
15pub use handle::DbHandle;
16use parking_lot::Mutex;
17use rustc_hash::FxHashSet;
18
19const START_VERSION: u32 = 1;
20
21/// The central database object to manage and cache incremental computations.
22///
23/// To use this, a type implementing `Storage` is required to be provided.
24/// See the documentation for `impl_storage!`.
25pub struct Db<Storage> {
26    cells: dashmap::DashMap<Cell, CellData>,
27    version: AtomicU32,
28    next_cell: AtomicU32,
29    storage: Storage,
30
31    /// Lock used when acquiring new Cells to ensure the same data isn't assigned
32    /// multiple ids concurrently. Maps computation_id to each lock.
33    cell_locks: dashmap::DashMap<u32, Arc<Mutex<()>>>,
34}
35
36impl<Storage: Default> Db<Storage> {
37    /// Construct a new `Db` object using `Default::default()` for the initial storage.
38    pub fn new() -> Self {
39        Self::with_storage(Storage::default())
40    }
41}
42
43impl<S: Default> Default for Db<S> {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49/// Abstracts over the `get` function provided by `Db<S>` and `DbHandle<S>` to avoid
50/// providing `get` and `get_db` variants for each function.
51pub trait DbGet<C: Computation> {
52    /// Run an incremental computation `C` and return its output.
53    /// If `C` is already cached, no computation will be performed.
54    fn get(&self, key: C) -> C::Output;
55}
56
57impl<S, C> DbGet<C> for Db<S>
58where
59    C: Computation,
60    S: Storage + StorageFor<C>,
61{
62    fn get(&self, key: C) -> C::Output {
63        self.get(key)
64    }
65}
66
67impl<S> Db<S> {
68    /// Construct a new `Db` object with the given initial storage.
69    pub fn with_storage(storage: S) -> Self {
70        Self {
71            cells: Default::default(),
72            version: AtomicU32::new(START_VERSION),
73            next_cell: AtomicU32::new(0),
74            cell_locks: Default::default(),
75            storage,
76        }
77    }
78
79    /// Retrieve an immutable reference to this `Db`'s storage
80    pub fn storage(&self) -> &S {
81        &self.storage
82    }
83
84    /// Retrieve a mutable reference to this `Db`'s storage.
85    ///
86    /// Note that any mutations made to the storage using this are _not_ tracked by the `Db`!
87    /// Using this incorrectly may break correctness!
88    pub fn storage_mut(&mut self) -> &mut S {
89        &mut self.storage
90    }
91}
92
93impl<S: Storage> Db<S> {
94    /// Return the corresponding Cell for a given computation, if it exists.
95    ///
96    /// This will not update any values.
97    fn get_cell<C: Computation>(&self, computation: &C) -> Option<Cell>
98    where
99        S: StorageFor<C>,
100    {
101        self.storage.get_cell_for_computation(computation)
102    }
103
104    pub(crate) fn get_or_insert_cell<C>(&self, input: C) -> Cell
105    where
106        C: Computation,
107        S: StorageFor<C>,
108    {
109        let computation_id = C::computation_id();
110        let lock = self.cell_locks.entry(computation_id).or_default().clone();
111        let _guard = lock.lock();
112
113        if let Some(cell) = self.get_cell(&input) {
114            cell
115        } else {
116            // We just need a unique ID here, we don't care about ordering between
117            // threads, so we're using Ordering::Relaxed.
118            let cell_id = self.next_cell.fetch_add(1, Ordering::Relaxed);
119            let new_cell = Cell::new(cell_id);
120
121            self.cells.insert(new_cell, CellData::new(computation_id));
122            self.storage.insert_new_cell(new_cell, input);
123            new_cell
124        }
125    }
126
127    fn handle(&self, cell: Cell) -> DbHandle<'_, S> {
128        DbHandle::new(self, cell)
129    }
130
131    #[cfg(test)]
132    #[allow(unused)]
133    pub(crate) fn with_cell_data<C: Computation>(&self, input: &C, f: impl FnOnce(&CellData))
134    where
135        S: StorageFor<C>,
136    {
137        let cell = self
138            .get_cell(input)
139            .unwrap_or_else(|| panic!("unwrap_cell_value: Expected cell to exist"));
140
141        self.cells.get(&cell).map(|value| f(&value)).unwrap()
142    }
143
144    pub fn version(&self) -> u32 {
145        self.version.load(Ordering::SeqCst)
146    }
147
148    pub fn gc(&mut self, version: u32) {
149        let used_cells: std::collections::HashSet<Cell> = self
150            .cells
151            .iter()
152            .filter_map(|entry| {
153                if entry.value().last_verified_version >= version {
154                    Some(entry.key().clone())
155                } else {
156                    None
157                }
158            })
159            .collect();
160
161        self.storage.gc(&used_cells);
162    }
163}
164
165impl<S: Storage> Db<S> {
166    /// Updates an input with a new value
167    ///
168    /// This requires an exclusive reference to self to ensure that there are no currently
169    /// running queries. Updating an input while an incremental computation is occurring
170    /// can break soundness for dependency tracking.
171    ///
172    /// Panics if the given computation is not an input - ie. panics if it has at least 1 dependency.
173    pub fn update_input<C>(&mut self, input: C, new_value: C::Output)
174    where
175        C: Computation,
176        S: StorageFor<C>,
177    {
178        let cell_id = self.get_or_insert_cell(input);
179        assert!(
180            self.is_input(cell_id),
181            "`update_input` given a non-input value. Inputs must have 0 dependencies",
182        );
183
184        let changed = self.storage.update_output(cell_id, new_value);
185        let mut cell = self.cells.get_mut(&cell_id).unwrap();
186
187        if changed {
188            let version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
189            cell.last_updated_version = version;
190            cell.last_verified_version = version;
191        } else {
192            cell.last_verified_version = self.version.load(Ordering::SeqCst);
193        }
194    }
195
196    fn is_input(&self, cell: Cell) -> bool {
197        self.with_cell(cell, |cell| {
198            cell.dependencies.is_empty() && cell.input_dependencies.is_empty()
199        })
200    }
201
202    /// True if a given computation is stale and needs to be re-computed.
203    /// Computations which have never been computed are also considered stale.
204    ///
205    /// Note that this may re-compute dependencies of the given computation.
206    pub fn is_stale<C: Computation>(&self, input: &C) -> bool
207    where
208        S: StorageFor<C>,
209    {
210        // If the cell doesn't exist, it is definitely stale
211        let Some(cell) = self.get_cell(input) else {
212            return true;
213        };
214        self.is_stale_cell(cell)
215    }
216
217    /// True if a given cell is stale and needs to be re-computed.
218    ///
219    /// Note that this may re-compute some input
220    fn is_stale_cell(&self, cell: Cell) -> bool {
221        let computation_id = self.with_cell(cell, |data| data.computation_id);
222
223        if self.storage.output_is_unset(cell, computation_id) {
224            return true;
225        }
226
227        // if any input dependency has changed, this cell is stale
228        let (last_verified, inputs, dependencies) = self.with_cell(cell, |data| {
229            (
230                data.last_verified_version,
231                data.input_dependencies.clone(),
232                data.dependencies.clone(),
233            )
234        });
235
236        // Optimization: only recursively check all dependencies if any
237        // of the inputs this cell depends on have changed
238        let inputs_changed = inputs.into_iter().any(|input_id| {
239            // This cell is stale if the dependency has been updated since
240            // we last verified this cell
241            self.with_cell(input_id, |input| input.last_updated_version > last_verified)
242        });
243
244        // Dependencies need to be iterated in the order they were computed.
245        // Otherwise we may re-run a computation which does not need to be re-run.
246        // In the worst case this could even lead to panics - see the div0 test.
247        inputs_changed
248            && dependencies.into_iter().any(|dependency_id| {
249                self.update_cell(dependency_id);
250                self.with_cell(dependency_id, |dependency| {
251                    dependency.last_updated_version > last_verified
252                })
253            })
254    }
255
256    /// Similar to `update_input` but runs the compute function
257    /// instead of accepting a given value. This also will not update
258    /// `self.version`
259    fn run_compute_function(&self, cell_id: Cell) {
260        let computation_id = self.with_cell(cell_id, |data| data.computation_id);
261        self.storage.clear_accumulated_for_cell(cell_id);
262        let handle = self.handle(cell_id);
263        let changed = S::run_computation(&handle, cell_id, computation_id);
264
265        let version = self.version.load(Ordering::SeqCst);
266        let mut cell = self.cells.get_mut(&cell_id).unwrap();
267        cell.last_verified_version = version;
268
269        if changed {
270            cell.last_updated_version = version;
271        }
272    }
273
274    /// Trigger an update of the given cell, recursively checking and re-running any out of date
275    /// dependencies.
276    fn update_cell(&self, cell_id: Cell) {
277        let last_verified_version = self.with_cell(cell_id, |data| data.last_verified_version);
278        let version = self.version.load(Ordering::SeqCst);
279
280        if last_verified_version != version {
281            // if any dependency may have changed, update
282            if self.is_stale_cell(cell_id) {
283                let lock = self.with_cell(cell_id, |cell| cell.lock.clone());
284
285                match lock.try_lock() {
286                    Some(guard) => {
287                        self.run_compute_function(cell_id);
288                        drop(guard);
289                    }
290                    None => {
291                        // This computation is already being run in another thread.
292                        // Before blocking and waiting, since we have time, check for a cycle and
293                        // issue and panic if found.
294                        self.check_for_cycle(cell_id);
295
296                        // Block until it finishes and return the result
297                        drop(lock.lock());
298                    }
299                }
300            } else {
301                let mut cell = self.cells.get_mut(&cell_id).unwrap();
302                cell.last_verified_version = version;
303            }
304        }
305    }
306
307    /// Perform a DFS to check for a cycle, panicking if found
308    fn check_for_cycle(&self, starting_cell: Cell) {
309        let mut visited = FxHashSet::default();
310        let mut path = Vec::new();
311
312        // We're going to push actions to this stack. Most actions will be pushing
313        // a dependency cell to track as the next node in the graph, but some will be
314        // pop actions for popping the top node off the current path. If we encounter
315        // a node which is already in the current path, we have found a cycle.
316        let mut stack = Vec::new();
317        stack.push(Action::Traverse(starting_cell));
318
319        enum Action {
320            Traverse(Cell),
321            Pop(Cell),
322        }
323
324        while let Some(action) = stack.pop() {
325            match action {
326                // This assert_eq is never expected to fail
327                Action::Pop(expected) => assert_eq!(path.pop(), Some(expected)),
328                Action::Traverse(cell) => {
329                    if path.contains(&cell) {
330                        // Include the same cell twice so the cycle is more clear to users
331                        path.push(cell);
332                        self.cycle_error(&path);
333                    }
334
335                    if visited.insert(cell) {
336                        path.push(cell);
337                        stack.push(Action::Pop(cell));
338                        self.with_cell(cell, |cell| {
339                            for dependency in cell.dependencies.iter() {
340                                stack.push(Action::Traverse(*dependency));
341                            }
342                        });
343                    }
344                }
345            }
346        }
347    }
348
349    /// Issue an error with the given cycle
350    fn cycle_error(&self, cycle: &[Cell]) {
351        let mut error = String::new();
352        for (i, cell) in cycle.iter().enumerate() {
353            error += &format!(
354                "\n  {}. {}",
355                i + 1,
356                self.storage.input_debug_string(self, *cell)
357            );
358        }
359        panic!("inc-complete: Cycle Detected!\n\nCycle:{error}")
360    }
361
362    /// Retrieves the up to date value for the given computation, re-running any dependencies as
363    /// necessary.
364    ///
365    /// This function can panic if the dynamic type of the value returned by `compute.run(..)` is not `T`.
366    ///
367    /// Locking behavior: This function locks the cell corresponding to the given computation. This
368    /// can cause a deadlock if the computation recursively depends on itself.
369    pub fn get<C: Computation>(&self, compute: C) -> C::Output
370    where
371        S: StorageFor<C>,
372    {
373        let cell_id = self.get_or_insert_cell(compute);
374        self.get_with_cell::<C>(cell_id)
375    }
376
377    pub(crate) fn get_with_cell<Concrete: Computation>(&self, cell_id: Cell) -> Concrete::Output
378    where
379        S: StorageFor<Concrete>,
380    {
381        self.update_cell(cell_id);
382
383        self.storage
384            .get_output(cell_id)
385            .expect("cell result should have been computed already")
386    }
387
388    fn with_cell<R>(&self, cell: Cell, f: impl FnOnce(&CellData) -> R) -> R {
389        f(&self.cells.get(&cell).unwrap())
390    }
391
392    /// Retrieve each accumulated value of the given type after the given computation is run.
393    /// Subsequent calls to this for the same computation or dependencies will be cached.
394    ///
395    /// This is most often used for operations like retrieving diagnostics or logs.
396    pub fn get_accumulated<Item, C>(&self, compute: C) -> BTreeSet<Item>
397    where
398        S: StorageFor<C> + StorageFor<Accumulated<Item>>,
399        C: Computation,
400        Item: 'static,
401    {
402        let cell_id = self.get_or_insert_cell(compute);
403        self.update_cell(cell_id);
404        self.get(Accumulated::<Item>::new(cell_id))
405    }
406}