inc_complete/
lib.rs

1use std::{collections::HashMap, hash::Hash};
2
3use db_handle::DbHandle;
4use petgraph::graph::DiGraph;
5
6mod cell;
7mod db_handle;
8mod value;
9
10use value::HashEqObj;
11pub use value::Run;
12pub use value::Value;
13
14const START_VERSION: u32 = 1;
15
16pub struct Db<F> {
17    cells: DiGraph<CellValue<F>, ()>,
18    version: u32,
19
20    input_to_cell: HashMap<F, Cell>,
21}
22
23#[derive(Debug, Copy, Clone)]
24pub struct Cell(petgraph::graph::NodeIndex);
25
26struct CellValue<F> {
27    compute: F,
28    result: Option<(u64, Value)>,
29
30    last_updated_version: u32,
31    last_verified_version: u32,
32}
33
34impl<F> CellValue<F> {
35    fn new(compute: F) -> Self {
36        Self {
37            compute,
38            result: None,
39            last_updated_version: 0,
40            last_verified_version: 0,
41        }
42    }
43}
44
45impl<F> Db<F> {
46    pub fn new() -> Self {
47        Self {
48            cells: DiGraph::default(),
49            input_to_cell: HashMap::default(),
50            version: START_VERSION,
51        }
52    }
53}
54
55impl<F: Run + Copy + Eq + Hash + Clone> Db<F> {
56    /// Retrieves the up to date value for the given computation, re-running any dependencies as
57    /// necessary.
58    ///
59    /// This function can panic if the dynamic type of the value returned by `compute.run(..)` is not `T`.
60    pub fn get<'a, T: 'static>(&'a mut self, compute: F) -> &'a T
61    where
62        F: std::fmt::Debug,
63    {
64        let cell_id = self.cell(compute);
65        self.update_cell(cell_id);
66
67        let cell = &self.cells[cell_id.0];
68        let result = &cell
69            .result
70            .as_ref()
71            .expect("cell result should have been computed already")
72            .1;
73
74        result.downcast_obj_ref()
75            .expect("Output type to `Db::get` does not match the type of the value returned by the `Run::run` function")
76    }
77
78    pub fn update_cell(&mut self, cell_id: Cell)
79    where
80        F: std::fmt::Debug,
81    {
82        let cell = &self.cells[cell_id.0];
83
84        if cell.last_verified_version != self.version {
85            if cell.result.is_some() {
86                let neighbors = self.cells.neighbors(cell_id.0).collect::<Vec<_>>();
87
88                // if any dependency may have changed, update
89                if neighbors.into_iter().any(|input_id| {
90                    let input = &self.cells[input_id];
91                    let cell = &self.cells[cell_id.0];
92                    let dependency_stale = input.last_verified_version != self.version
93                        || input.last_updated_version > cell.last_verified_version;
94
95                    dependency_stale
96                }) {
97                    self.run_compute_function(cell_id);
98                } else {
99                    let cell = &mut self.cells[cell_id.0];
100                    cell.last_verified_version = self.version;
101                }
102            } else {
103                // cell.result is None, initialize it
104                self.run_compute_function(cell_id);
105            }
106        }
107    }
108
109    pub fn cell(&mut self, input: F) -> Cell {
110        if let Some(cell_id) = self.input_to_cell.get(&input) {
111            *cell_id
112        } else {
113            let new_id = self.cells.add_node(CellValue::new(input.clone()));
114            let cell = Cell(new_id);
115            self.input_to_cell.insert(input, cell);
116            cell
117        }
118    }
119
120    /// Updates an input with a new value
121    ///
122    /// May panic in Debug mode if the input is not an input - ie. it has at least 1 dependency.
123    /// Note that this step is skipped when compiling in Release mode.
124    pub fn update_input(&mut self, input: F, new_value: Value)
125    where
126        F: std::fmt::Debug,
127    {
128        let cell_id = self.cell(input);
129        debug_assert!(
130            self.is_input(cell_id),
131            "`{input:?}` is not an input - inputs must have 0 dependencies"
132        );
133
134        let cell = &mut self.cells[cell_id.0];
135        let new_hash = new_value.get_hash();
136
137        if let Some((old_hash, _)) = cell.result.as_ref() {
138            if new_hash == *old_hash {
139                cell.last_verified_version = self.version;
140                return;
141            }
142        }
143
144        self.version += 1;
145        cell.result = Some((new_hash, new_value));
146        cell.last_updated_version = self.version;
147        cell.last_verified_version = self.version;
148    }
149
150    fn is_input(&self, cell: Cell) -> bool {
151        self.cells.neighbors(cell.0).count() == 0
152    }
153
154    #[cfg(test)]
155    fn get_cell(&mut self, input: F) -> &CellValue<F> {
156        let cell = self.cell(input);
157        &self.cells[cell.0]
158    }
159
160    /// Similar to `update_input` but runs the compute function
161    /// instead of accepting a given value. This also will not update
162    /// `self.version`
163    fn run_compute_function(&mut self, cell_id: Cell) {
164        let cell = &self.cells[cell_id.0];
165        let result = cell.compute.run(&mut self.handle(cell_id));
166        let new_hash = result.get_hash();
167        let cell = &mut self.cells[cell_id.0];
168
169        if let Some((old_hash, _)) = cell.result.as_ref() {
170            if new_hash == *old_hash {
171                cell.last_verified_version = self.version;
172                return;
173            }
174        }
175
176        cell.result = Some((new_hash, result));
177        cell.last_verified_version = self.version;
178        cell.last_updated_version = self.version;
179    }
180
181    fn handle(&mut self, cell: Cell) -> DbHandle<F> {
182        DbHandle::new(self, cell)
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use crate::{Db, Run, START_VERSION, Value, db_handle::DbHandle};
189
190    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
191    enum Basic {
192        A1,
193        A2,
194        A3,
195    }
196
197    //      A
198    // 1 [ =20 ]
199    // 2 [ =A1 + 1 ]
200    // 3 [ =A2 + 2 ]
201    impl Run for Basic {
202        fn run(self, db: &mut DbHandle<Self>) -> Value {
203            match self {
204                Basic::A1 => Value::new(20i32),
205                Basic::A2 => Value::new(db.get::<i32>(Basic::A1) + 1i32),
206                Basic::A3 => Value::new(db.get::<i32>(Basic::A2) + 2i32),
207            }
208        }
209    }
210
211    // Test that we can compute a basic chain of computation
212    //      A
213    // 1 [ =20 ]
214    // 2 [ =A1 + 1 ]
215    // 3 [ =A2 + 2 ]
216    #[test]
217    fn basic() {
218        let mut db = Db::new();
219        let result = *db.get::<i32>(Basic::A3);
220        assert_eq!(result, 23);
221    }
222
223    // Test that we can re-use values from past runs
224    //      A
225    // 1 [ =20 ]
226    // 2 [ =A1 + 1 ]
227    // 3 [ =A2 + 2 ]
228    #[test]
229    fn no_recompute_basic() {
230        let mut db = Db::new();
231        let result1 = *db.get::<i32>(Basic::A3);
232        let result2 = *db.get::<i32>(Basic::A3);
233        assert_eq!(result1, 23);
234        assert_eq!(result2, 23);
235
236        // No input has been updated
237        let expected_version = START_VERSION;
238        assert_eq!(db.version, expected_version);
239
240        let a1 = db.get_cell(Basic::A1);
241        assert_eq!(a1.last_updated_version, expected_version);
242        assert_eq!(a1.last_verified_version, expected_version);
243
244        let a2 = db.get_cell(Basic::A2);
245        assert_eq!(a2.last_updated_version, expected_version);
246        assert_eq!(a2.last_verified_version, expected_version);
247
248        let a3 = db.get_cell(Basic::A3);
249        assert_eq!(a3.last_updated_version, expected_version);
250        assert_eq!(a3.last_verified_version, expected_version);
251    }
252
253    #[test]
254    fn early_cutoff() {
255        // Given:
256        //  Numerator = 4
257        //  Denominator = 2
258        //  Division = Numerator / Denominator
259        //  DenominatorIs0 = Denominator == 0
260        //  Result = if DenominatorIs0 { 0 } else { Division }
261        //
262        // We should expect a result of 2. When changing Denominator to 0,
263        // we should avoid recalculating Division even though it was previously
264        // a dependency of Result since Division is no longer required, and doing
265        // so would result in a divide by zero error.
266        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
267        enum EarlyCutoff {
268            Numerator,
269            Denominator,
270            Division,
271            DenominatorIs0,
272            Result,
273        }
274
275        impl Run for EarlyCutoff {
276            fn run(self, db: &mut DbHandle<Self>) -> Value {
277                use EarlyCutoff::*;
278                match self {
279                    Numerator => Value::new(6),
280                    Denominator => Value::new(0),
281                    Division => Value::new(*db.get::<i32>(Numerator) / *db.get::<i32>(Denominator)),
282                    DenominatorIs0 => Value::new(*db.get::<i32>(Denominator) == 0),
283                    Result => {
284                        if *db.get(DenominatorIs0) {
285                            Value::new(0i32)
286                        } else {
287                            Value::new(*db.get::<i32>(Division))
288                        }
289                    }
290                }
291            }
292        }
293
294        {
295            // Run from scratch with Denominator = 0
296            assert_eq!(0i32, *Db::new().get(EarlyCutoff::Result));
297        }
298
299        {
300            // Start with Denominator = 2, then recompute with Denominator = 0
301            let mut db = Db::new();
302            assert_eq!(db.version, START_VERSION);
303            db.update_input(EarlyCutoff::Denominator, Value::new(2i32));
304            assert_eq!(db.version, START_VERSION + 1);
305
306            // 6 / 3
307            assert_eq!(3i32, *db.get(EarlyCutoff::Result));
308
309            db.update_input(EarlyCutoff::Denominator, Value::new(0i32));
310            assert_eq!(db.version, START_VERSION + 2);
311
312            // Although Division was previously a dependency of Result,
313            // we shouldn't update Division due to the `DenominatorIs0` changing as well,
314            // leading us into a different branch where `Division` is no longer required.
315            // If we did recalculate `Division` we would get a divide by zero error.
316            //
317            // Shouldn't get a divide by zero here
318            assert_eq!(0i32, *db.get(EarlyCutoff::Result));
319        }
320    }
321}