1use crate::cell::CellData;
2use crate::storage::{ComputationId, StorageFor};
3use crate::{Cell, OutputType, Storage};
4use petgraph::graph::DiGraph;
5
6mod handle;
7mod tests;
8
9pub use handle::DbHandle;
10
11const START_VERSION: u32 = 1;
12
13#[derive(serde::Serialize, serde::Deserialize)]
18pub struct Db<Storage> {
19 cells: DiGraph<CellData, ()>,
20 version: u32,
21 storage: Storage,
22}
23
24impl<Storage: Default> Db<Storage> {
25 pub fn new() -> Self {
27 Self::with_storage(Storage::default())
28 }
29}
30
31impl<S: Default> Default for Db<S> {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl<S> Db<S> {
38 pub fn with_storage(storage: S) -> Self {
40 Self {
41 cells: DiGraph::default(),
42 version: START_VERSION,
43 storage,
44 }
45 }
46
47 pub fn storage(&self) -> &S {
49 &self.storage
50 }
51
52 pub fn storage_mut(&mut self) -> &mut S {
57 &mut self.storage
58 }
59}
60
61impl<S: Storage> Db<S> {
62 pub fn is_stale<C: OutputType>(&self, input: &C) -> bool
67 where
68 S: StorageFor<C>,
69 {
70 let Some(cell) = self.get_cell(input) else {
72 return true;
73 };
74 self.is_stale_cell(cell)
75 }
76
77 fn is_stale_cell(&self, cell: Cell) -> bool {
80 let computation_id = self.cells[cell.index()].computation_id;
81 if self.storage.output_is_unset(cell, computation_id) {
82 return true;
83 }
84
85 let neighbors = self.cells.neighbors(cell.index()).collect::<Vec<_>>();
86
87 neighbors.into_iter().any(|dependency_id| {
89 let dependency = &self.cells[dependency_id];
90 let cell = &self.cells[cell.index()];
91
92 dependency.last_verified_version != self.version
93 || dependency.last_updated_version > cell.last_verified_version
94 })
95 }
96
97 fn get_cell<C: OutputType>(&self, computation: &C) -> Option<Cell>
101 where
102 S: StorageFor<C>,
103 {
104 self.storage.get_cell_for_computation(computation)
105 }
106
107 pub(crate) fn get_or_insert_cell<C>(&mut self, input: C) -> Cell
108 where
109 C: OutputType + ComputationId,
110 S: StorageFor<C>,
111 {
112 if let Some(cell) = self.get_cell(&input) {
113 cell
114 } else {
115 let computation_id = C::computation_id();
116
117 let new_id = self.cells.add_node(CellData::new(computation_id));
118 let cell = Cell::new(new_id);
119 self.storage.insert_new_cell(cell, input);
120 cell
121 }
122 }
123
124 pub fn update_input<C: OutputType>(&mut self, input: C, new_value: C::Output)
129 where
130 C: std::fmt::Debug + ComputationId,
131 S: StorageFor<C>,
132 {
133 let cell_id = self.get_or_insert_cell(input);
134 debug_assert!(
135 self.is_input(cell_id),
136 "`update_input` given a non-input value. Inputs must have 0 dependencies",
137 );
138
139 let changed = self.storage.update_output(cell_id, new_value);
140 let cell = &mut self.cells[cell_id.index()];
141
142 if changed {
143 self.version += 1;
144 cell.last_updated_version = self.version;
145 cell.last_verified_version = self.version;
146 } else {
147 cell.last_verified_version = self.version;
148 }
149 }
150
151 fn is_input(&self, cell: Cell) -> bool {
152 self.cells.neighbors(cell.index()).count() == 0
153 }
154
155 fn handle(&mut self, cell: Cell) -> DbHandle<S> {
156 DbHandle::new(self, cell)
157 }
158
159 #[cfg(test)]
160 pub(crate) fn unwrap_cell_value<C: OutputType>(&self, input: &C) -> &CellData
161 where
162 C: std::fmt::Debug,
163 S: StorageFor<C>,
164 {
165 let cell = self
166 .get_cell(input)
167 .unwrap_or_else(|| panic!("unwrap_cell_value: Expected cell for `{input:?}` to exist"));
168 &self.cells[cell.index()]
169 }
170
171 fn run_compute_function(&mut self, cell_id: Cell) {
175 let cell = &self.cells[cell_id.index()];
176 let computation_id = cell.computation_id;
177
178 let mut handle = self.handle(cell_id);
179 let changed = S::run_computation(&mut handle, cell_id, computation_id);
180
181 let cell = &mut self.cells[cell_id.index()];
182 cell.last_verified_version = self.version;
183
184 if changed {
185 cell.last_updated_version = self.version;
186 }
187 }
188
189 fn update_cell(&mut self, cell_id: Cell) {
192 let cell = &self.cells[cell_id.index()];
193
194 if cell.last_verified_version != self.version {
195 if self.is_stale_cell(cell_id) {
197 self.run_compute_function(cell_id);
198 } else {
199 let cell = &mut self.cells[cell_id.index()];
200 cell.last_verified_version = self.version;
201 }
202 }
203 }
204
205 pub fn get<C: OutputType + ComputationId>(&mut self, compute: C) -> &C::Output
210 where
211 S: StorageFor<C>,
212 {
213 let cell_id = self.get_or_insert_cell(compute);
214 self.get_with_cell::<C>(cell_id)
215 }
216
217 pub(crate) fn get_with_cell<Concrete: OutputType>(&mut self, cell_id: Cell) -> &Concrete::Output
222 where
223 S: StorageFor<Concrete>,
224 {
225 self.update_cell(cell_id);
226
227 self.storage
228 .get_output(cell_id)
229 .expect("cell result should have been computed already")
230 }
231}