1use std::{collections::HashMap, hash::Hash};
2
3use db_handle::DbHandle;
4use petgraph::graph::DiGraph;
5
6mod cell;
7mod db_handle;
8mod value;
9
10use value::HashEqObj;
11pub use value::Run;
12pub use value::Value;
13
14const START_VERSION: u32 = 1;
15
16pub struct Db<F> {
17 cells: DiGraph<CellValue<F>, ()>,
18 version: u32,
19
20 input_to_cell: HashMap<F, Cell>,
21}
22
23#[derive(Debug, Copy, Clone)]
24pub struct Cell(petgraph::graph::NodeIndex);
25
26struct CellValue<F> {
27 compute: F,
28 result: Option<(u64, Value)>,
29
30 last_updated_version: u32,
31 last_verified_version: u32,
32}
33
34impl<F> CellValue<F> {
35 fn new(compute: F) -> Self {
36 Self {
37 compute,
38 result: None,
39 last_updated_version: 0,
40 last_verified_version: 0,
41 }
42 }
43}
44
45impl<F> Db<F> {
46 pub fn new() -> Self {
47 Self {
48 cells: DiGraph::default(),
49 input_to_cell: HashMap::default(),
50 version: START_VERSION,
51 }
52 }
53}
54
55impl<F: Run + Copy + Eq + Hash + Clone> Db<F> {
56 pub fn get<'a, T: 'static>(&'a mut self, compute: F) -> &'a T
61 where
62 F: std::fmt::Debug,
63 {
64 let cell_id = self.cell(compute);
65 self.update_cell(cell_id);
66
67 let cell = &self.cells[cell_id.0];
68 let result = &cell
69 .result
70 .as_ref()
71 .expect("cell result should have been computed already")
72 .1;
73
74 result.downcast_obj_ref()
75 .expect("Output type to `Db::get` does not match the type of the value returned by the `Run::run` function")
76 }
77
78 pub fn update_cell(&mut self, cell_id: Cell)
79 where
80 F: std::fmt::Debug,
81 {
82 let cell = &self.cells[cell_id.0];
83
84 if cell.last_verified_version != self.version {
85 if cell.result.is_some() {
86 let neighbors = self.cells.neighbors(cell_id.0).collect::<Vec<_>>();
87
88 if neighbors.into_iter().any(|input_id| {
90 let input = &self.cells[input_id];
91 let cell = &self.cells[cell_id.0];
92 let dependency_stale = input.last_verified_version != self.version
93 || input.last_updated_version > cell.last_verified_version;
94
95 dependency_stale
96 }) {
97 self.run_compute_function(cell_id);
98 } else {
99 let cell = &mut self.cells[cell_id.0];
100 cell.last_verified_version = self.version;
101 }
102 } else {
103 self.run_compute_function(cell_id);
105 }
106 }
107 }
108
109 pub fn cell(&mut self, input: F) -> Cell {
110 if let Some(cell_id) = self.input_to_cell.get(&input) {
111 *cell_id
112 } else {
113 let new_id = self.cells.add_node(CellValue::new(input.clone()));
114 let cell = Cell(new_id);
115 self.input_to_cell.insert(input, cell);
116 cell
117 }
118 }
119
120 pub fn update_input(&mut self, input: F, new_value: Value)
125 where
126 F: std::fmt::Debug,
127 {
128 let cell_id = self.cell(input);
129 debug_assert!(
130 self.is_input(cell_id),
131 "`{input:?}` is not an input - inputs must have 0 dependencies"
132 );
133
134 let cell = &mut self.cells[cell_id.0];
135 let new_hash = new_value.get_hash();
136
137 if let Some((old_hash, _)) = cell.result.as_ref() {
138 if new_hash == *old_hash {
139 cell.last_verified_version = self.version;
140 return;
141 }
142 }
143
144 self.version += 1;
145 cell.result = Some((new_hash, new_value));
146 cell.last_updated_version = self.version;
147 cell.last_verified_version = self.version;
148 }
149
150 fn is_input(&self, cell: Cell) -> bool {
151 self.cells.neighbors(cell.0).count() == 0
152 }
153
154 #[cfg(test)]
155 fn get_cell(&mut self, input: F) -> &CellValue<F> {
156 let cell = self.cell(input);
157 &self.cells[cell.0]
158 }
159
160 fn run_compute_function(&mut self, cell_id: Cell) {
164 let cell = &self.cells[cell_id.0];
165 let result = cell.compute.run(&mut self.handle(cell_id));
166 let new_hash = result.get_hash();
167 let cell = &mut self.cells[cell_id.0];
168
169 if let Some((old_hash, _)) = cell.result.as_ref() {
170 if new_hash == *old_hash {
171 cell.last_verified_version = self.version;
172 return;
173 }
174 }
175
176 cell.result = Some((new_hash, result));
177 cell.last_verified_version = self.version;
178 cell.last_updated_version = self.version;
179 }
180
181 fn handle(&mut self, cell: Cell) -> DbHandle<F> {
182 DbHandle::new(self, cell)
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use crate::{Db, Run, START_VERSION, Value, db_handle::DbHandle};
189
190 #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
191 enum Basic {
192 A1,
193 A2,
194 A3,
195 }
196
197 impl Run for Basic {
202 fn run(self, db: &mut DbHandle<Self>) -> Value {
203 match self {
204 Basic::A1 => Value::new(20i32),
205 Basic::A2 => Value::new(db.get::<i32>(Basic::A1) + 1i32),
206 Basic::A3 => Value::new(db.get::<i32>(Basic::A2) + 2i32),
207 }
208 }
209 }
210
211 #[test]
217 fn basic() {
218 let mut db = Db::new();
219 let result = *db.get::<i32>(Basic::A3);
220 assert_eq!(result, 23);
221 }
222
223 #[test]
229 fn no_recompute_basic() {
230 let mut db = Db::new();
231 let result1 = *db.get::<i32>(Basic::A3);
232 let result2 = *db.get::<i32>(Basic::A3);
233 assert_eq!(result1, 23);
234 assert_eq!(result2, 23);
235
236 let expected_version = START_VERSION;
238 assert_eq!(db.version, expected_version);
239
240 let a1 = db.get_cell(Basic::A1);
241 assert_eq!(a1.last_updated_version, expected_version);
242 assert_eq!(a1.last_verified_version, expected_version);
243
244 let a2 = db.get_cell(Basic::A2);
245 assert_eq!(a2.last_updated_version, expected_version);
246 assert_eq!(a2.last_verified_version, expected_version);
247
248 let a3 = db.get_cell(Basic::A3);
249 assert_eq!(a3.last_updated_version, expected_version);
250 assert_eq!(a3.last_verified_version, expected_version);
251 }
252
253 #[test]
254 fn early_cutoff() {
255 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
267 enum EarlyCutoff {
268 Numerator,
269 Denominator,
270 Division,
271 DenominatorIs0,
272 Result,
273 }
274
275 impl Run for EarlyCutoff {
276 fn run(self, db: &mut DbHandle<Self>) -> Value {
277 use EarlyCutoff::*;
278 match self {
279 Numerator => Value::new(6),
280 Denominator => Value::new(0),
281 Division => Value::new(*db.get::<i32>(Numerator) / *db.get::<i32>(Denominator)),
282 DenominatorIs0 => Value::new(*db.get::<i32>(Denominator) == 0),
283 Result => {
284 if *db.get(DenominatorIs0) {
285 Value::new(0i32)
286 } else {
287 Value::new(*db.get::<i32>(Division))
288 }
289 }
290 }
291 }
292 }
293
294 {
295 assert_eq!(0i32, *Db::new().get(EarlyCutoff::Result));
297 }
298
299 {
300 let mut db = Db::new();
302 assert_eq!(db.version, START_VERSION);
303 db.update_input(EarlyCutoff::Denominator, Value::new(2i32));
304 assert_eq!(db.version, START_VERSION + 1);
305
306 assert_eq!(3i32, *db.get(EarlyCutoff::Result));
308
309 db.update_input(EarlyCutoff::Denominator, Value::new(0i32));
310 assert_eq!(db.version, START_VERSION + 2);
311
312 assert_eq!(0i32, *db.get(EarlyCutoff::Result));
319 }
320 }
321}