1use std::sync::atomic::{AtomicU32, Ordering};
2
3use crate::cell::CellData;
4use crate::storage::{ComputationId, StorageFor};
5use crate::{Cell, OutputType, Storage};
6
7mod handle;
8mod serialize;
9mod tests;
10
11pub use handle::DbHandle;
12
13const START_VERSION: u32 = 1;
14
15pub struct Db<Storage> {
20 cells: dashmap::DashMap<Cell, CellData>,
21 version: AtomicU32,
22 next_cell: AtomicU32,
23 storage: Storage,
24}
25
26impl<Storage: Default> Db<Storage> {
27 pub fn new() -> Self {
29 Self::with_storage(Storage::default())
30 }
31}
32
33impl<S: Default> Default for Db<S> {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39#[cfg(not(feature = "async"))]
42pub trait DbGet<C: OutputType> {
43 fn get(&self, key: C) -> C::Output;
44}
45#[cfg(feature = "async")]
46pub trait DbGet<C: OutputType> {
47 fn get(&self, key: C) -> impl Future<Output = C::Output> + Send;
48}
49
50#[cfg(not(feature = "async"))]
51impl<S, C> DbGet<C> for Db<S>
52where
53 C: OutputType + ComputationId,
54 S: Storage + StorageFor<C>,
55{
56 fn get(&self, key: C) -> C::Output {
57 self.get(key)
58 }
59}
60
61#[cfg(feature = "async")]
62impl<S, C> DbGet<C> for Db<S>
63where
64 C: OutputType + ComputationId,
65 S: Storage + StorageFor<C> + Sync,
66{
67 fn get(&self, key: C) -> impl Future<Output = C::Output> + Send {
68 Db::get(self, key)
69 }
70}
71
72impl<S> Db<S> {
73 pub fn with_storage(storage: S) -> Self {
75 Self {
76 cells: Default::default(),
77 version: AtomicU32::new(START_VERSION),
78 next_cell: AtomicU32::new(0),
79 storage,
80 }
81 }
82
83 pub fn storage(&self) -> &S {
85 &self.storage
86 }
87
88 pub fn storage_mut(&mut self) -> &mut S {
93 &mut self.storage
94 }
95}
96
97impl<S: Storage> Db<S> {
98 fn get_cell<C: OutputType>(&self, computation: &C) -> Option<Cell>
102 where
103 S: StorageFor<C>,
104 {
105 self.storage.get_cell_for_computation(computation)
106 }
107
108 pub(crate) fn get_or_insert_cell<C>(&self, input: C) -> Cell
109 where
110 C: OutputType + ComputationId,
111 S: StorageFor<C>,
112 {
113 if let Some(cell) = self.get_cell(&input) {
114 cell
115 } else {
116 let computation_id = C::computation_id();
117
118 let cell_id = self.next_cell.fetch_add(1, Ordering::Relaxed);
121 let new_cell = Cell::new(cell_id);
122
123 self.cells.insert(new_cell, CellData::new(computation_id));
124 self.storage.insert_new_cell(new_cell, input);
125 new_cell
126 }
127 }
128
129 fn handle(&self, cell: Cell) -> DbHandle<S> {
130 DbHandle::new(self, cell)
131 }
132
133 #[cfg(test)]
134 #[allow(unused)]
135 pub(crate) fn with_cell_data<C: OutputType>(&self, input: &C, f: impl FnOnce(&CellData))
136 where
137 S: StorageFor<C>,
138 {
139 let cell = self
140 .get_cell(input)
141 .unwrap_or_else(|| panic!("unwrap_cell_value: Expected cell to exist"));
142
143 self.cells.get(&cell).map(|value| f(&value)).unwrap()
144 }
145
146 pub fn version(&self) -> u32 {
147 self.version.load(Ordering::SeqCst)
148 }
149}
150
151#[cfg(not(feature = "async"))]
152impl<S: Storage> Db<S> {
153 pub fn update_input<C>(&mut self, input: C, new_value: C::Output)
161 where
162 C: OutputType + ComputationId,
163 S: StorageFor<C>,
164 {
165 let cell_id = self.get_or_insert_cell(input);
166 assert!(
167 self.is_input(cell_id),
168 "`update_input` given a non-input value. Inputs must have 0 dependencies",
169 );
170
171 let changed = self.storage.update_output(cell_id, new_value);
172 let mut cell = self.cells.get_mut(&cell_id).unwrap();
173
174 if changed {
175 let version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
176 cell.last_updated_version = version;
177 cell.last_verified_version = version;
178 } else {
179 cell.last_verified_version = self.version.load(Ordering::SeqCst);
180 }
181 }
182
183 fn is_input(&self, cell: Cell) -> bool {
184 self.with_cell(cell, |cell| cell.dependencies.is_empty() && cell.input_dependencies.is_empty())
185 }
186
187 pub fn is_stale<C: OutputType>(&self, input: &C) -> bool
192 where
193 S: StorageFor<C>,
194 {
195 let Some(cell) = self.get_cell(input) else {
197 return true;
198 };
199 self.is_stale_cell(cell)
200 }
201
202 fn is_stale_cell(&self, cell: Cell) -> bool {
206 let computation_id = self.with_cell(cell, |data| data.computation_id);
207
208 if self.storage.output_is_unset(cell, computation_id) {
209 return true;
210 }
211
212 let (last_verified, inputs, dependencies) = self.with_cell(cell, |data| {
214 (data.last_verified_version, data.input_dependencies.clone(), data.dependencies.clone())
215 });
216
217 let inputs_changed = inputs.into_iter().any(|input_id| {
220 self.with_cell(input_id, |input| {
223 input.last_updated_version > last_verified
224 })
225 });
226
227 inputs_changed && dependencies.into_iter().any(|dependency_id| {
231 self.update_cell(dependency_id);
232 self.with_cell(dependency_id, |dependency| {
233 dependency.last_updated_version > last_verified
234 })
235 })
236 }
237
238 fn run_compute_function(&self, cell_id: Cell) {
242 let computation_id = self.with_cell(cell_id, |data| data.computation_id);
243
244 let handle = self.handle(cell_id);
245 let changed = S::run_computation(&handle, cell_id, computation_id);
246
247 let version = self.version.load(Ordering::SeqCst);
248 let mut cell = self.cells.get_mut(&cell_id).unwrap();
249 cell.last_verified_version = version;
250
251 if changed {
252 cell.last_updated_version = version;
253 }
254 }
255
256 fn update_cell(&self, cell_id: Cell) {
259 let last_verified_version = self.with_cell(cell_id, |data| data.last_verified_version);
260 let version = self.version.load(Ordering::SeqCst);
261
262 if last_verified_version != version {
263 if self.is_stale_cell(cell_id) {
265 let lock = self.with_cell(cell_id, |cell| cell.lock.clone());
266
267 match lock.try_lock() {
268 Some(guard) => {
269 self.run_compute_function(cell_id);
270 drop(guard);
271 }
272 None => {
273 drop(lock.lock());
276 }
277 }
278 } else {
279 let mut cell = self.cells.get_mut(&cell_id).unwrap();
280 cell.last_verified_version = version;
281 }
282 }
283 }
284
285 pub fn get<C: OutputType + ComputationId>(&self, compute: C) -> C::Output
293 where
294 S: StorageFor<C>,
295 {
296 let cell_id = self.get_or_insert_cell(compute);
297 self.get_with_cell::<C>(cell_id)
298 }
299
300 pub(crate) fn get_with_cell<Concrete: OutputType>(&self, cell_id: Cell) -> Concrete::Output
301 where
302 S: StorageFor<Concrete>,
303 {
304 self.update_cell(cell_id);
305
306 self.storage
307 .get_output(cell_id)
308 .expect("cell result should have been computed already")
309 }
310
311 fn with_cell<R>(&self, cell: Cell, f: impl FnOnce(&CellData) -> R) -> R {
312 f(&self.cells.get(&cell).unwrap())
313 }
314}
315
316#[cfg(feature = "async")]
317impl<S: Storage + Sync> Db<S> {
318 pub async fn update_input<C: OutputType>(&mut self, input: C, new_value: C::Output)
323 where
324 C: ComputationId,
325 S: StorageFor<C>,
326 {
327 let cell_id = self.get_or_insert_cell(input);
328 debug_assert!(
329 self.is_input(cell_id).await,
330 "`update_input` given a non-input value. Inputs must have 0 dependencies",
331 );
332
333 let changed = self.storage.update_output(cell_id, new_value);
334 let mut cell = self.cells.get(&cell_id).unwrap();
335
336 if changed {
337 let version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
338 cell.last_updated_version = version;
339 cell.last_verified_version = version;
340 } else {
341 cell.last_verified_version = self.version.load(Ordering::SeqCst);
342 }
343 }
344
345 fn is_input(&self, cell: Cell) -> impl Future<Output = bool> {
346 self.with_cell(cell, |cell| cell.dependencies.len() == 0)
347 }
348
349 pub async fn is_stale<C: OutputType>(&self, input: &C) -> bool
354 where
355 S: StorageFor<C>,
356 {
357 let Some(cell) = self.get_cell(input) else {
359 return true;
360 };
361 self.is_stale_cell(cell).await
362 }
363
364 async fn is_stale_cell(&self, cell: Cell) -> bool {
368 let computation_id = self.with_cell(cell, |data| data.computation_id).await;
369
370 if self.storage.output_is_unset(cell, computation_id) {
371 return true;
372 }
373
374 let (last_verified, dependencies) = self
376 .with_cell(cell, |data| {
377 (data.last_verified_version, data.dependencies.clone())
378 })
379 .await;
380
381 for dependency_id in dependencies {
385 self.update_cell(dependency_id).await;
386
387 if self
390 .cells
391 .read(&dependency_id, |_, dependency| {
392 dependency.last_updated_version > last_verified
393 })
394 .unwrap()
395 {
396 return true;
397 }
398 }
399 false
400 }
401
402 async fn run_compute_function(&self, cell_id: Cell) {
406 let computation_id = self.with_cell(cell_id, |data| data.computation_id).await;
407
408 let handle = self.handle(cell_id);
409 let changed = S::run_computation(&handle, cell_id, computation_id).await;
410
411 let version = self.version.load(Ordering::SeqCst);
412 let mut cell = self.cells.get_async(&cell_id).await.unwrap();
413 cell.last_verified_version = version;
414
415 if changed {
416 cell.last_updated_version = version;
417 }
418 }
419
420 async fn update_cell(&self, cell_id: Cell) {
423 let last_verified_version = self
424 .with_cell(cell_id, |data| data.last_verified_version)
425 .await;
426 let version = self.version.load(Ordering::SeqCst);
427
428 if last_verified_version != version {
429 if Box::pin(self.is_stale_cell(cell_id)).await {
431 self.run_compute_function(cell_id).await;
432 } else {
433 let mut cell = self.cells.get_async(&cell_id).await.unwrap();
434 cell.last_verified_version = version;
435 }
436 }
437 }
438
439 pub fn get<C: OutputType + ComputationId>(&self, compute: C) -> impl Future<Output = C::Output>
444 where
445 S: StorageFor<C>,
446 {
447 let cell_id = self.get_or_insert_cell(compute);
448 self.get_with_cell::<C>(cell_id)
449 }
450
451 pub(crate) fn get_with_cell<Concrete: OutputType>(
452 &self,
453 cell_id: Cell,
454 ) -> impl Future<Output = Concrete::Output> + Send
455 where
456 S: StorageFor<Concrete>,
457 {
458 async move {
459 self.update_cell(cell_id).await;
460
461 self.storage
462 .get_output(cell_id)
463 .expect("cell result should have been computed already")
464 }
465 }
466
467 async fn with_cell<R>(&self, cell: Cell, f: impl FnOnce(&CellData) -> R) -> R {
468 self.cells
469 .read_async(&cell, |_, data| f(data))
470 .await
471 .unwrap()
472 }
473}