1use std::collections::BTreeSet;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU32, Ordering};
4
5use crate::accumulate::{ACCUMULATED_COMPUTATION_ID, Accumulate, Accumulated};
6use crate::cell::CellData;
7use crate::storage::StorageFor;
8use crate::{Cell, Computation, Storage};
9
10pub mod debug_with_db;
11mod handle;
12mod serialize;
13mod tests;
14
15pub use handle::DbHandle;
16use parking_lot::Mutex;
17use rustc_hash::FxHashSet;
18
19const START_VERSION: u32 = 1;
20
21pub struct Db<Storage> {
26 cells: dashmap::DashMap<Cell, CellData>,
27 version: AtomicU32,
28 next_cell: AtomicU32,
29 storage: Storage,
30
31 cell_locks: dashmap::DashMap<u32, Arc<Mutex<()>>>,
34}
35
36impl<Storage: Default> Db<Storage> {
37 pub fn new() -> Self {
39 Self::with_storage(Storage::default())
40 }
41}
42
43impl<S: Default> Default for Db<S> {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49pub trait DbGet<C: Computation> {
52 fn get(&self, key: C) -> C::Output;
55}
56
57impl<S, C> DbGet<C> for Db<S>
58where
59 C: Computation,
60 S: Storage + StorageFor<C>,
61{
62 fn get(&self, key: C) -> C::Output {
63 self.get(key)
64 }
65}
66
67impl<S> Db<S> {
68 pub fn with_storage(storage: S) -> Self {
70 Self {
71 cells: Default::default(),
72 version: AtomicU32::new(START_VERSION),
73 next_cell: AtomicU32::new(0),
74 cell_locks: Default::default(),
75 storage,
76 }
77 }
78
79 pub fn storage(&self) -> &S {
81 &self.storage
82 }
83
84 pub fn storage_mut(&mut self) -> &mut S {
89 &mut self.storage
90 }
91}
92
93impl<S: Storage> Db<S> {
94 fn get_cell<C: Computation>(&self, computation: &C) -> Option<Cell>
98 where
99 S: StorageFor<C>,
100 {
101 self.storage.get_cell_for_computation(computation)
102 }
103
104 pub(crate) fn get_or_insert_cell<C>(&self, input: C) -> Cell
105 where
106 C: Computation,
107 S: StorageFor<C>,
108 {
109 let computation_id = C::computation_id();
110 let lock = self.cell_locks.entry(computation_id).or_default().clone();
111 let _guard = lock.lock();
112
113 if let Some(cell) = self.get_cell(&input) {
114 cell
115 } else {
116 let cell_id = self.next_cell.fetch_add(1, Ordering::Relaxed);
119 let new_cell = Cell::new(cell_id);
120
121 self.cells.insert(new_cell, CellData::new(computation_id));
122 self.storage.insert_new_cell(new_cell, input);
123 new_cell
124 }
125 }
126
127 fn handle(&self, cell: Cell) -> DbHandle<'_, S> {
128 DbHandle::new(self, cell)
129 }
130
131 #[cfg(test)]
132 #[allow(unused)]
133 pub(crate) fn with_cell_data<C: Computation>(&self, input: &C, f: impl FnOnce(&CellData))
134 where
135 S: StorageFor<C>,
136 {
137 let cell = self
138 .get_cell(input)
139 .unwrap_or_else(|| panic!("unwrap_cell_value: Expected cell to exist"));
140
141 self.cells.get(&cell).map(|value| f(&value)).unwrap()
142 }
143
144 pub fn version(&self) -> u32 {
145 self.version.load(Ordering::SeqCst)
146 }
147
148 pub fn gc(&mut self, version: u32) {
149 let used_cells: std::collections::HashSet<Cell> = self
150 .cells
151 .iter()
152 .filter_map(|entry| {
153 if entry.value().last_verified_version >= version {
154 Some(entry.key().clone())
155 } else {
156 None
157 }
158 })
159 .collect();
160
161 self.storage.gc(&used_cells);
162 }
163}
164
165impl<S: Storage> Db<S> {
166 pub fn update_input<C>(&mut self, input: C, new_value: C::Output)
174 where
175 C: Computation,
176 S: StorageFor<C>,
177 {
178 let cell_id = self.get_or_insert_cell(input);
179 assert!(
180 self.is_input(cell_id),
181 "`update_input` given a non-input value. Inputs must have 0 dependencies",
182 );
183
184 let changed = self.storage.update_output(cell_id, new_value);
185 let mut cell = self.cells.get_mut(&cell_id).unwrap();
186
187 if changed {
188 let version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
189 cell.last_updated_version = version;
190 cell.last_verified_version = version;
191 } else {
192 cell.last_verified_version = self.version.load(Ordering::SeqCst);
193 }
194 }
195
196 fn is_input(&self, cell: Cell) -> bool {
197 self.with_cell(cell, |cell| {
198 cell.dependencies.is_empty() && cell.input_dependencies.is_empty()
199 })
200 }
201
202 pub fn is_stale<C: Computation>(&self, input: &C) -> bool
207 where
208 S: StorageFor<C>,
209 {
210 let Some(cell) = self.get_cell(input) else {
212 return true;
213 };
214 self.is_stale_cell(cell)
215 }
216
217 fn is_stale_cell(&self, cell: Cell) -> bool {
221 let computation_id = self.with_cell(cell, |data| data.computation_id);
222
223 if self.storage.output_is_unset(cell, computation_id) {
224 return true;
225 }
226
227 let (last_verified, inputs, dependencies) = self.with_cell(cell, |data| {
229 (
230 data.last_verified_version,
231 data.input_dependencies.clone(),
232 data.dependencies.clone(),
233 )
234 });
235
236 let inputs_changed = inputs.into_iter().any(|input_id| {
239 self.with_cell(input_id, |input| input.last_updated_version > last_verified)
242 });
243
244 inputs_changed
248 && dependencies.into_iter().any(|dependency_id| {
249 self.update_cell(dependency_id);
250 self.with_cell(dependency_id, |dependency| {
251 if computation_id == ACCUMULATED_COMPUTATION_ID {
252 dependency.last_run_version > last_verified
253 } else {
254 dependency.last_updated_version > last_verified
255 }
256 })
257 })
258 }
259
260 fn run_compute_function(&self, cell_id: Cell) {
264 let computation_id = self.with_cell(cell_id, |data| data.computation_id);
265 self.storage.clear_accumulated_for_cell(cell_id);
266 let handle = self.handle(cell_id);
267 let changed = S::run_computation(&handle, cell_id, computation_id);
268
269 let version = self.version.load(Ordering::SeqCst);
270 let mut cell = self.cells.get_mut(&cell_id).unwrap();
271 cell.last_verified_version = version;
272 cell.last_run_version = version;
273
274 if changed {
275 cell.last_updated_version = version;
276 }
277 }
278
279 fn update_cell(&self, cell_id: Cell) {
282 let last_verified_version = self.with_cell(cell_id, |data| data.last_verified_version);
283 let version = self.version.load(Ordering::SeqCst);
284
285 if last_verified_version != version {
286 if self.is_stale_cell(cell_id) {
288 let lock = self.with_cell(cell_id, |cell| cell.lock.clone());
289
290 match lock.try_lock() {
291 Some(guard) => {
292 self.run_compute_function(cell_id);
293 drop(guard);
294 }
295 None => {
296 self.check_for_cycle(cell_id);
300
301 drop(lock.lock());
303 }
304 }
305 } else {
306 let mut cell = self.cells.get_mut(&cell_id).unwrap();
307 cell.last_verified_version = version;
308 }
309 }
310 }
311
312 fn check_for_cycle(&self, starting_cell: Cell) {
314 let mut visited = FxHashSet::default();
315 let mut path = Vec::new();
316
317 let mut stack = Vec::new();
322 stack.push(Action::Traverse(starting_cell));
323
324 enum Action {
325 Traverse(Cell),
326 Pop(Cell),
327 }
328
329 while let Some(action) = stack.pop() {
330 match action {
331 Action::Pop(expected) => assert_eq!(path.pop(), Some(expected)),
333 Action::Traverse(cell) => {
334 if path.contains(&cell) {
335 path.push(cell);
337 self.cycle_error(&path);
338 }
339
340 if visited.insert(cell) {
341 path.push(cell);
342 stack.push(Action::Pop(cell));
343 self.with_cell(cell, |cell| {
344 for dependency in cell.dependencies.iter() {
345 stack.push(Action::Traverse(*dependency));
346 }
347 });
348 }
349 }
350 }
351 }
352 }
353
354 fn cycle_error(&self, cycle: &[Cell]) {
356 let mut error = String::new();
357 for (i, cell) in cycle.iter().enumerate() {
358 error += &format!(
359 "\n {}. {}",
360 i + 1,
361 self.storage.input_debug_string(self, *cell)
362 );
363 }
364 panic!("inc-complete: Cycle Detected!\n\nCycle:{error}")
365 }
366
367 pub fn get<C: Computation>(&self, compute: C) -> C::Output
375 where
376 S: StorageFor<C>,
377 {
378 let cell_id = self.get_or_insert_cell(compute);
379 self.get_with_cell::<C>(cell_id)
380 }
381
382 pub(crate) fn get_with_cell<Concrete: Computation>(&self, cell_id: Cell) -> Concrete::Output
383 where
384 S: StorageFor<Concrete>,
385 {
386 self.update_cell(cell_id);
387
388 self.storage
389 .get_output(cell_id)
390 .expect("cell result should have been computed already")
391 }
392
393 fn with_cell<R>(&self, cell: Cell, f: impl FnOnce(&CellData) -> R) -> R {
394 f(&self.cells.get(&cell).unwrap())
395 }
396
397 pub fn get_accumulated<Item, C>(&self, compute: C) -> BTreeSet<Item>
409 where
410 S: StorageFor<C> + StorageFor<Accumulated<Item>>,
411 C: Computation,
412 Item: 'static,
413 {
414 let cell_id = self.get_or_insert_cell(compute);
415 self.update_cell(cell_id);
416 self.get(Accumulated::<Item>::new(cell_id))
417 }
418
419 pub fn get_accumulated_uncached<Item, C>(&mut self, compute: C) -> BTreeSet<Item>
428 where
429 S: StorageFor<C> + StorageFor<Accumulated<Item>> + Accumulate<Item>,
430 C: Computation,
431 Item: 'static + Ord,
432 {
433 let cell_id = self.get_or_insert_cell(compute);
434 self.update_cell(cell_id);
435
436 let mut items = BTreeSet::new();
437 let mut visited = BTreeSet::new();
438 let mut queue = vec![cell_id];
439
440 while let Some(cell) = queue.pop() {
441 if visited.insert(cell) {
442 self.with_cell(cell, |data| queue.extend_from_slice(&data.dependencies));
443 items.extend(self.storage().get_accumulated::<Vec<Item>>(cell));
444 }
445 }
446
447 items
448 }
449}