1use std::sync::atomic::{AtomicU32, Ordering};
2
3use crate::cell::CellData;
4use crate::storage::{ComputationId, StorageFor};
5use crate::{Cell, OutputType, Storage};
6
7mod handle;
8mod serialize;
9mod tests;
10
11pub use handle::DbHandle;
12
13const START_VERSION: u32 = 1;
14
15pub struct Db<Storage> {
20 cells: dashmap::DashMap<Cell, CellData>,
21 version: AtomicU32,
22 next_cell: AtomicU32,
23 storage: Storage,
24}
25
26impl<Storage: Default> Db<Storage> {
27 pub fn new() -> Self {
29 Self::with_storage(Storage::default())
30 }
31}
32
33impl<S: Default> Default for Db<S> {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39#[cfg(not(feature = "async"))]
42pub trait DbGet<C: OutputType> {
43 fn get(&self, key: C) -> C::Output;
44}
45#[cfg(feature = "async")]
46pub trait DbGet<C: OutputType> {
47 fn get(&self, key: C) -> impl Future<Output = C::Output> + Send;
48}
49
50#[cfg(not(feature = "async"))]
51impl<S, C> DbGet<C> for Db<S>
52where
53 C: OutputType + ComputationId,
54 S: Storage + StorageFor<C>,
55{
56 fn get(&self, key: C) -> C::Output {
57 self.get(key)
58 }
59}
60
61#[cfg(feature = "async")]
62impl<S, C> DbGet<C> for Db<S>
63where
64 C: OutputType + ComputationId,
65 S: Storage + StorageFor<C> + Sync,
66{
67 fn get(&self, key: C) -> impl Future<Output = C::Output> + Send {
68 Db::get(self, key)
69 }
70}
71
72impl<S> Db<S> {
73 pub fn with_storage(storage: S) -> Self {
75 Self {
76 cells: Default::default(),
77 version: AtomicU32::new(START_VERSION),
78 next_cell: AtomicU32::new(0),
79 storage,
80 }
81 }
82
83 pub fn storage(&self) -> &S {
85 &self.storage
86 }
87
88 pub fn storage_mut(&mut self) -> &mut S {
93 &mut self.storage
94 }
95}
96
97impl<S: Storage> Db<S> {
98 fn get_cell<C: OutputType>(&self, computation: &C) -> Option<Cell>
102 where
103 S: StorageFor<C>,
104 {
105 self.storage.get_cell_for_computation(computation)
106 }
107
108 pub(crate) fn get_or_insert_cell<C>(&self, input: C) -> Cell
109 where
110 C: OutputType + ComputationId,
111 S: StorageFor<C>,
112 {
113 if let Some(cell) = self.get_cell(&input) {
114 cell
115 } else {
116 let computation_id = C::computation_id();
117
118 let cell_id = self.next_cell.fetch_add(1, Ordering::Relaxed);
121 let new_cell = Cell::new(cell_id);
122
123 self.cells.insert(new_cell, CellData::new(computation_id));
124 self.storage.insert_new_cell(new_cell, input);
125 new_cell
126 }
127 }
128
129 fn handle(&self, cell: Cell) -> DbHandle<S> {
130 DbHandle::new(self, cell)
131 }
132
133 #[cfg(test)]
134 #[allow(unused)]
135 pub(crate) fn with_cell_data<C: OutputType>(&self, input: &C, f: impl FnOnce(&CellData))
136 where
137 S: StorageFor<C>,
138 {
139 let cell = self
140 .get_cell(input)
141 .unwrap_or_else(|| panic!("unwrap_cell_value: Expected cell to exist"));
142
143 self.cells.get(&cell).map(|value| f(&value)).unwrap()
144 }
145
146 pub fn version(&self) -> u32 {
147 self.version.load(Ordering::SeqCst)
148 }
149}
150
151#[cfg(not(feature = "async"))]
152impl<S: Storage> Db<S> {
153 pub fn update_input<C>(&mut self, input: C, new_value: C::Output)
161 where
162 C: OutputType + ComputationId,
163 S: StorageFor<C>,
164 {
165 let cell_id = self.get_or_insert_cell(input);
166 assert!(
167 self.is_input(cell_id),
168 "`update_input` given a non-input value. Inputs must have 0 dependencies",
169 );
170
171 let changed = self.storage.update_output(cell_id, new_value);
172 let mut cell = self.cells.get_mut(&cell_id).unwrap();
173
174 if changed {
175 let version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
176 cell.last_updated_version = version;
177 cell.last_verified_version = version;
178 } else {
179 cell.last_verified_version = self.version.load(Ordering::SeqCst);
180 }
181 }
182
183 fn is_input(&self, cell: Cell) -> bool {
184 self.with_cell(cell, |cell| cell.dependencies.is_empty())
185 }
186
187 pub fn is_stale<C: OutputType>(&self, input: &C) -> bool
192 where
193 S: StorageFor<C>,
194 {
195 let Some(cell) = self.get_cell(input) else {
197 return true;
198 };
199 self.is_stale_cell(cell)
200 }
201
202 fn is_stale_cell(&self, cell: Cell) -> bool {
206 let computation_id = self.with_cell(cell, |data| data.computation_id);
207
208 if self.storage.output_is_unset(cell, computation_id) {
209 return true;
210 }
211
212 let (last_verified, dependencies) = self.with_cell(cell, |data| {
214 (data.last_verified_version, data.dependencies.clone())
215 });
216
217 dependencies.iter().any(|dependency_id| {
221 self.update_cell(*dependency_id);
222
223 self.with_cell(*dependency_id, |dependency| {
226 dependency.last_updated_version > last_verified
227 })
228 })
229 }
230
231 fn run_compute_function(&self, cell_id: Cell) {
235 let computation_id = self.with_cell(cell_id, |data| data.computation_id);
236
237 let handle = self.handle(cell_id);
238 let changed = S::run_computation(&handle, cell_id, computation_id);
239
240 let version = self.version.load(Ordering::SeqCst);
241 let mut cell = self.cells.get_mut(&cell_id).unwrap();
242 cell.last_verified_version = version;
243
244 if changed {
245 cell.last_updated_version = version;
246 }
247 }
248
249 fn update_cell(&self, cell_id: Cell) {
252 let last_verified_version = self.with_cell(cell_id, |data| data.last_verified_version);
253 let version = self.version.load(Ordering::SeqCst);
254
255 if last_verified_version != version {
256 if self.is_stale_cell(cell_id) {
258 let lock = self.with_cell(cell_id, |cell| cell.lock.clone());
259
260 match lock.try_lock() {
261 Some(guard) => {
262 self.run_compute_function(cell_id);
263 drop(guard);
264 }
265 None => {
266 drop(lock.lock());
269 }
270 }
271 } else {
272 let mut cell = self.cells.get_mut(&cell_id).unwrap();
273 cell.last_verified_version = version;
274 }
275 }
276 }
277
278 pub fn get<C: OutputType + ComputationId>(&self, compute: C) -> C::Output
286 where
287 S: StorageFor<C>,
288 {
289 let cell_id = self.get_or_insert_cell(compute);
290 self.get_with_cell::<C>(cell_id)
291 }
292
293 pub(crate) fn get_with_cell<Concrete: OutputType>(&self, cell_id: Cell) -> Concrete::Output
294 where
295 S: StorageFor<Concrete>,
296 {
297 self.update_cell(cell_id);
298
299 self.storage
300 .get_output(cell_id)
301 .expect("cell result should have been computed already")
302 }
303
304 fn with_cell<R>(&self, cell: Cell, f: impl FnOnce(&CellData) -> R) -> R {
305 f(&self.cells.get(&cell).unwrap())
306 }
307}
308
309#[cfg(feature = "async")]
310impl<S: Storage + Sync> Db<S> {
311 pub async fn update_input<C: OutputType>(&mut self, input: C, new_value: C::Output)
316 where
317 C: ComputationId,
318 S: StorageFor<C>,
319 {
320 let cell_id = self.get_or_insert_cell(input);
321 debug_assert!(
322 self.is_input(cell_id).await,
323 "`update_input` given a non-input value. Inputs must have 0 dependencies",
324 );
325
326 let changed = self.storage.update_output(cell_id, new_value);
327 let mut cell = self.cells.get(&cell_id).unwrap();
328
329 if changed {
330 let version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
331 cell.last_updated_version = version;
332 cell.last_verified_version = version;
333 } else {
334 cell.last_verified_version = self.version.load(Ordering::SeqCst);
335 }
336 }
337
338 fn is_input(&self, cell: Cell) -> impl Future<Output = bool> {
339 self.with_cell(cell, |cell| cell.dependencies.len() == 0)
340 }
341
342 pub async fn is_stale<C: OutputType>(&self, input: &C) -> bool
347 where
348 S: StorageFor<C>,
349 {
350 let Some(cell) = self.get_cell(input) else {
352 return true;
353 };
354 self.is_stale_cell(cell).await
355 }
356
357 async fn is_stale_cell(&self, cell: Cell) -> bool {
361 let computation_id = self.with_cell(cell, |data| data.computation_id).await;
362
363 if self.storage.output_is_unset(cell, computation_id) {
364 return true;
365 }
366
367 let (last_verified, dependencies) = self
369 .with_cell(cell, |data| {
370 (data.last_verified_version, data.dependencies.clone())
371 })
372 .await;
373
374 for dependency_id in dependencies {
378 self.update_cell(dependency_id).await;
379
380 if self
383 .cells
384 .read(&dependency_id, |_, dependency| {
385 dependency.last_updated_version > last_verified
386 })
387 .unwrap()
388 {
389 return true;
390 }
391 }
392 false
393 }
394
395 async fn run_compute_function(&self, cell_id: Cell) {
399 let computation_id = self.with_cell(cell_id, |data| data.computation_id).await;
400
401 let handle = self.handle(cell_id);
402 let changed = S::run_computation(&handle, cell_id, computation_id).await;
403
404 let version = self.version.load(Ordering::SeqCst);
405 let mut cell = self.cells.get_async(&cell_id).await.unwrap();
406 cell.last_verified_version = version;
407
408 if changed {
409 cell.last_updated_version = version;
410 }
411 }
412
413 async fn update_cell(&self, cell_id: Cell) {
416 let last_verified_version = self
417 .with_cell(cell_id, |data| data.last_verified_version)
418 .await;
419 let version = self.version.load(Ordering::SeqCst);
420
421 if last_verified_version != version {
422 if Box::pin(self.is_stale_cell(cell_id)).await {
424 self.run_compute_function(cell_id).await;
425 } else {
426 let mut cell = self.cells.get_async(&cell_id).await.unwrap();
427 cell.last_verified_version = version;
428 }
429 }
430 }
431
432 pub fn get<C: OutputType + ComputationId>(&self, compute: C) -> impl Future<Output = C::Output>
437 where
438 S: StorageFor<C>,
439 {
440 let cell_id = self.get_or_insert_cell(compute);
441 self.get_with_cell::<C>(cell_id)
442 }
443
444 pub(crate) fn get_with_cell<Concrete: OutputType>(
445 &self,
446 cell_id: Cell,
447 ) -> impl Future<Output = Concrete::Output> + Send
448 where
449 S: StorageFor<Concrete>,
450 {
451 async move {
452 self.update_cell(cell_id).await;
453
454 self.storage
455 .get_output(cell_id)
456 .expect("cell result should have been computed already")
457 }
458 }
459
460 async fn with_cell<R>(&self, cell: Cell, f: impl FnOnce(&CellData) -> R) -> R {
461 self.cells
462 .read_async(&cell, |_, data| f(data))
463 .await
464 .unwrap()
465 }
466}