1use crate::cell::CellData;
2use crate::storage::{ComputationId, StorageFor};
3use crate::{Cell, OutputType, Storage};
4use petgraph::graph::DiGraph;
5
6mod handle;
7mod tests;
8
9pub use handle::DbHandle;
10
11const START_VERSION: u32 = 1;
12
13#[derive(serde::Serialize, serde::Deserialize)]
18pub struct Db<Storage> {
19 cells: DiGraph<CellData, ()>,
20 version: u32,
21 storage: Storage,
22}
23
24impl<Storage: Default> Db<Storage> {
25 pub fn new() -> Self {
27 Self::with_storage(Storage::default())
28 }
29}
30
31impl<S: Default> Default for Db<S> {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl<S> Db<S> {
38 pub fn with_storage(storage: S) -> Self {
40 Self {
41 cells: DiGraph::default(),
42 version: START_VERSION,
43 storage,
44 }
45 }
46
47 pub fn storage(&self) -> &S {
49 &self.storage
50 }
51
52 pub fn storage_mut(&mut self) -> &mut S {
57 &mut self.storage
58 }
59}
60
61impl<S: Storage> Db<S> {
62 pub fn is_stale<C: OutputType>(&mut self, input: &C) -> bool
67 where
68 S: StorageFor<C>,
69 {
70 let Some(cell) = self.get_cell(input) else {
72 return true;
73 };
74 self.is_stale_cell(cell)
75 }
76
77 fn is_stale_cell(&mut self, cell: Cell) -> bool {
80 let data = &self.cells[cell.index()];
81 let computation_id = data.computation_id;
82 let last_verified = data.last_verified_version;
83
84 if self.storage.output_is_unset(cell, computation_id) {
85 return true;
86 }
87
88 let neighbors = self.cells.neighbors(cell.index()).collect::<Vec<_>>();
89
90 neighbors.into_iter().rev().any(|dependency_id| {
92 self.update_cell(Cell::new(dependency_id));
93
94 self.cells[dependency_id].last_updated_version > last_verified
97 })
98 }
99
100 fn get_cell<C: OutputType>(&self, computation: &C) -> Option<Cell>
104 where
105 S: StorageFor<C>,
106 {
107 self.storage.get_cell_for_computation(computation)
108 }
109
110 pub(crate) fn get_or_insert_cell<C>(&mut self, input: C) -> Cell
111 where
112 C: OutputType + ComputationId,
113 S: StorageFor<C>,
114 {
115 if let Some(cell) = self.get_cell(&input) {
116 cell
117 } else {
118 let computation_id = C::computation_id();
119
120 let new_id = self.cells.add_node(CellData::new(computation_id));
121 let cell = Cell::new(new_id);
122 self.storage.insert_new_cell(cell, input);
123 cell
124 }
125 }
126
127 pub fn update_input<C: OutputType>(&mut self, input: C, new_value: C::Output)
132 where
133 C: ComputationId,
134 S: StorageFor<C>,
135 {
136 let cell_id = self.get_or_insert_cell(input);
137 debug_assert!(
138 self.is_input(cell_id),
139 "`update_input` given a non-input value. Inputs must have 0 dependencies",
140 );
141
142 let changed = self.storage.update_output(cell_id, new_value);
143 let cell = &mut self.cells[cell_id.index()];
144
145 if changed {
146 self.version += 1;
147 cell.last_updated_version = self.version;
148 cell.last_verified_version = self.version;
149 } else {
150 cell.last_verified_version = self.version;
151 }
152 }
153
154 fn is_input(&self, cell: Cell) -> bool {
155 self.cells.neighbors(cell.index()).count() == 0
156 }
157
158 fn handle(&mut self, cell: Cell) -> DbHandle<S> {
159 DbHandle::new(self, cell)
160 }
161
162 #[cfg(test)]
163 pub(crate) fn unwrap_cell_value<C: OutputType>(&self, input: &C) -> &CellData
164 where
165 S: StorageFor<C>,
166 {
167 let cell = self
168 .get_cell(input)
169 .unwrap_or_else(|| panic!("unwrap_cell_value: Expected cell to exist"));
170 &self.cells[cell.index()]
171 }
172
173 fn run_compute_function(&mut self, cell_id: Cell) {
177 let cell = &self.cells[cell_id.index()];
178 let computation_id = cell.computation_id;
179
180 let mut handle = self.handle(cell_id);
181 let changed = S::run_computation(&mut handle, cell_id, computation_id);
182
183 let cell = &mut self.cells[cell_id.index()];
184 cell.last_verified_version = self.version;
185
186 if changed {
187 cell.last_updated_version = self.version;
188 }
189 }
190
191 fn update_cell(&mut self, cell_id: Cell) {
194 let cell = &self.cells[cell_id.index()];
195
196 if cell.last_verified_version != self.version {
197 if self.is_stale_cell(cell_id) {
199 self.run_compute_function(cell_id);
200 } else {
201 let cell = &mut self.cells[cell_id.index()];
202 cell.last_verified_version = self.version;
203 }
204 }
205 }
206
207 pub fn get<C: OutputType + ComputationId>(&mut self, compute: C) -> &C::Output
212 where
213 S: StorageFor<C>,
214 {
215 let cell_id = self.get_or_insert_cell(compute);
216 self.get_with_cell::<C>(cell_id)
217 }
218
219 pub(crate) fn get_with_cell<Concrete: OutputType>(&mut self, cell_id: Cell) -> &Concrete::Output
224 where
225 S: StorageFor<Concrete>,
226 {
227 self.update_cell(cell_id);
228
229 self.storage
230 .get_output(cell_id)
231 .expect("cell result should have been computed already")
232 }
233}