1use std::sync::atomic::{AtomicU32, Ordering};
2
3use crate::cell::CellData;
4use crate::storage::{ComputationId, StorageFor};
5use crate::{Cell, OutputType, Storage};
6
7mod handle;
8mod serialize;
9mod tests;
10
11pub use handle::DbHandle;
12
13const START_VERSION: u32 = 1;
14
15pub struct Db<Storage> {
20 cells: scc::HashMap<Cell, CellData>,
21 version: AtomicU32,
22 next_cell: AtomicU32,
23 storage: Storage,
24}
25
26impl<Storage: Default> Db<Storage> {
27 pub fn new() -> Self {
29 Self::with_storage(Storage::default())
30 }
31}
32
33impl<S: Default> Default for Db<S> {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39#[cfg(not(feature = "async"))]
42pub trait DbGet<C: OutputType> {
43 fn get(&self, key: C) -> C::Output;
44}
45#[cfg(feature = "async")]
46pub trait DbGet<C: OutputType> {
47 fn get(&self, key: C) -> impl Future<Output = C::Output> + Send;
48}
49
50#[cfg(not(feature = "async"))]
51impl<S, C> DbGet<C> for Db<S> where
52 C: OutputType + ComputationId,
53 S: Storage + StorageFor<C>
54{
55 fn get(&self, key: C) -> C::Output {
56 self.get(key)
57 }
58}
59
60#[cfg(feature = "async")]
61impl<S, C> DbGet<C> for Db<S> where
62 C: OutputType + ComputationId,
63 S: Storage + StorageFor<C> + Sync
64{
65 fn get(&self, key: C) -> impl Future<Output = C::Output> + Send {
66 Db::get(self, key)
67 }
68}
69
70impl<S> Db<S> {
71 pub fn with_storage(storage: S) -> Self {
73 Self {
74 cells: Default::default(),
75 version: AtomicU32::new(START_VERSION),
76 next_cell: AtomicU32::new(0),
77 storage,
78 }
79 }
80
81 pub fn storage(&self) -> &S {
83 &self.storage
84 }
85
86 pub fn storage_mut(&mut self) -> &mut S {
91 &mut self.storage
92 }
93}
94
95impl<S: Storage> Db<S> {
96 fn get_cell<C: OutputType>(&self, computation: &C) -> Option<Cell>
100 where
101 S: StorageFor<C>,
102 {
103 self.storage.get_cell_for_computation(computation)
104 }
105
106 pub(crate) fn get_or_insert_cell<C>(&self, input: C) -> Cell
107 where
108 C: OutputType + ComputationId,
109 S: StorageFor<C>,
110 {
111 if let Some(cell) = self.get_cell(&input) {
112 cell
113 } else {
114 let computation_id = C::computation_id();
115
116 let cell_id = self.next_cell.fetch_add(1, Ordering::Relaxed);
119 let new_cell = Cell::new(cell_id);
120
121 self.cells
122 .insert(new_cell, CellData::new(computation_id))
123 .ok();
124 self.storage.insert_new_cell(new_cell, input);
125 new_cell
126 }
127 }
128
129 fn handle(&self, cell: Cell) -> DbHandle<S> {
130 DbHandle::new(self, cell)
131 }
132
133 #[cfg(test)]
134 #[allow(unused)]
135 pub(crate) fn unwrap_cell_value<C: OutputType>(&self, input: &C) -> CellData
136 where
137 S: StorageFor<C>,
138 {
139 let cell = self
140 .get_cell(input)
141 .unwrap_or_else(|| panic!("unwrap_cell_value: Expected cell to exist"));
142
143 self.cells.read(&cell, |_, value| value.clone()).unwrap()
144 }
145
146 pub fn version(&self) -> u32 {
147 self.version.load(Ordering::SeqCst)
148 }
149}
150
151#[cfg(not(feature = "async"))]
152impl<S: Storage> Db<S> {
153 pub fn update_input<C>(&mut self, input: C, new_value: C::Output)
161 where
162 C: OutputType + ComputationId,
163 S: StorageFor<C>,
164 {
165 let cell_id = self.get_or_insert_cell(input);
166 assert!(
167 self.is_input(cell_id),
168 "`update_input` given a non-input value. Inputs must have 0 dependencies",
169 );
170
171 let changed = self.storage.update_output(cell_id, new_value);
172 let mut cell = self.cells.get(&cell_id).unwrap();
173
174 if changed {
175 let version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
176 cell.last_updated_version = version;
177 cell.last_verified_version = version;
178 } else {
179 cell.last_verified_version = self.version.load(Ordering::SeqCst);
180 }
181 }
182
183 fn is_input(&self, cell: Cell) -> bool {
184 self.with_cell(cell, |cell| cell.dependencies.is_empty())
185 }
186
187 pub fn is_stale<C: OutputType>(&self, input: &C) -> bool
192 where
193 S: StorageFor<C>,
194 {
195 let Some(cell) = self.get_cell(input) else {
197 return true;
198 };
199 self.is_stale_cell(cell)
200 }
201
202 fn is_stale_cell(&self, cell: Cell) -> bool {
206 let computation_id = self.with_cell(cell, |data| data.computation_id);
207
208 if self.storage.output_is_unset(cell, computation_id) {
209 return true;
210 }
211
212 let (last_verified, dependencies) = self.with_cell(cell, |data| {
214 (data.last_verified_version, data.dependencies.clone())
215 });
216
217 dependencies.iter().any(|dependency_id| {
221 self.update_cell(*dependency_id);
222
223 self.cells
226 .read(dependency_id, |_, dependency| {
227 dependency.last_updated_version > last_verified
228 })
229 .unwrap()
230 })
231 }
232
233 fn run_compute_function(&self, cell_id: Cell) {
237 let computation_id = self.with_cell(cell_id, |data| data.computation_id);
238
239 let handle = self.handle(cell_id);
240 let changed = S::run_computation(&handle, cell_id, computation_id);
241
242 let version = self.version.load(Ordering::SeqCst);
243 let mut cell = self.cells.get(&cell_id).unwrap();
244 cell.last_verified_version = version;
245
246 if changed {
247 cell.last_updated_version = version;
248 }
249 }
250
251 fn update_cell(&self, cell_id: Cell) {
254 let last_verified_version = self.with_cell(cell_id, |data| data.last_verified_version);
255 let version = self.version.load(Ordering::SeqCst);
256
257 if last_verified_version != version {
258 if self.is_stale_cell(cell_id) {
260 self.run_compute_function(cell_id);
261 } else {
262 let mut cell = self.cells.get(&cell_id).unwrap();
263 cell.last_verified_version = version;
264 }
265 }
266 }
267
268 pub fn get<C: OutputType + ComputationId>(&self, compute: C) -> C::Output
275 where
276 S: StorageFor<C>,
277 {
278 let cell_id = self.get_or_insert_cell(compute);
279 self.get_with_cell::<C>(cell_id)
280 }
281
282 pub(crate) fn get_with_cell<Concrete: OutputType>(&self, cell_id: Cell) -> Concrete::Output
284 where
285 S: StorageFor<Concrete>,
286 {
287 self.update_cell(cell_id);
288
289 self.storage
290 .get_output(cell_id)
291 .expect("cell result should have been computed already")
292 }
293
294 fn with_cell<R>(&self, cell: Cell, f: impl FnOnce(&CellData) -> R) -> R {
295 self.cells.read(&cell, |_, data| f(data)).unwrap()
296 }
297}
298
299#[cfg(feature = "async")]
300impl<S: Storage + Sync> Db<S> {
301 pub async fn update_input<C: OutputType>(&mut self, input: C, new_value: C::Output)
306 where
307 C: ComputationId,
308 S: StorageFor<C>,
309 {
310 let cell_id = self.get_or_insert_cell(input);
311 debug_assert!(
312 self.is_input(cell_id).await,
313 "`update_input` given a non-input value. Inputs must have 0 dependencies",
314 );
315
316 let changed = self.storage.update_output(cell_id, new_value);
317 let mut cell = self.cells.get(&cell_id).unwrap();
318
319 if changed {
320 let version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
321 cell.last_updated_version = version;
322 cell.last_verified_version = version;
323 } else {
324 cell.last_verified_version = self.version.load(Ordering::SeqCst);
325 }
326 }
327
328 fn is_input(&self, cell: Cell) -> impl Future<Output = bool> {
329 self.with_cell(cell, |cell| cell.dependencies.len() == 0)
330 }
331
332 pub async fn is_stale<C: OutputType>(&self, input: &C) -> bool
337 where
338 S: StorageFor<C>,
339 {
340 let Some(cell) = self.get_cell(input) else {
342 return true;
343 };
344 self.is_stale_cell(cell).await
345 }
346
347 async fn is_stale_cell(&self, cell: Cell) -> bool {
351 let computation_id = self.with_cell(cell, |data| data.computation_id).await;
352
353 if self.storage.output_is_unset(cell, computation_id) {
354 return true;
355 }
356
357 let (last_verified, dependencies) = self
359 .with_cell(cell, |data| {
360 (data.last_verified_version, data.dependencies.clone())
361 })
362 .await;
363
364 for dependency_id in dependencies {
368 self.update_cell(dependency_id).await;
369
370 if self
373 .cells
374 .read(&dependency_id, |_, dependency| {
375 dependency.last_updated_version > last_verified
376 })
377 .unwrap()
378 {
379 return true;
380 }
381 }
382 false
383 }
384
385 async fn run_compute_function(&self, cell_id: Cell) {
389 let computation_id = self.with_cell(cell_id, |data| data.computation_id).await;
390
391 let handle = self.handle(cell_id);
392 let changed = S::run_computation(&handle, cell_id, computation_id).await;
393
394 let version = self.version.load(Ordering::SeqCst);
395 let mut cell = self.cells.get_async(&cell_id).await.unwrap();
396 cell.last_verified_version = version;
397
398 if changed {
399 cell.last_updated_version = version;
400 }
401 }
402
403 async fn update_cell(&self, cell_id: Cell) {
406 let last_verified_version = self
407 .with_cell(cell_id, |data| data.last_verified_version)
408 .await;
409 let version = self.version.load(Ordering::SeqCst);
410
411 if last_verified_version != version {
412 if Box::pin(self.is_stale_cell(cell_id)).await {
414 self.run_compute_function(cell_id).await;
415 } else {
416 let mut cell = self.cells.get_async(&cell_id).await.unwrap();
417 cell.last_verified_version = version;
418 }
419 }
420 }
421
422 pub fn get<C: OutputType + ComputationId>(&self, compute: C) -> impl Future<Output = C::Output>
429 where
430 S: StorageFor<C>,
431 {
432 let cell_id = self.get_or_insert_cell(compute);
433 self.get_with_cell::<C>(cell_id)
434 }
435
436 pub(crate) fn get_with_cell<Concrete: OutputType>(
438 &self,
439 cell_id: Cell,
440 ) -> impl Future<Output = Concrete::Output> + Send
441 where
442 S: StorageFor<Concrete>,
443 {
444 async move {
445 self.update_cell(cell_id).await;
446
447 self.storage
448 .get_output(cell_id)
449 .expect("cell result should have been computed already")
450 }
451 }
452
453 async fn with_cell<R>(&self, cell: Cell, f: impl FnOnce(&CellData) -> R) -> R {
454 self.cells
455 .read_async(&cell, |_, data| f(data))
456 .await
457 .unwrap()
458 }
459}