1use std::sync::Arc;
2use std::sync::atomic::{AtomicU32, Ordering};
3
4use crate::accumulate::Accumulated;
5use crate::cell::CellData;
6use crate::storage::StorageFor;
7use crate::{Cell, Computation, Storage};
8
9mod handle;
10mod serialize;
11mod tests;
12pub mod debug_with_db;
13
14pub use handle::DbHandle;
15use parking_lot::Mutex;
16use rustc_hash::FxHashSet;
17
18const START_VERSION: u32 = 1;
19
20pub struct Db<Storage> {
25 cells: dashmap::DashMap<Cell, CellData>,
26 version: AtomicU32,
27 next_cell: AtomicU32,
28 storage: Storage,
29
30 cell_locks: dashmap::DashMap<u32, Arc<Mutex<()>>>,
33}
34
35impl<Storage: Default> Db<Storage> {
36 pub fn new() -> Self {
38 Self::with_storage(Storage::default())
39 }
40}
41
42impl<S: Default> Default for Db<S> {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48pub trait DbGet<C: Computation> {
51 fn get(&self, key: C) -> C::Output;
54}
55
56impl<S, C> DbGet<C> for Db<S>
57where
58 C: Computation,
59 S: Storage + StorageFor<C>,
60{
61 fn get(&self, key: C) -> C::Output {
62 self.get(key)
63 }
64}
65
66impl<S> Db<S> {
67 pub fn with_storage(storage: S) -> Self {
69 Self {
70 cells: Default::default(),
71 version: AtomicU32::new(START_VERSION),
72 next_cell: AtomicU32::new(0),
73 cell_locks: Default::default(),
74 storage,
75 }
76 }
77
78 pub fn storage(&self) -> &S {
80 &self.storage
81 }
82
83 pub fn storage_mut(&mut self) -> &mut S {
88 &mut self.storage
89 }
90}
91
92impl<S: Storage> Db<S> {
93 fn get_cell<C: Computation>(&self, computation: &C) -> Option<Cell>
97 where
98 S: StorageFor<C>,
99 {
100 self.storage.get_cell_for_computation(computation)
101 }
102
103 pub(crate) fn get_or_insert_cell<C>(&self, input: C) -> Cell
104 where
105 C: Computation,
106 S: StorageFor<C>,
107 {
108 let computation_id = C::computation_id();
109 let lock = self.cell_locks.entry(computation_id).or_default().clone();
110 let _guard = lock.lock();
111
112 if let Some(cell) = self.get_cell(&input) {
113 cell
114 } else {
115 let cell_id = self.next_cell.fetch_add(1, Ordering::Relaxed);
118 let new_cell = Cell::new(cell_id);
119
120 self.cells.insert(new_cell, CellData::new(computation_id));
121 self.storage.insert_new_cell(new_cell, input);
122 new_cell
123 }
124 }
125
126 fn handle(&self, cell: Cell) -> DbHandle<'_, S> {
127 DbHandle::new(self, cell)
128 }
129
130 #[cfg(test)]
131 #[allow(unused)]
132 pub(crate) fn with_cell_data<C: Computation>(&self, input: &C, f: impl FnOnce(&CellData))
133 where
134 S: StorageFor<C>,
135 {
136 let cell = self
137 .get_cell(input)
138 .unwrap_or_else(|| panic!("unwrap_cell_value: Expected cell to exist"));
139
140 self.cells.get(&cell).map(|value| f(&value)).unwrap()
141 }
142
143 pub fn version(&self) -> u32 {
144 self.version.load(Ordering::SeqCst)
145 }
146
147 pub fn gc(&mut self, version: u32) {
148 let used_cells: std::collections::HashSet<Cell> = self
149 .cells
150 .iter()
151 .filter_map(|entry| {
152 if entry.value().last_verified_version >= version {
153 Some(entry.key().clone())
154 } else {
155 None
156 }
157 })
158 .collect();
159
160 self.storage.gc(&used_cells);
161 }
162}
163
164impl<S: Storage> Db<S> {
165 pub fn update_input<C>(&mut self, input: C, new_value: C::Output)
173 where
174 C: Computation,
175 S: StorageFor<C>,
176 {
177 let cell_id = self.get_or_insert_cell(input);
178 assert!(
179 self.is_input(cell_id),
180 "`update_input` given a non-input value. Inputs must have 0 dependencies",
181 );
182
183 let changed = self.storage.update_output(cell_id, new_value);
184 let mut cell = self.cells.get_mut(&cell_id).unwrap();
185
186 if changed {
187 let version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
188 cell.last_updated_version = version;
189 cell.last_verified_version = version;
190 } else {
191 cell.last_verified_version = self.version.load(Ordering::SeqCst);
192 }
193 }
194
195 fn is_input(&self, cell: Cell) -> bool {
196 self.with_cell(cell, |cell| {
197 cell.dependencies.is_empty() && cell.input_dependencies.is_empty()
198 })
199 }
200
201 pub fn is_stale<C: Computation>(&self, input: &C) -> bool
206 where
207 S: StorageFor<C>,
208 {
209 let Some(cell) = self.get_cell(input) else {
211 return true;
212 };
213 self.is_stale_cell(cell)
214 }
215
216 fn is_stale_cell(&self, cell: Cell) -> bool {
220 let computation_id = self.with_cell(cell, |data| data.computation_id);
221
222 if self.storage.output_is_unset(cell, computation_id) {
223 return true;
224 }
225
226 let (last_verified, inputs, dependencies) = self.with_cell(cell, |data| {
228 (
229 data.last_verified_version,
230 data.input_dependencies.clone(),
231 data.dependencies.clone(),
232 )
233 });
234
235 let inputs_changed = inputs.into_iter().any(|input_id| {
238 self.with_cell(input_id, |input| input.last_updated_version > last_verified)
241 });
242
243 inputs_changed
247 && dependencies.into_iter().any(|dependency_id| {
248 self.update_cell(dependency_id);
249 self.with_cell(dependency_id, |dependency| {
250 dependency.last_updated_version > last_verified
251 })
252 })
253 }
254
255 fn run_compute_function(&self, cell_id: Cell) {
259 let computation_id = self.with_cell(cell_id, |data| data.computation_id);
260 self.storage.clear_accumulated_for_cell(cell_id);
261 let handle = self.handle(cell_id);
262 let changed = S::run_computation(&handle, cell_id, computation_id);
263
264 let version = self.version.load(Ordering::SeqCst);
265 let mut cell = self.cells.get_mut(&cell_id).unwrap();
266 cell.last_verified_version = version;
267
268 if changed {
269 cell.last_updated_version = version;
270 }
271 }
272
273 fn update_cell(&self, cell_id: Cell) {
276 let last_verified_version = self.with_cell(cell_id, |data| data.last_verified_version);
277 let version = self.version.load(Ordering::SeqCst);
278
279 if last_verified_version != version {
280 if self.is_stale_cell(cell_id) {
282 let lock = self.with_cell(cell_id, |cell| cell.lock.clone());
283
284 match lock.try_lock() {
285 Some(guard) => {
286 self.run_compute_function(cell_id);
287 drop(guard);
288 }
289 None => {
290 self.check_for_cycle(cell_id);
294
295 drop(lock.lock());
297 }
298 }
299 } else {
300 let mut cell = self.cells.get_mut(&cell_id).unwrap();
301 cell.last_verified_version = version;
302 }
303 }
304 }
305
306 fn check_for_cycle(&self, starting_cell: Cell) {
308 let mut visited = FxHashSet::default();
309 let mut path = Vec::new();
310
311 let mut stack = Vec::new();
316 stack.push(Action::Traverse(starting_cell));
317
318 enum Action {
319 Traverse(Cell),
320 Pop(Cell),
321 }
322
323 while let Some(action) = stack.pop() {
324 match action {
325 Action::Pop(expected) => assert_eq!(path.pop(), Some(expected)),
327 Action::Traverse(cell) => {
328 if path.contains(&cell) {
329 path.push(cell);
331 self.cycle_error(&path);
332 }
333
334 if visited.insert(cell) {
335 path.push(cell);
336 stack.push(Action::Pop(cell));
337 self.with_cell(cell, |cell| {
338 for dependency in cell.dependencies.iter() {
339 stack.push(Action::Traverse(*dependency));
340 }
341 });
342 }
343 }
344 }
345 }
346 }
347
348 fn cycle_error(&self, cycle: &[Cell]) {
350 let mut error = String::new();
351 for (i, cell) in cycle.iter().enumerate() {
352 error += &format!("\n {}. {}", i + 1, self.storage.input_debug_string(self, *cell));
353 }
354 panic!("inc-complete: Cycle Detected!\n\nCycle:{error}")
355 }
356
357 pub fn get<C: Computation>(&self, compute: C) -> C::Output
365 where
366 S: StorageFor<C>,
367 {
368 let cell_id = self.get_or_insert_cell(compute);
369 self.get_with_cell::<C>(cell_id)
370 }
371
372 pub(crate) fn get_with_cell<Concrete: Computation>(&self, cell_id: Cell) -> Concrete::Output
373 where
374 S: StorageFor<Concrete>,
375 {
376 self.update_cell(cell_id);
377
378 self.storage
379 .get_output(cell_id)
380 .expect("cell result should have been computed already")
381 }
382
383 fn with_cell<R>(&self, cell: Cell, f: impl FnOnce(&CellData) -> R) -> R {
384 f(&self.cells.get(&cell).unwrap())
385 }
386
387 pub fn get_accumulated<Item, C>(&self, compute: C) -> Vec<Item>
392 where
393 S: StorageFor<C> + StorageFor<Accumulated<Item>>,
394 C: Computation,
395 Item: 'static,
396 {
397 let cell_id = self.get_or_insert_cell(compute);
398 self.update_cell(cell_id);
399 self.get(Accumulated::<Item>::new(cell_id))
400 }
401}