1use std::{any::Any, collections::HashMap, hash::{ Hasher, Hash }};
2
3use petgraph::graph::DiGraph;
4
5const START_VERSION: u32 = 1;
6
7pub struct Db<F> {
8 cells: DiGraph<CellValue<F>, ()>,
9 version: u32,
10
11 input_to_cell: HashMap<F, Cell>,
12}
13
14#[derive(Copy, Clone)]
15pub struct Cell(petgraph::graph::NodeIndex);
16
17struct CellValue<F> {
18 compute: F,
19 result: Option<(u64, Box<dyn Any>)>,
20
21 last_updated_version: u32,
22 last_verified_version: u32,
23}
24
25impl<F> CellValue<F> {
26 fn new(compute: F) -> Self {
27 Self {
28 compute,
29 result: None,
30 last_updated_version: 0,
31 last_verified_version: 0,
32 }
33 }
34}
35
36pub trait Run: Sized {
37 fn run(self, db: &mut Db<Self>) -> impl Any + Hash;
38}
39
40impl<F> Db<F> {
41 pub fn new() -> Self {
42 Self {
43 cells: DiGraph::default(),
44 input_to_cell: HashMap::default(),
45 version: START_VERSION,
46 }
47 }
48}
49
50impl<F: Run + Copy + Eq + Hash + Clone> Db<F> {
51 pub fn get<'a, T: 'static>(&'a mut self, compute: F) -> &'a T {
52 let cell_id = self.input(compute);
53 let cell = &self.cells[cell_id.0];
54
55 if cell.last_verified_version != self.version {
56 if cell.result.is_some() {
57 let neighbors = self.cells.neighbors(cell_id.0).collect::<Vec<_>>();
58
59 if neighbors.into_iter().any(|input_id| {
60 let input = &self.cells[input_id];
61 let cell = &self.cells[cell_id.0];
62 let need_to_update = input.last_updated_version > cell.last_verified_version;
63 if need_to_update {
64 self.update_cell(Cell(input_id));
65 }
66 need_to_update
67 }) {
68 self.update_cell(cell_id);
69 } else {
70 let cell = &mut self.cells[cell_id.0];
71 cell.last_verified_version = self.version;
72 }
73 } else {
74 self.update_cell(cell_id);
75 }
76 }
77
78 let cell = &self.cells[cell_id.0];
79 let result = cell.result.as_ref()
80 .expect("cell result should have been computed already").1.as_ref();
81
82 result
83 .downcast_ref()
84 .expect("Output type to `Db::get` does not match the type of the value returned by the `Run::run` function")
85 }
86
87 pub fn input(&mut self, input: F) -> Cell {
88 if let Some(cell_id) = self.input_to_cell.get(&input) {
89 *cell_id
90 } else {
91 self.version += 1;
92 let new_id = self.cells.add_node(CellValue::new(input.clone()));
93 let cell = Cell(new_id);
94 self.input_to_cell.insert(input, cell);
95 cell
96 }
97 }
98
99 #[cfg(test)]
100 fn get_cell(&mut self, input: F) -> &CellValue<F> {
101 let cell = self.input(input);
102 &self.cells[cell.0]
103 }
104
105 fn update_cell(&mut self, cell_id: Cell) {
106 let cell = &self.cells[cell_id.0];
107 let result = cell.compute.run(self);
108 let new_hash = hash(&result);
109 let cell = &mut self.cells[cell_id.0];
110
111 if let Some((old_hash, _)) = cell.result.as_ref() {
112 if new_hash == *old_hash {
113 cell.last_verified_version = self.version;
114 return;
115 }
116 }
117
118 cell.result = Some((new_hash, Box::new(result)));
119 cell.last_verified_version = self.version;
120 cell.last_updated_version = self.version;
121 }
122}
123
124fn hash<T: Hash>(x: T) -> u64 {
126 let mut hasher = std::hash::DefaultHasher::default();
127 x.hash(&mut hasher);
128 hasher.finish()
129}
130
131#[cfg(test)]
132mod tests {
133 use crate::{Db, Run, START_VERSION};
134
135 #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
136 enum Basic {
137 A1,
138 A2,
139 A3,
140 }
141
142 impl Run for Basic {
147 fn run(self, db: &mut Db<Self>) -> impl std::any::Any + std::hash::Hash {
148 match self {
149 Basic::A1 => 20,
150 Basic::A2 => db.get::<i32>(Basic::A1) + 1,
151 Basic::A3 => db.get::<i32>(Basic::A2) + 2,
152 }
153 }
154 }
155
156 #[test]
162 fn basic() {
163 let mut db = Db::new();
164 let result = *db.get::<i32>(Basic::A3);
165 assert_eq!(result, 23);
166 }
167
168 #[test]
174 fn no_recompute_basic() {
175 let mut db = Db::new();
176 let result1 = *db.get::<i32>(Basic::A3);
177 let result2 = *db.get::<i32>(Basic::A3);
178 assert_eq!(result1, 23);
179 assert_eq!(result2, 23);
180
181 let expected_version = START_VERSION + 3;
183 assert_eq!(db.version, expected_version);
184
185 let a1 = db.get_cell(Basic::A1);
186 assert_eq!(a1.last_updated_version, expected_version);
187 assert_eq!(a1.last_verified_version, expected_version);
188
189 let a2 = db.get_cell(Basic::A2);
190 assert_eq!(a2.last_updated_version, expected_version);
191 assert_eq!(a2.last_verified_version, expected_version);
192
193 let a3 = db.get_cell(Basic::A3);
194 assert_eq!(a3.last_updated_version, expected_version);
195 assert_eq!(a3.last_verified_version, expected_version);
196 }
197}