1use std::sync::Arc;
2use std::sync::atomic::{AtomicU32, Ordering};
3
4use crate::accumulate::Accumulate;
5use crate::cell::CellData;
6use crate::storage::StorageFor;
7use crate::{Cell, Computation, Storage};
8
9mod handle;
10mod serialize;
11mod tests;
12
13pub use handle::DbHandle;
14use parking_lot::Mutex;
15use rustc_hash::FxHashSet;
16
17const START_VERSION: u32 = 1;
18
19pub struct Db<Storage> {
24 cells: dashmap::DashMap<Cell, CellData>,
25 version: AtomicU32,
26 next_cell: AtomicU32,
27 storage: Storage,
28
29 cell_locks: dashmap::DashMap<u32, Arc<Mutex<()>>>,
32}
33
34impl<Storage: Default> Db<Storage> {
35 pub fn new() -> Self {
37 Self::with_storage(Storage::default())
38 }
39}
40
41impl<S: Default> Default for Db<S> {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47pub trait DbGet<C: Computation> {
50 fn get(&self, key: C) -> C::Output;
53}
54
55impl<S, C> DbGet<C> for Db<S>
56where
57 C: Computation,
58 S: Storage + StorageFor<C>,
59{
60 fn get(&self, key: C) -> C::Output {
61 self.get(key)
62 }
63}
64
65impl<S> Db<S> {
66 pub fn with_storage(storage: S) -> Self {
68 Self {
69 cells: Default::default(),
70 version: AtomicU32::new(START_VERSION),
71 next_cell: AtomicU32::new(0),
72 cell_locks: Default::default(),
73 storage,
74 }
75 }
76
77 pub fn storage(&self) -> &S {
79 &self.storage
80 }
81
82 pub fn storage_mut(&mut self) -> &mut S {
87 &mut self.storage
88 }
89}
90
91impl<S: Storage> Db<S> {
92 fn get_cell<C: Computation>(&self, computation: &C) -> Option<Cell>
96 where
97 S: StorageFor<C>,
98 {
99 self.storage.get_cell_for_computation(computation)
100 }
101
102 pub(crate) fn get_or_insert_cell<C>(&self, input: C) -> Cell
103 where
104 C: Computation,
105 S: StorageFor<C>,
106 {
107 let computation_id = C::computation_id();
108 let lock = self.cell_locks.entry(computation_id).or_default().clone();
109 let _guard = lock.lock();
110
111 if let Some(cell) = self.get_cell(&input) {
112 cell
113 } else {
114 let cell_id = self.next_cell.fetch_add(1, Ordering::Relaxed);
117 let new_cell = Cell::new(cell_id);
118
119 self.cells.insert(new_cell, CellData::new(computation_id));
120 self.storage.insert_new_cell(new_cell, input);
121 new_cell
122 }
123 }
124
125 fn handle(&self, cell: Cell) -> DbHandle<'_, S> {
126 DbHandle::new(self, cell)
127 }
128
129 #[cfg(test)]
130 #[allow(unused)]
131 pub(crate) fn with_cell_data<C: Computation>(&self, input: &C, f: impl FnOnce(&CellData))
132 where
133 S: StorageFor<C>,
134 {
135 let cell = self
136 .get_cell(input)
137 .unwrap_or_else(|| panic!("unwrap_cell_value: Expected cell to exist"));
138
139 self.cells.get(&cell).map(|value| f(&value)).unwrap()
140 }
141
142 pub fn version(&self) -> u32 {
143 self.version.load(Ordering::SeqCst)
144 }
145
146 pub fn gc(&mut self, version: u32) {
147 let used_cells: std::collections::HashSet<Cell> = self
148 .cells
149 .iter()
150 .filter_map(|entry| {
151 if entry.value().last_verified_version >= version {
152 Some(entry.key().clone())
153 } else {
154 None
155 }
156 })
157 .collect();
158
159 self.storage.gc(&used_cells);
160 }
161}
162
163impl<S: Storage> Db<S> {
164 pub fn update_input<C>(&mut self, input: C, new_value: C::Output)
172 where
173 C: Computation,
174 S: StorageFor<C>,
175 {
176 let cell_id = self.get_or_insert_cell(input);
177 assert!(
178 self.is_input(cell_id),
179 "`update_input` given a non-input value. Inputs must have 0 dependencies",
180 );
181
182 let changed = self.storage.update_output(cell_id, new_value);
183 let mut cell = self.cells.get_mut(&cell_id).unwrap();
184
185 if changed {
186 let version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
187 cell.last_updated_version = version;
188 cell.last_verified_version = version;
189 } else {
190 cell.last_verified_version = self.version.load(Ordering::SeqCst);
191 }
192 }
193
194 fn is_input(&self, cell: Cell) -> bool {
195 self.with_cell(cell, |cell| {
196 cell.dependencies.is_empty() && cell.input_dependencies.is_empty()
197 })
198 }
199
200 pub fn is_stale<C: Computation>(&self, input: &C) -> bool
205 where
206 S: StorageFor<C>,
207 {
208 let Some(cell) = self.get_cell(input) else {
210 return true;
211 };
212 self.is_stale_cell(cell)
213 }
214
215 fn is_stale_cell(&self, cell: Cell) -> bool {
219 let computation_id = self.with_cell(cell, |data| data.computation_id);
220
221 if self.storage.output_is_unset(cell, computation_id) {
222 return true;
223 }
224
225 let (last_verified, inputs, dependencies) = self.with_cell(cell, |data| {
227 (
228 data.last_verified_version,
229 data.input_dependencies.clone(),
230 data.dependencies.clone(),
231 )
232 });
233
234 let inputs_changed = inputs.into_iter().any(|input_id| {
237 self.with_cell(input_id, |input| input.last_updated_version > last_verified)
240 });
241
242 inputs_changed
246 && dependencies.into_iter().any(|dependency_id| {
247 self.update_cell(dependency_id);
248 self.with_cell(dependency_id, |dependency| {
249 dependency.last_updated_version > last_verified
250 })
251 })
252 }
253
254 fn run_compute_function(&self, cell_id: Cell) {
258 let computation_id = self.with_cell(cell_id, |data| data.computation_id);
259
260 let handle = self.handle(cell_id);
261 let changed = S::run_computation(&handle, cell_id, computation_id);
262
263 let version = self.version.load(Ordering::SeqCst);
264 let mut cell = self.cells.get_mut(&cell_id).unwrap();
265 cell.last_verified_version = version;
266
267 if changed {
268 cell.last_updated_version = version;
269 }
270 }
271
272 fn update_cell(&self, cell_id: Cell) {
275 let last_verified_version = self.with_cell(cell_id, |data| data.last_verified_version);
276 let version = self.version.load(Ordering::SeqCst);
277
278 if last_verified_version != version {
279 if self.is_stale_cell(cell_id) {
281 let lock = self.with_cell(cell_id, |cell| cell.lock.clone());
282
283 match lock.try_lock() {
284 Some(guard) => {
285 self.run_compute_function(cell_id);
286 drop(guard);
287 }
288 None => {
289 self.check_for_cycle(cell_id);
293
294 drop(lock.lock());
296 }
297 }
298 } else {
299 let mut cell = self.cells.get_mut(&cell_id).unwrap();
300 cell.last_verified_version = version;
301 }
302 }
303 }
304
305 fn check_for_cycle(&self, starting_cell: Cell) {
307 let mut visited = FxHashSet::default();
308 let mut path = Vec::new();
309
310 let mut stack = Vec::new();
315 stack.push(Action::Traverse(starting_cell));
316
317 enum Action {
318 Traverse(Cell),
319 Pop(Cell),
320 }
321
322 while let Some(action) = stack.pop() {
323 match action {
324 Action::Pop(expected) => assert_eq!(path.pop(), Some(expected)),
326 Action::Traverse(cell) => {
327 if path.contains(&cell) {
328 path.push(cell);
330 self.cycle_error(&path);
331 }
332
333 if visited.insert(cell) {
334 path.push(cell);
335 stack.push(Action::Pop(cell));
336 self.with_cell(cell, |cell| {
337 for dependency in cell.dependencies.iter() {
338 stack.push(Action::Traverse(*dependency));
339 }
340 });
341 }
342 }
343 }
344 }
345 }
346
347 fn cycle_error(&self, cycle: &[Cell]) {
350 let mut error = self.storage.input_debug_string(cycle[0]);
352 for cell in cycle.iter().skip(1) {
353 error += &format!(" -> {}", self.storage.input_debug_string(*cell));
354 }
355 panic!("inc-complete: Cycle Detected!\n\nCycle:\n {error}")
356 }
357
358 pub fn get<C: Computation>(&self, compute: C) -> C::Output
366 where
367 S: StorageFor<C>,
368 {
369 let cell_id = self.get_or_insert_cell(compute);
370 self.get_with_cell::<C>(cell_id)
371 }
372
373 pub(crate) fn get_with_cell<Concrete: Computation>(&self, cell_id: Cell) -> Concrete::Output
374 where
375 S: StorageFor<Concrete>,
376 {
377 self.update_cell(cell_id);
378
379 self.storage
380 .get_output(cell_id)
381 .expect("cell result should have been computed already")
382 }
383
384 fn with_cell<R>(&self, cell: Cell, f: impl FnOnce(&CellData) -> R) -> R {
385 f(&self.cells.get(&cell).unwrap())
386 }
387
388 fn collect_all_dependencies(&self, operation: Cell) -> Vec<Cell> {
392 self.update_cell(operation);
394
395 let mut queue: Vec<_> = self.with_cell(operation, |cell| {
396 cell.dependencies.iter().copied().collect()
397 });
398
399 let mut cells = Vec::new();
400 cells.push(operation);
401
402 while let Some(dependency) = queue.pop() {
405 cells.push(dependency);
406
407 self.with_cell(dependency, |cell| {
408 for dependency in cell.dependencies.iter() {
409 queue.push(*dependency);
410 }
411 });
412 }
413
414 cells.reverse();
415 cells
416 }
417
418 pub fn get_accumulated<Container, Item, C>(&self, compute: C) -> Container
427 where
428 Container: FromIterator<Item>,
429 S: Accumulate<Item> + StorageFor<C>,
430 C: Computation,
431 {
432 let cell_id = self.get_or_insert_cell(compute);
433
434 let cells = self.collect_all_dependencies(cell_id);
435 self.storage.get_accumulated(&cells)
436 }
437}