Skip to main content

inc_complete/db/
handle.rs

1use std::collections::BTreeSet;
2
3use crate::{
4    Cell, Computation, Db, Storage,
5    accumulate::{Accumulate, Accumulated},
6    storage::StorageFor,
7};
8
9use super::DbGet;
10
11/// A handle to the database during some operation.
12///
13/// This wraps calls to the Db so that any `get` calls
14/// will be automatically registered as dependencies of
15/// the current operation.
16pub struct DbHandle<'db, S> {
17    db: &'db Db<S>,
18    current_operation: Cell,
19}
20
21impl<'db, S> DbHandle<'db, S> {
22    pub(crate) fn new(db: &'db Db<S>, current_operation: Cell) -> Self {
23        // We're re-running a cell so remove any past dependencies
24        let mut cell = db.cells.get_mut(&current_operation).unwrap();
25
26        cell.dependencies.clear();
27        cell.input_dependencies.clear();
28
29        Self {
30            db,
31            current_operation,
32        }
33    }
34
35    /// Retrieve an immutable reference to this `Db`'s storage
36    ///
37    /// Note that any mutations made to the storage using this are _not_ tracked by the database!
38    /// Using this incorrectly may break correctness!
39    pub fn storage(&self) -> &S {
40        self.db.storage()
41    }
42}
43
44impl<S: Storage> DbHandle<'_, S> {
45    /// Locking behavior: This function locks the cell corresponding to the given computation. This
46    /// can cause a deadlock if the computation recursively depends on itself.
47    pub fn get<C: Computation>(&self, compute: C) -> C::Output
48    where
49        S: StorageFor<C>,
50    {
51        // Register the dependency
52        let dependency = self.db.get_or_insert_cell(compute);
53        self.update_and_register_dependency::<C>(dependency);
54
55        // Fetch the current value of the dependency, running it if out of date
56        self.db.get_with_cell(dependency)
57    }
58
59    /// Registers the given cell as a dependency, running it and updating any required metadata
60    fn update_and_register_dependency<C: Computation>(&self, dependency: Cell) {
61        self.update_and_register_dependency_inner(dependency, C::IS_INPUT);
62    }
63
64    fn update_and_register_dependency_inner(&self, dependency: Cell, is_input: bool) {
65        let mut cell = self.db.cells.get_mut(&self.current_operation).unwrap();
66
67        // If `dependency` is an input it must be remembered both as a dependency
68        // and as an input dependency. Otherwise we cannot differentiate between
69        // computations which directly depend on inputs and those that only indirectly
70        // depend on them.
71        cell.dependencies.push(dependency);
72        if is_input {
73            cell.input_dependencies.insert(dependency);
74        }
75
76        drop(cell);
77
78        // Run the computation to update its dependencies before we query them afterward
79        self.db.update_cell(dependency);
80
81        let dependency = self.db.cells.get(&dependency).unwrap();
82        let dependency_inputs = dependency.input_dependencies.clone();
83        drop(dependency);
84
85        let mut cell = self.db.cells.get_mut(&self.current_operation).unwrap();
86        for input in dependency_inputs {
87            cell.input_dependencies.insert(input);
88        }
89    }
90
91    /// Accumulate an item in the current computation. This item can be retrieved along
92    /// with all other accumulated items in this computation and its dependencies via
93    /// a call to `get_accumulated`.
94    ///
95    /// This is most often used for operations like pushing diagnostics or logs.
96    pub fn accumulate<Item>(&self, item: Item)
97    where
98        S: Accumulate<Item>,
99    {
100        self.storage().accumulate(self.current_operation, item);
101    }
102
103    /// Retrieve an accumulated value in a container of the user's choice.
104    /// This will return all the accumulated items after the given computation.
105    ///
106    /// This is most often used for operations like retrieving diagnostics or logs.
107    pub fn get_accumulated<Item, C>(&self, compute: C) -> BTreeSet<Item>
108    where
109        C: Computation,
110        Item: 'static + Ord,
111        S: StorageFor<Accumulated<Item>> + StorageFor<C> + Accumulate<Item>,
112    {
113        let dependency = self.db.get_or_insert_cell(compute);
114        self.get_accumulated_with_cell::<Item>(dependency)
115    }
116
117    /// Retrieve an accumulated value in a container of the user's choice.
118    /// This will return all the accumulated items after the given computation.
119    ///
120    /// This is the implementation of the publically accessible `db.get(Accumulated::<Item>(MyComputation))`.
121    ///
122    /// This is most often used for operations like retrieving diagnostics or logs.
123    pub(crate) fn get_accumulated_with_cell<Item>(&self, cell_id: Cell) -> BTreeSet<Item>
124    where
125        Item: 'static + Ord,
126        S: StorageFor<Accumulated<Item>> + Accumulate<Item>,
127    {
128        self.update_and_register_dependency_inner(cell_id, false);
129        let dependencies = self.db.with_cell(cell_id, |cell| cell.dependencies.clone());
130
131        // Collect `Accumulator` results from each dependency. This should also ensure we
132        // rerun this if any dependency changes, even if `cell_id` is updated such that it
133        // uses different dependencies but its output remains the same.
134        let computation_id = Accumulated::<Item>::computation_id();
135        let mut result: BTreeSet<Item> = dependencies
136            .into_iter()
137            // Filter out `Accumulated<Item>` cells from the dep list — they exist for staleness
138            // tracking only and must not be traversed for value collection, or we'd get duplicates.
139            .filter(|&dep| self.db.with_cell(dep, |cell| cell.computation_id) != computation_id)
140            .flat_map(|dependency| self.get(Accumulated::<Item>::new(dependency)))
141            .collect();
142
143        result.extend(self.storage().get_accumulated::<Vec<Item>>(cell_id));
144        result
145    }
146}
147
148impl<'db, S, C> DbGet<C> for DbHandle<'db, S>
149where
150    C: Computation,
151    S: Storage + StorageFor<C>,
152{
153    fn get(&self, key: C) -> C::Output {
154        self.get(key)
155    }
156}