1use std::sync::atomic::{AtomicU32, Ordering};
2
3use crate::cell::CellData;
4use crate::storage::{ComputationId, StorageFor};
5use crate::{Cell, OutputType, Storage};
6
7mod handle;
8mod serialize;
9mod tests;
10
11pub use handle::DbHandle;
12
13const START_VERSION: u32 = 1;
14
15pub struct Db<Storage> {
20 cells: dashmap::DashMap<Cell, CellData>,
21 version: AtomicU32,
22 next_cell: AtomicU32,
23 storage: Storage,
24}
25
26impl<Storage: Default> Db<Storage> {
27 pub fn new() -> Self {
29 Self::with_storage(Storage::default())
30 }
31}
32
33impl<S: Default> Default for Db<S> {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39#[cfg(not(feature = "async"))]
42pub trait DbGet<C: OutputType> {
43 fn get(&self, key: C) -> C::Output;
44}
45#[cfg(feature = "async")]
46pub trait DbGet<C: OutputType> {
47 fn get(&self, key: C) -> impl Future<Output = C::Output> + Send;
48}
49
50#[cfg(not(feature = "async"))]
51impl<S, C> DbGet<C> for Db<S>
52where
53 C: OutputType + ComputationId,
54 S: Storage + StorageFor<C>,
55{
56 fn get(&self, key: C) -> C::Output {
57 self.get(key)
58 }
59}
60
61#[cfg(feature = "async")]
62impl<S, C> DbGet<C> for Db<S>
63where
64 C: OutputType + ComputationId,
65 S: Storage + StorageFor<C> + Sync,
66{
67 fn get(&self, key: C) -> impl Future<Output = C::Output> + Send {
68 Db::get(self, key)
69 }
70}
71
72impl<S> Db<S> {
73 pub fn with_storage(storage: S) -> Self {
75 Self {
76 cells: Default::default(),
77 version: AtomicU32::new(START_VERSION),
78 next_cell: AtomicU32::new(0),
79 storage,
80 }
81 }
82
83 pub fn storage(&self) -> &S {
85 &self.storage
86 }
87
88 pub fn storage_mut(&mut self) -> &mut S {
93 &mut self.storage
94 }
95}
96
97impl<S: Storage> Db<S> {
98 fn get_cell<C: OutputType>(&self, computation: &C) -> Option<Cell>
102 where
103 S: StorageFor<C>,
104 {
105 self.storage.get_cell_for_computation(computation)
106 }
107
108 pub(crate) fn get_or_insert_cell<C>(&self, input: C) -> Cell
109 where
110 C: OutputType + ComputationId,
111 S: StorageFor<C>,
112 {
113 if let Some(cell) = self.get_cell(&input) {
114 cell
115 } else {
116 let computation_id = C::computation_id();
117
118 let cell_id = self.next_cell.fetch_add(1, Ordering::Relaxed);
121 let new_cell = Cell::new(cell_id);
122
123 self.cells.insert(new_cell, CellData::new(computation_id));
124 self.storage.insert_new_cell(new_cell, input);
125 new_cell
126 }
127 }
128
129 fn handle(&self, cell: Cell) -> DbHandle<S> {
130 DbHandle::new(self, cell)
131 }
132
133 #[cfg(test)]
134 #[allow(unused)]
135 pub(crate) fn with_cell_data<C: OutputType>(&self, input: &C, f: impl FnOnce(&CellData))
136 where
137 S: StorageFor<C>,
138 {
139 let cell = self
140 .get_cell(input)
141 .unwrap_or_else(|| panic!("unwrap_cell_value: Expected cell to exist"));
142
143 self.cells.get(&cell).map(|value| f(&value)).unwrap()
144 }
145
146 pub fn version(&self) -> u32 {
147 self.version.load(Ordering::SeqCst)
148 }
149}
150
151#[cfg(not(feature = "async"))]
152impl<S: Storage> Db<S> {
153 pub fn update_input<C>(&mut self, input: C, new_value: C::Output)
161 where
162 C: OutputType + ComputationId,
163 S: StorageFor<C>,
164 {
165 let cell_id = self.get_or_insert_cell(input);
166 assert!(
167 self.is_input(cell_id),
168 "`update_input` given a non-input value. Inputs must have 0 dependencies",
169 );
170
171 let changed = self.storage.update_output(cell_id, new_value);
172 let mut cell = self.cells.get_mut(&cell_id).unwrap();
173
174 if changed {
175 let version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
176 cell.last_updated_version = version;
177 cell.last_verified_version = version;
178 } else {
179 cell.last_verified_version = self.version.load(Ordering::SeqCst);
180 }
181 }
182
183 fn is_input(&self, cell: Cell) -> bool {
184 self.with_cell(cell, |cell| {
185 cell.dependencies.is_empty() && cell.input_dependencies.is_empty()
186 })
187 }
188
189 pub fn is_stale<C: OutputType>(&self, input: &C) -> bool
194 where
195 S: StorageFor<C>,
196 {
197 let Some(cell) = self.get_cell(input) else {
199 return true;
200 };
201 self.is_stale_cell(cell)
202 }
203
204 fn is_stale_cell(&self, cell: Cell) -> bool {
208 let computation_id = self.with_cell(cell, |data| data.computation_id);
209
210 if self.storage.output_is_unset(cell, computation_id) {
211 return true;
212 }
213
214 let (last_verified, inputs, dependencies) = self.with_cell(cell, |data| {
216 (
217 data.last_verified_version,
218 data.input_dependencies.clone(),
219 data.dependencies.clone(),
220 )
221 });
222
223 let inputs_changed = inputs.into_iter().any(|input_id| {
226 self.with_cell(input_id, |input| input.last_updated_version > last_verified)
229 });
230
231 inputs_changed
235 && dependencies.into_iter().any(|dependency_id| {
236 self.update_cell(dependency_id);
237 self.with_cell(dependency_id, |dependency| {
238 dependency.last_updated_version > last_verified
239 })
240 })
241 }
242
243 fn run_compute_function(&self, cell_id: Cell) {
247 let computation_id = self.with_cell(cell_id, |data| data.computation_id);
248
249 let handle = self.handle(cell_id);
250 let changed = S::run_computation(&handle, cell_id, computation_id);
251
252 let version = self.version.load(Ordering::SeqCst);
253 let mut cell = self.cells.get_mut(&cell_id).unwrap();
254 cell.last_verified_version = version;
255
256 if changed {
257 cell.last_updated_version = version;
258 }
259 }
260
261 fn update_cell(&self, cell_id: Cell) {
264 let last_verified_version = self.with_cell(cell_id, |data| data.last_verified_version);
265 let version = self.version.load(Ordering::SeqCst);
266
267 if last_verified_version != version {
268 if self.is_stale_cell(cell_id) {
270 let lock = self.with_cell(cell_id, |cell| cell.lock.clone());
271
272 match lock.try_lock() {
273 Some(guard) => {
274 self.run_compute_function(cell_id);
275 drop(guard);
276 }
277 None => {
278 drop(lock.lock());
281 }
282 }
283 } else {
284 let mut cell = self.cells.get_mut(&cell_id).unwrap();
285 cell.last_verified_version = version;
286 }
287 }
288 }
289
290 pub fn get<C: OutputType + ComputationId>(&self, compute: C) -> C::Output
298 where
299 S: StorageFor<C>,
300 {
301 let cell_id = self.get_or_insert_cell(compute);
302 self.get_with_cell::<C>(cell_id)
303 }
304
305 pub(crate) fn get_with_cell<Concrete: OutputType>(&self, cell_id: Cell) -> Concrete::Output
306 where
307 S: StorageFor<Concrete>,
308 {
309 self.update_cell(cell_id);
310
311 self.storage
312 .get_output(cell_id)
313 .expect("cell result should have been computed already")
314 }
315
316 fn with_cell<R>(&self, cell: Cell, f: impl FnOnce(&CellData) -> R) -> R {
317 f(&self.cells.get(&cell).unwrap())
318 }
319}
320
321#[cfg(feature = "async")]
322impl<S: Storage + Sync> Db<S> {
323 pub async fn update_input<C: OutputType>(&mut self, input: C, new_value: C::Output)
328 where
329 C: ComputationId,
330 S: StorageFor<C>,
331 {
332 let cell_id = self.get_or_insert_cell(input);
333 debug_assert!(
334 self.is_input(cell_id).await,
335 "`update_input` given a non-input value. Inputs must have 0 dependencies",
336 );
337
338 let changed = self.storage.update_output(cell_id, new_value);
339 let mut cell = self.cells.get(&cell_id).unwrap();
340
341 if changed {
342 let version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
343 cell.last_updated_version = version;
344 cell.last_verified_version = version;
345 } else {
346 cell.last_verified_version = self.version.load(Ordering::SeqCst);
347 }
348 }
349
350 fn is_input(&self, cell: Cell) -> impl Future<Output = bool> {
351 self.with_cell(cell, |cell| cell.dependencies.len() == 0)
352 }
353
354 pub async fn is_stale<C: OutputType>(&self, input: &C) -> bool
359 where
360 S: StorageFor<C>,
361 {
362 let Some(cell) = self.get_cell(input) else {
364 return true;
365 };
366 self.is_stale_cell(cell).await
367 }
368
369 async fn is_stale_cell(&self, cell: Cell) -> bool {
373 let computation_id = self.with_cell(cell, |data| data.computation_id).await;
374
375 if self.storage.output_is_unset(cell, computation_id) {
376 return true;
377 }
378
379 let (last_verified, dependencies) = self
381 .with_cell(cell, |data| {
382 (data.last_verified_version, data.dependencies.clone())
383 })
384 .await;
385
386 for dependency_id in dependencies {
390 self.update_cell(dependency_id).await;
391
392 if self
395 .cells
396 .read(&dependency_id, |_, dependency| {
397 dependency.last_updated_version > last_verified
398 })
399 .unwrap()
400 {
401 return true;
402 }
403 }
404 false
405 }
406
407 async fn run_compute_function(&self, cell_id: Cell) {
411 let computation_id = self.with_cell(cell_id, |data| data.computation_id).await;
412
413 let handle = self.handle(cell_id);
414 let changed = S::run_computation(&handle, cell_id, computation_id).await;
415
416 let version = self.version.load(Ordering::SeqCst);
417 let mut cell = self.cells.get_async(&cell_id).await.unwrap();
418 cell.last_verified_version = version;
419
420 if changed {
421 cell.last_updated_version = version;
422 }
423 }
424
425 async fn update_cell(&self, cell_id: Cell) {
428 let last_verified_version = self
429 .with_cell(cell_id, |data| data.last_verified_version)
430 .await;
431 let version = self.version.load(Ordering::SeqCst);
432
433 if last_verified_version != version {
434 if Box::pin(self.is_stale_cell(cell_id)).await {
436 self.run_compute_function(cell_id).await;
437 } else {
438 let mut cell = self.cells.get_async(&cell_id).await.unwrap();
439 cell.last_verified_version = version;
440 }
441 }
442 }
443
444 pub fn get<C: OutputType + ComputationId>(&self, compute: C) -> impl Future<Output = C::Output>
449 where
450 S: StorageFor<C>,
451 {
452 let cell_id = self.get_or_insert_cell(compute);
453 self.get_with_cell::<C>(cell_id)
454 }
455
456 pub(crate) fn get_with_cell<Concrete: OutputType>(
457 &self,
458 cell_id: Cell,
459 ) -> impl Future<Output = Concrete::Output> + Send
460 where
461 S: StorageFor<Concrete>,
462 {
463 async move {
464 self.update_cell(cell_id).await;
465
466 self.storage
467 .get_output(cell_id)
468 .expect("cell result should have been computed already")
469 }
470 }
471
472 async fn with_cell<R>(&self, cell: Cell, f: impl FnOnce(&CellData) -> R) -> R {
473 self.cells
474 .read_async(&cell, |_, data| f(data))
475 .await
476 .unwrap()
477 }
478}