1use std::collections::BTreeSet;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU32, Ordering};
4
5use crate::accumulate::Accumulated;
6use crate::cell::CellData;
7use crate::storage::StorageFor;
8use crate::{Cell, Computation, Storage};
9
10pub mod debug_with_db;
11mod handle;
12mod serialize;
13mod tests;
14
15pub use handle::DbHandle;
16use parking_lot::Mutex;
17use rustc_hash::FxHashSet;
18
19const START_VERSION: u32 = 1;
20
21pub struct Db<Storage> {
26 cells: dashmap::DashMap<Cell, CellData>,
27 version: AtomicU32,
28 next_cell: AtomicU32,
29 storage: Storage,
30
31 cell_locks: dashmap::DashMap<u32, Arc<Mutex<()>>>,
34}
35
36impl<Storage: Default> Db<Storage> {
37 pub fn new() -> Self {
39 Self::with_storage(Storage::default())
40 }
41}
42
43impl<S: Default> Default for Db<S> {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49pub trait DbGet<C: Computation> {
52 fn get(&self, key: C) -> C::Output;
55}
56
57impl<S, C> DbGet<C> for Db<S>
58where
59 C: Computation,
60 S: Storage + StorageFor<C>,
61{
62 fn get(&self, key: C) -> C::Output {
63 self.get(key)
64 }
65}
66
67impl<S> Db<S> {
68 pub fn with_storage(storage: S) -> Self {
70 Self {
71 cells: Default::default(),
72 version: AtomicU32::new(START_VERSION),
73 next_cell: AtomicU32::new(0),
74 cell_locks: Default::default(),
75 storage,
76 }
77 }
78
79 pub fn storage(&self) -> &S {
81 &self.storage
82 }
83
84 pub fn storage_mut(&mut self) -> &mut S {
89 &mut self.storage
90 }
91}
92
93impl<S: Storage> Db<S> {
94 fn get_cell<C: Computation>(&self, computation: &C) -> Option<Cell>
98 where
99 S: StorageFor<C>,
100 {
101 self.storage.get_cell_for_computation(computation)
102 }
103
104 pub(crate) fn get_or_insert_cell<C>(&self, input: C) -> Cell
105 where
106 C: Computation,
107 S: StorageFor<C>,
108 {
109 let computation_id = C::computation_id();
110 let lock = self.cell_locks.entry(computation_id).or_default().clone();
111 let _guard = lock.lock();
112
113 if let Some(cell) = self.get_cell(&input) {
114 cell
115 } else {
116 let cell_id = self.next_cell.fetch_add(1, Ordering::Relaxed);
119 let new_cell = Cell::new(cell_id);
120
121 self.cells.insert(new_cell, CellData::new(computation_id));
122 self.storage.insert_new_cell(new_cell, input);
123 new_cell
124 }
125 }
126
127 fn handle(&self, cell: Cell) -> DbHandle<'_, S> {
128 DbHandle::new(self, cell)
129 }
130
131 #[cfg(test)]
132 #[allow(unused)]
133 pub(crate) fn with_cell_data<C: Computation>(&self, input: &C, f: impl FnOnce(&CellData))
134 where
135 S: StorageFor<C>,
136 {
137 let cell = self
138 .get_cell(input)
139 .unwrap_or_else(|| panic!("unwrap_cell_value: Expected cell to exist"));
140
141 self.cells.get(&cell).map(|value| f(&value)).unwrap()
142 }
143
144 pub fn version(&self) -> u32 {
145 self.version.load(Ordering::SeqCst)
146 }
147
148 pub fn gc(&mut self, version: u32) {
149 let used_cells: std::collections::HashSet<Cell> = self
150 .cells
151 .iter()
152 .filter_map(|entry| {
153 if entry.value().last_verified_version >= version {
154 Some(entry.key().clone())
155 } else {
156 None
157 }
158 })
159 .collect();
160
161 self.storage.gc(&used_cells);
162 }
163}
164
165impl<S: Storage> Db<S> {
166 pub fn update_input<C>(&mut self, input: C, new_value: C::Output)
174 where
175 C: Computation,
176 S: StorageFor<C>,
177 {
178 let cell_id = self.get_or_insert_cell(input);
179 assert!(
180 self.is_input(cell_id),
181 "`update_input` given a non-input value. Inputs must have 0 dependencies",
182 );
183
184 let changed = self.storage.update_output(cell_id, new_value);
185 let mut cell = self.cells.get_mut(&cell_id).unwrap();
186
187 if changed {
188 let version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
189 cell.last_updated_version = version;
190 cell.last_verified_version = version;
191 } else {
192 cell.last_verified_version = self.version.load(Ordering::SeqCst);
193 }
194 }
195
196 fn is_input(&self, cell: Cell) -> bool {
197 self.with_cell(cell, |cell| {
198 cell.dependencies.is_empty() && cell.input_dependencies.is_empty()
199 })
200 }
201
202 pub fn is_stale<C: Computation>(&self, input: &C) -> bool
207 where
208 S: StorageFor<C>,
209 {
210 let Some(cell) = self.get_cell(input) else {
212 return true;
213 };
214 self.is_stale_cell(cell)
215 }
216
217 fn is_stale_cell(&self, cell: Cell) -> bool {
221 let computation_id = self.with_cell(cell, |data| data.computation_id);
222
223 if self.storage.output_is_unset(cell, computation_id) {
224 return true;
225 }
226
227 let (last_verified, inputs, dependencies) = self.with_cell(cell, |data| {
229 (
230 data.last_verified_version,
231 data.input_dependencies.clone(),
232 data.dependencies.clone(),
233 )
234 });
235
236 let inputs_changed = inputs.into_iter().any(|input_id| {
239 self.with_cell(input_id, |input| input.last_updated_version > last_verified)
242 });
243
244 inputs_changed
248 && dependencies.into_iter().any(|dependency_id| {
249 self.update_cell(dependency_id);
250 self.with_cell(dependency_id, |dependency| {
251 dependency.last_updated_version > last_verified
252 })
253 })
254 }
255
256 fn run_compute_function(&self, cell_id: Cell) {
260 let computation_id = self.with_cell(cell_id, |data| data.computation_id);
261 self.storage.clear_accumulated_for_cell(cell_id);
262 let handle = self.handle(cell_id);
263 let changed = S::run_computation(&handle, cell_id, computation_id);
264
265 let version = self.version.load(Ordering::SeqCst);
266 let mut cell = self.cells.get_mut(&cell_id).unwrap();
267 cell.last_verified_version = version;
268
269 if changed {
270 cell.last_updated_version = version;
271 }
272 }
273
274 fn update_cell(&self, cell_id: Cell) {
277 let last_verified_version = self.with_cell(cell_id, |data| data.last_verified_version);
278 let version = self.version.load(Ordering::SeqCst);
279
280 if last_verified_version != version {
281 if self.is_stale_cell(cell_id) {
283 let lock = self.with_cell(cell_id, |cell| cell.lock.clone());
284
285 match lock.try_lock() {
286 Some(guard) => {
287 self.run_compute_function(cell_id);
288 drop(guard);
289 }
290 None => {
291 self.check_for_cycle(cell_id);
295
296 drop(lock.lock());
298 }
299 }
300 } else {
301 let mut cell = self.cells.get_mut(&cell_id).unwrap();
302 cell.last_verified_version = version;
303 }
304 }
305 }
306
307 fn check_for_cycle(&self, starting_cell: Cell) {
309 let mut visited = FxHashSet::default();
310 let mut path = Vec::new();
311
312 let mut stack = Vec::new();
317 stack.push(Action::Traverse(starting_cell));
318
319 enum Action {
320 Traverse(Cell),
321 Pop(Cell),
322 }
323
324 while let Some(action) = stack.pop() {
325 match action {
326 Action::Pop(expected) => assert_eq!(path.pop(), Some(expected)),
328 Action::Traverse(cell) => {
329 if path.contains(&cell) {
330 path.push(cell);
332 self.cycle_error(&path);
333 }
334
335 if visited.insert(cell) {
336 path.push(cell);
337 stack.push(Action::Pop(cell));
338 self.with_cell(cell, |cell| {
339 for dependency in cell.dependencies.iter() {
340 stack.push(Action::Traverse(*dependency));
341 }
342 });
343 }
344 }
345 }
346 }
347 }
348
349 fn cycle_error(&self, cycle: &[Cell]) {
351 let mut error = String::new();
352 for (i, cell) in cycle.iter().enumerate() {
353 error += &format!(
354 "\n {}. {}",
355 i + 1,
356 self.storage.input_debug_string(self, *cell)
357 );
358 }
359 panic!("inc-complete: Cycle Detected!\n\nCycle:{error}")
360 }
361
362 pub fn get<C: Computation>(&self, compute: C) -> C::Output
370 where
371 S: StorageFor<C>,
372 {
373 let cell_id = self.get_or_insert_cell(compute);
374 self.get_with_cell::<C>(cell_id)
375 }
376
377 pub(crate) fn get_with_cell<Concrete: Computation>(&self, cell_id: Cell) -> Concrete::Output
378 where
379 S: StorageFor<Concrete>,
380 {
381 self.update_cell(cell_id);
382
383 self.storage
384 .get_output(cell_id)
385 .expect("cell result should have been computed already")
386 }
387
388 fn with_cell<R>(&self, cell: Cell, f: impl FnOnce(&CellData) -> R) -> R {
389 f(&self.cells.get(&cell).unwrap())
390 }
391
392 pub fn get_accumulated<Item, C>(&self, compute: C) -> BTreeSet<Item>
397 where
398 S: StorageFor<C> + StorageFor<Accumulated<Item>>,
399 C: Computation,
400 Item: 'static,
401 {
402 let cell_id = self.get_or_insert_cell(compute);
403 self.update_cell(cell_id);
404 self.get(Accumulated::<Item>::new(cell_id))
405 }
406}