1use std::sync::atomic::{AtomicU32, Ordering};
2
3use crate::accumulate::Accumulate;
4use crate::cell::CellData;
5use crate::storage::StorageFor;
6use crate::{Cell, Computation, Storage};
7
8mod handle;
9mod serialize;
10mod tests;
11
12pub use handle::DbHandle;
13use rustc_hash::FxHashSet;
14
15const START_VERSION: u32 = 1;
16
17pub struct Db<Storage> {
22 cells: dashmap::DashMap<Cell, CellData>,
23 version: AtomicU32,
24 next_cell: AtomicU32,
25 storage: Storage,
26}
27
28impl<Storage: Default> Db<Storage> {
29 pub fn new() -> Self {
31 Self::with_storage(Storage::default())
32 }
33}
34
35impl<S: Default> Default for Db<S> {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41pub trait DbGet<C: Computation> {
44 fn get(&self, key: C) -> C::Output;
47}
48
49impl<S, C> DbGet<C> for Db<S>
50where
51 C: Computation,
52 S: Storage + StorageFor<C>,
53{
54 fn get(&self, key: C) -> C::Output {
55 self.get(key)
56 }
57}
58
59impl<S> Db<S> {
60 pub fn with_storage(storage: S) -> Self {
62 Self {
63 cells: Default::default(),
64 version: AtomicU32::new(START_VERSION),
65 next_cell: AtomicU32::new(0),
66 storage,
67 }
68 }
69
70 pub fn storage(&self) -> &S {
72 &self.storage
73 }
74
75 pub fn storage_mut(&mut self) -> &mut S {
80 &mut self.storage
81 }
82}
83
84impl<S: Storage> Db<S> {
85 fn get_cell<C: Computation>(&self, computation: &C) -> Option<Cell>
89 where
90 S: StorageFor<C>,
91 {
92 self.storage.get_cell_for_computation(computation)
93 }
94
95 pub(crate) fn get_or_insert_cell<C>(&self, input: C) -> Cell
96 where
97 C: Computation,
98 S: StorageFor<C>,
99 {
100 if let Some(cell) = self.get_cell(&input) {
101 cell
102 } else {
103 let computation_id = C::computation_id();
104
105 let cell_id = self.next_cell.fetch_add(1, Ordering::Relaxed);
108 let new_cell = Cell::new(cell_id);
109
110 self.cells.insert(new_cell, CellData::new(computation_id));
111 self.storage.insert_new_cell(new_cell, input);
112 new_cell
113 }
114 }
115
116 fn handle(&self, cell: Cell) -> DbHandle<'_, S> {
117 DbHandle::new(self, cell)
118 }
119
120 #[cfg(test)]
121 #[allow(unused)]
122 pub(crate) fn with_cell_data<C: Computation>(&self, input: &C, f: impl FnOnce(&CellData))
123 where
124 S: StorageFor<C>,
125 {
126 let cell = self
127 .get_cell(input)
128 .unwrap_or_else(|| panic!("unwrap_cell_value: Expected cell to exist"));
129
130 self.cells.get(&cell).map(|value| f(&value)).unwrap()
131 }
132
133 pub fn version(&self) -> u32 {
134 self.version.load(Ordering::SeqCst)
135 }
136
137
138 pub fn gc(&mut self, version: u32) {
139 let used_cells: std::collections::HashSet<Cell> = self
140 .cells
141 .iter()
142 .filter_map(|entry| {
143 if entry.value().last_verified_version >= version {
144 Some(entry.key().clone())
145 } else {
146 None
147 }
148 })
149 .collect();
150
151 self.storage.gc(&used_cells);
152 }
153}
154
155impl<S: Storage> Db<S> {
156 pub fn update_input<C>(&mut self, input: C, new_value: C::Output)
164 where
165 C: Computation,
166 S: StorageFor<C>,
167 {
168 let cell_id = self.get_or_insert_cell(input);
169 assert!(
170 self.is_input(cell_id),
171 "`update_input` given a non-input value. Inputs must have 0 dependencies",
172 );
173
174 let changed = self.storage.update_output(cell_id, new_value);
175 let mut cell = self.cells.get_mut(&cell_id).unwrap();
176
177 if changed {
178 let version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
179 cell.last_updated_version = version;
180 cell.last_verified_version = version;
181 } else {
182 cell.last_verified_version = self.version.load(Ordering::SeqCst);
183 }
184 }
185
186 fn is_input(&self, cell: Cell) -> bool {
187 self.with_cell(cell, |cell| {
188 cell.dependencies.is_empty() && cell.input_dependencies.is_empty()
189 })
190 }
191
192 pub fn is_stale<C: Computation>(&self, input: &C) -> bool
197 where
198 S: StorageFor<C>,
199 {
200 let Some(cell) = self.get_cell(input) else {
202 return true;
203 };
204 self.is_stale_cell(cell)
205 }
206
207 fn is_stale_cell(&self, cell: Cell) -> bool {
211 let computation_id = self.with_cell(cell, |data| data.computation_id);
212
213 if self.storage.output_is_unset(cell, computation_id) {
214 return true;
215 }
216
217 let (last_verified, inputs, dependencies) = self.with_cell(cell, |data| {
219 (
220 data.last_verified_version,
221 data.input_dependencies.clone(),
222 data.dependencies.clone(),
223 )
224 });
225
226 let inputs_changed = inputs.into_iter().any(|input_id| {
229 self.with_cell(input_id, |input| input.last_updated_version > last_verified)
232 });
233
234 inputs_changed
238 && dependencies.into_iter().any(|dependency_id| {
239 self.update_cell(dependency_id);
240 self.with_cell(dependency_id, |dependency| {
241 dependency.last_updated_version > last_verified
242 })
243 })
244 }
245
246 fn run_compute_function(&self, cell_id: Cell) {
250 let computation_id = self.with_cell(cell_id, |data| data.computation_id);
251
252 let handle = self.handle(cell_id);
253 let changed = S::run_computation(&handle, cell_id, computation_id);
254
255 let version = self.version.load(Ordering::SeqCst);
256 let mut cell = self.cells.get_mut(&cell_id).unwrap();
257 cell.last_verified_version = version;
258
259 if changed {
260 cell.last_updated_version = version;
261 }
262 }
263
264 fn update_cell(&self, cell_id: Cell) {
267 let last_verified_version = self.with_cell(cell_id, |data| data.last_verified_version);
268 let version = self.version.load(Ordering::SeqCst);
269
270 if last_verified_version != version {
271 if self.is_stale_cell(cell_id) {
273 let lock = self.with_cell(cell_id, |cell| cell.lock.clone());
274
275 match lock.try_lock() {
276 Some(guard) => {
277 self.run_compute_function(cell_id);
278 drop(guard);
279 }
280 None => {
281 self.check_for_cycle(cell_id);
285
286 drop(lock.lock());
288 }
289 }
290 } else {
291 let mut cell = self.cells.get_mut(&cell_id).unwrap();
292 cell.last_verified_version = version;
293 }
294 }
295 }
296
297 fn check_for_cycle(&self, starting_cell: Cell) {
299 let mut visited = FxHashSet::default();
300 let mut path = Vec::new();
301
302 let mut stack = Vec::new();
307 stack.push(Action::Traverse(starting_cell));
308
309 enum Action {
310 Traverse(Cell),
311 Pop(Cell),
312 }
313
314 while let Some(action) = stack.pop() {
315 match action {
316 Action::Pop(expected) => assert_eq!(path.pop(), Some(expected)),
318 Action::Traverse(cell) => {
319 if path.contains(&cell) {
320 path.push(cell);
322 self.cycle_error(&path);
323 }
324
325 if visited.insert(cell) {
326 path.push(cell);
327 stack.push(Action::Pop(cell));
328 self.with_cell(cell, |cell| {
329 for dependency in cell.dependencies.iter() {
330 stack.push(Action::Traverse(*dependency));
331 }
332 });
333 }
334 }
335 }
336 }
337 }
338
339 fn cycle_error(&self, cycle: &[Cell]) {
342 let mut error = self.storage.input_debug_string(cycle[0]);
344 for cell in cycle.iter().skip(1) {
345 error += &format!(" -> {}", self.storage.input_debug_string(*cell));
346 }
347 panic!("inc-complete: Cycle Detected!\n\nCycle:\n {error}")
348 }
349
350 pub fn get<C: Computation>(&self, compute: C) -> C::Output
358 where
359 S: StorageFor<C>,
360 {
361 let cell_id = self.get_or_insert_cell(compute);
362 self.get_with_cell::<C>(cell_id)
363 }
364
365 pub(crate) fn get_with_cell<Concrete: Computation>(&self, cell_id: Cell) -> Concrete::Output
366 where
367 S: StorageFor<Concrete>,
368 {
369 self.update_cell(cell_id);
370
371 self.storage
372 .get_output(cell_id)
373 .expect("cell result should have been computed already")
374 }
375
376 fn with_cell<R>(&self, cell: Cell, f: impl FnOnce(&CellData) -> R) -> R {
377 f(&self.cells.get(&cell).unwrap())
378 }
379
380 fn collect_all_dependencies(&self, operation: Cell) -> Vec<Cell> {
384 self.update_cell(operation);
386
387 let mut queue: Vec<_> = self.with_cell(operation, |cell| {
388 cell.dependencies.iter().copied().collect()
389 });
390
391 let mut cells = Vec::new();
392 cells.push(operation);
393
394 while let Some(dependency) = queue.pop() {
397 cells.push(dependency);
398
399 self.with_cell(dependency, |cell| {
400 for dependency in cell.dependencies.iter() {
401 queue.push(*dependency);
402 }
403 });
404 }
405
406 cells.reverse();
407 cells
408 }
409
410 pub fn get_accumulated<Container, Item, C>(&self, compute: C) -> Container where
419 Container: FromIterator<Item>,
420 S: Accumulate<Item> + StorageFor<C>,
421 C: Computation
422 {
423 let cell_id = self.get_or_insert_cell(compute);
424
425 let cells = self.collect_all_dependencies(cell_id);
426 self.storage.get_accumulated(&cells)
427 }
428}