inc_complete/
lib.rs

1use std::{any::Any, collections::HashMap, hash::{ Hasher, Hash }};
2
3use petgraph::graph::DiGraph;
4
5const START_VERSION: u32 = 1;
6
7pub struct Db<F> {
8    cells: DiGraph<CellValue<F>, ()>,
9    version: u32,
10
11    input_to_cell: HashMap<F, Cell>,
12}
13
14#[derive(Copy, Clone)]
15pub struct Cell(petgraph::graph::NodeIndex);
16
17struct CellValue<F> {
18    compute: F,
19    result: Option<(u64, Box<dyn Any>)>,
20
21    last_updated_version: u32,
22    last_verified_version: u32,
23}
24
25impl<F> CellValue<F> {
26    fn new(compute: F) -> Self {
27        Self {
28            compute,
29            result: None,
30            last_updated_version: 0,
31            last_verified_version: 0,
32        }
33    }
34}
35
36pub trait Run: Sized {
37    fn run(self, db: &mut Db<Self>) -> impl Any + Hash;
38}
39
40impl<F> Db<F> {
41    pub fn new() -> Self {
42        Self {
43            cells: DiGraph::default(),
44            input_to_cell: HashMap::default(),
45            version: START_VERSION,
46        }
47    }
48}
49
50impl<F: Run + Copy + Eq + Hash + Clone> Db<F> {
51    pub fn get<'a, T: 'static>(&'a mut self, compute: F) -> &'a T {
52        let cell_id = self.input(compute);
53        let cell = &self.cells[cell_id.0];
54
55        if cell.last_verified_version != self.version {
56            if cell.result.is_some() {
57                let neighbors = self.cells.neighbors(cell_id.0).collect::<Vec<_>>();
58
59                if neighbors.into_iter().any(|input_id| {
60                    let input = &self.cells[input_id];
61                    let cell = &self.cells[cell_id.0];
62                    let need_to_update = input.last_updated_version > cell.last_verified_version;
63                    if need_to_update {
64                        self.update_cell(Cell(input_id));
65                    }
66                    need_to_update
67                }) {
68                    self.update_cell(cell_id);
69                } else {
70                    let cell = &mut self.cells[cell_id.0];
71                    cell.last_verified_version = self.version;
72                }
73            } else /* cell.result is None, initialize it */ {
74                self.update_cell(cell_id);
75            }
76        }
77
78        let cell = &self.cells[cell_id.0];
79        let result = cell.result.as_ref()
80            .expect("cell result should have been computed already").1.as_ref();
81
82        result
83            .downcast_ref()
84            .expect("Output type to `Db::get` does not match the type of the value returned by the `Run::run` function")
85    }
86
87    pub fn input(&mut self, input: F) -> Cell {
88        if let Some(cell_id) = self.input_to_cell.get(&input) {
89            *cell_id
90        } else {
91            self.version += 1;
92            let new_id = self.cells.add_node(CellValue::new(input.clone()));
93            let cell = Cell(new_id);
94            self.input_to_cell.insert(input, cell);
95            cell
96        }
97    }
98
99    #[cfg(test)]
100    fn get_cell(&mut self, input: F) -> &CellValue<F> {
101        let cell = self.input(input);
102        &self.cells[cell.0]
103    }
104
105    fn update_cell(&mut self, cell_id: Cell) {
106        let cell = &self.cells[cell_id.0];
107        let result = cell.compute.run(self);
108        let new_hash = hash(&result);
109        let cell = &mut self.cells[cell_id.0];
110
111        if let Some((old_hash, _)) = cell.result.as_ref() {
112            if new_hash == *old_hash {
113                cell.last_verified_version = self.version;
114                return;
115            }
116        }
117
118        cell.result = Some((new_hash, Box::new(result)));
119        cell.last_verified_version = self.version;
120        cell.last_updated_version = self.version;
121    }
122}
123
124// TODO: Use stable hash
125fn hash<T: Hash>(x: T) -> u64 {
126    let mut hasher = std::hash::DefaultHasher::default();
127    x.hash(&mut hasher);
128    hasher.finish()
129}
130
131#[cfg(test)]
132mod tests {
133    use crate::{Db, Run, START_VERSION};
134
135    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
136    enum Basic {
137        A1,
138        A2,
139        A3,
140    }
141
142    //      A
143    // 1 [ =20 ]
144    // 2 [ =A1 + 1 ]
145    // 3 [ =A2 + 2 ]
146    impl Run for Basic {
147        fn run(self, db: &mut Db<Self>) -> impl std::any::Any + std::hash::Hash {
148            match self {
149                Basic::A1 => 20,
150                Basic::A2 => db.get::<i32>(Basic::A1) + 1,
151                Basic::A3 => db.get::<i32>(Basic::A2) + 2,
152            }
153        }
154    }
155
156    // Test that we can compute a basic chain of computation
157    //      A
158    // 1 [ =20 ]
159    // 2 [ =A1 + 1 ]
160    // 3 [ =A2 + 2 ]
161    #[test]
162    fn basic() {
163        let mut db = Db::new();
164        let result = *db.get::<i32>(Basic::A3);
165        assert_eq!(result, 23);
166    }
167
168    // Test that we can re-use values from past runs
169    //      A
170    // 1 [ =20 ]
171    // 2 [ =A1 + 1 ]
172    // 3 [ =A2 + 2 ]
173    #[test]
174    fn no_recompute_basic() {
175        let mut db = Db::new();
176        let result1 = *db.get::<i32>(Basic::A3);
177        let result2 = *db.get::<i32>(Basic::A3);
178        assert_eq!(result1, 23);
179        assert_eq!(result2, 23);
180
181        // Expect 3 updates from caching A3, A2, A1
182        let expected_version = START_VERSION + 3;
183        assert_eq!(db.version, expected_version);
184
185        let a1 = db.get_cell(Basic::A1);
186        assert_eq!(a1.last_updated_version, expected_version);
187        assert_eq!(a1.last_verified_version, expected_version);
188
189        let a2 = db.get_cell(Basic::A2);
190        assert_eq!(a2.last_updated_version, expected_version);
191        assert_eq!(a2.last_verified_version, expected_version);
192
193        let a3 = db.get_cell(Basic::A3);
194        assert_eq!(a3.last_updated_version, expected_version);
195        assert_eq!(a3.last_verified_version, expected_version);
196    }
197}