1use std::sync::atomic::{AtomicU32, Ordering};
2
3use crate::accumulate::Accumulate;
4use crate::cell::CellData;
5use crate::storage::{ComputationId, StorageFor};
6use crate::{Cell, OutputType, Storage};
7
8mod handle;
9mod serialize;
10mod tests;
11
12pub use handle::DbHandle;
13
14const START_VERSION: u32 = 1;
15
16pub struct Db<Storage> {
21 cells: dashmap::DashMap<Cell, CellData>,
22 version: AtomicU32,
23 next_cell: AtomicU32,
24 storage: Storage,
25}
26
27impl<Storage: Default> Db<Storage> {
28 pub fn new() -> Self {
30 Self::with_storage(Storage::default())
31 }
32}
33
34impl<S: Default> Default for Db<S> {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40pub trait DbGet<C: OutputType> {
43 fn get(&self, key: C) -> C::Output;
46}
47
48impl<S, C> DbGet<C> for Db<S>
49where
50 C: OutputType + ComputationId,
51 S: Storage + StorageFor<C>,
52{
53 fn get(&self, key: C) -> C::Output {
54 self.get(key)
55 }
56}
57
58impl<S> Db<S> {
59 pub fn with_storage(storage: S) -> Self {
61 Self {
62 cells: Default::default(),
63 version: AtomicU32::new(START_VERSION),
64 next_cell: AtomicU32::new(0),
65 storage,
66 }
67 }
68
69 pub fn storage(&self) -> &S {
71 &self.storage
72 }
73
74 pub fn storage_mut(&mut self) -> &mut S {
79 &mut self.storage
80 }
81}
82
83impl<S: Storage> Db<S> {
84 fn get_cell<C: OutputType>(&self, computation: &C) -> Option<Cell>
88 where
89 S: StorageFor<C>,
90 {
91 self.storage.get_cell_for_computation(computation)
92 }
93
94 pub(crate) fn get_or_insert_cell<C>(&self, input: C) -> Cell
95 where
96 C: OutputType + ComputationId,
97 S: StorageFor<C>,
98 {
99 if let Some(cell) = self.get_cell(&input) {
100 cell
101 } else {
102 let computation_id = C::computation_id();
103
104 let cell_id = self.next_cell.fetch_add(1, Ordering::Relaxed);
107 let new_cell = Cell::new(cell_id);
108
109 self.cells.insert(new_cell, CellData::new(computation_id));
110 self.storage.insert_new_cell(new_cell, input);
111 new_cell
112 }
113 }
114
115 fn handle(&self, cell: Cell) -> DbHandle<S> {
116 DbHandle::new(self, cell)
117 }
118
119 #[cfg(test)]
120 #[allow(unused)]
121 pub(crate) fn with_cell_data<C: OutputType>(&self, input: &C, f: impl FnOnce(&CellData))
122 where
123 S: StorageFor<C>,
124 {
125 let cell = self
126 .get_cell(input)
127 .unwrap_or_else(|| panic!("unwrap_cell_value: Expected cell to exist"));
128
129 self.cells.get(&cell).map(|value| f(&value)).unwrap()
130 }
131
132 pub fn version(&self) -> u32 {
133 self.version.load(Ordering::SeqCst)
134 }
135
136
137 pub fn gc(&mut self, version: u32) {
138 let used_cells: std::collections::HashSet<Cell> = self
139 .cells
140 .iter()
141 .filter_map(|entry| {
142 if entry.value().last_verified_version >= version {
143 Some(entry.key().clone())
144 } else {
145 None
146 }
147 })
148 .collect();
149
150 self.storage.gc(&used_cells);
151 }
152}
153
154impl<S: Storage> Db<S> {
155 pub fn update_input<C>(&mut self, input: C, new_value: C::Output)
163 where
164 C: OutputType + ComputationId,
165 S: StorageFor<C>,
166 {
167 let cell_id = self.get_or_insert_cell(input);
168 assert!(
169 self.is_input(cell_id),
170 "`update_input` given a non-input value. Inputs must have 0 dependencies",
171 );
172
173 let changed = self.storage.update_output(cell_id, new_value);
174 let mut cell = self.cells.get_mut(&cell_id).unwrap();
175
176 if changed {
177 let version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
178 cell.last_updated_version = version;
179 cell.last_verified_version = version;
180 } else {
181 cell.last_verified_version = self.version.load(Ordering::SeqCst);
182 }
183 }
184
185 fn is_input(&self, cell: Cell) -> bool {
186 self.with_cell(cell, |cell| {
187 cell.dependencies.is_empty() && cell.input_dependencies.is_empty()
188 })
189 }
190
191 pub fn is_stale<C: OutputType>(&self, input: &C) -> bool
196 where
197 S: StorageFor<C>,
198 {
199 let Some(cell) = self.get_cell(input) else {
201 return true;
202 };
203 self.is_stale_cell(cell)
204 }
205
206 fn is_stale_cell(&self, cell: Cell) -> bool {
210 let computation_id = self.with_cell(cell, |data| data.computation_id);
211
212 if self.storage.output_is_unset(cell, computation_id) {
213 return true;
214 }
215
216 let (last_verified, inputs, dependencies) = self.with_cell(cell, |data| {
218 (
219 data.last_verified_version,
220 data.input_dependencies.clone(),
221 data.dependencies.clone(),
222 )
223 });
224
225 let inputs_changed = inputs.into_iter().any(|input_id| {
228 self.with_cell(input_id, |input| input.last_updated_version > last_verified)
231 });
232
233 inputs_changed
237 && dependencies.into_iter().any(|dependency_id| {
238 self.update_cell(dependency_id);
239 self.with_cell(dependency_id, |dependency| {
240 dependency.last_updated_version > last_verified
241 })
242 })
243 }
244
245 fn run_compute_function(&self, cell_id: Cell) {
249 let computation_id = self.with_cell(cell_id, |data| data.computation_id);
250
251 let handle = self.handle(cell_id);
252 let changed = S::run_computation(&handle, cell_id, computation_id);
253
254 let version = self.version.load(Ordering::SeqCst);
255 let mut cell = self.cells.get_mut(&cell_id).unwrap();
256 cell.last_verified_version = version;
257
258 if changed {
259 cell.last_updated_version = version;
260 }
261 }
262
263 fn update_cell(&self, cell_id: Cell) {
266 let last_verified_version = self.with_cell(cell_id, |data| data.last_verified_version);
267 let version = self.version.load(Ordering::SeqCst);
268
269 if last_verified_version != version {
270 if self.is_stale_cell(cell_id) {
272 let lock = self.with_cell(cell_id, |cell| cell.lock.clone());
273
274 match lock.try_lock() {
275 Some(guard) => {
276 self.run_compute_function(cell_id);
277 drop(guard);
278 }
279 None => {
280 drop(lock.lock());
283 }
284 }
285 } else {
286 let mut cell = self.cells.get_mut(&cell_id).unwrap();
287 cell.last_verified_version = version;
288 }
289 }
290 }
291
292 pub fn get<C: OutputType + ComputationId>(&self, compute: C) -> C::Output
300 where
301 S: StorageFor<C>,
302 {
303 let cell_id = self.get_or_insert_cell(compute);
304 self.get_with_cell::<C>(cell_id)
305 }
306
307 pub(crate) fn get_with_cell<Concrete: OutputType>(&self, cell_id: Cell) -> Concrete::Output
308 where
309 S: StorageFor<Concrete>,
310 {
311 self.update_cell(cell_id);
312
313 self.storage
314 .get_output(cell_id)
315 .expect("cell result should have been computed already")
316 }
317
318 fn with_cell<R>(&self, cell: Cell, f: impl FnOnce(&CellData) -> R) -> R {
319 f(&self.cells.get(&cell).unwrap())
320 }
321
322 fn collect_all_dependencies(&self, operation: Cell) -> Vec<Cell> {
326 self.update_cell(operation);
328
329 let mut queue: Vec<_> = self.with_cell(operation, |cell| {
330 cell.dependencies.iter().copied().collect()
331 });
332
333 let mut cells = Vec::new();
334 cells.push(operation);
335
336 while let Some(dependency) = queue.pop() {
339 cells.push(dependency);
340
341 self.with_cell(dependency, |cell| {
342 for dependency in cell.dependencies.iter() {
343 queue.push(*dependency);
344 }
345 });
346 }
347
348 cells.reverse();
349 cells
350 }
351
352 pub fn get_accumulated<Container, Item, C>(&self, compute: C) -> Container where
361 Container: FromIterator<Item>,
362 S: Accumulate<Item> + StorageFor<C>,
363 C: OutputType + ComputationId
364 {
365 let cell_id = self.get_or_insert_cell(compute);
366
367 let cells = self.collect_all_dependencies(cell_id);
368 self.storage.get_accumulated(&cells)
369 }
370}