1use crate::cell::CellData;
2use crate::storage::{ComputationId, StorageFor};
3use crate::{Cell, OutputType, Storage};
4use petgraph::graph::DiGraph;
5
6mod handle;
7mod tests;
8
9pub use handle::DbHandle;
10
11const START_VERSION: u32 = 1;
12
13#[derive(serde::Serialize, serde::Deserialize)]
18pub struct Db<Storage> {
19 cells: DiGraph<CellData, ()>,
20 version: u32,
21 storage: Storage,
22}
23
24impl<Storage: Default> Db<Storage> {
25 pub fn new() -> Self {
27 Self::with_storage(Storage::default())
28 }
29}
30
31impl<S> Db<S> {
32 pub fn with_storage(storage: S) -> Self {
34 Self {
35 cells: DiGraph::default(),
36 version: START_VERSION,
37 storage,
38 }
39 }
40
41 pub fn storage(&self) -> &S {
43 &self.storage
44 }
45
46 pub fn storage_mut(&mut self) -> &mut S {
51 &mut self.storage
52 }
53}
54
55impl<S: Storage> Db<S> {
56 pub fn is_stale<C: OutputType>(&self, input: &C) -> bool
61 where
62 S: StorageFor<C>,
63 {
64 let Some(cell) = self.get_cell(input) else {
66 return true;
67 };
68 self.is_stale_cell(cell)
69 }
70
71 fn is_stale_cell(&self, cell: Cell) -> bool {
74 let computation_id = self.cells[cell.index()].computation_id;
75 if self.storage.output_is_unset(cell, computation_id) {
76 return true;
77 }
78
79 let neighbors = self.cells.neighbors(cell.index()).collect::<Vec<_>>();
80
81 neighbors.into_iter().any(|dependency_id| {
83 let dependency = &self.cells[dependency_id];
84 let cell = &self.cells[cell.index()];
85
86 dependency.last_verified_version != self.version
87 || dependency.last_updated_version > cell.last_verified_version
88 })
89 }
90
91 fn get_cell<C: OutputType>(&self, computation: &C) -> Option<Cell>
95 where
96 S: StorageFor<C>,
97 {
98 self.storage.get_cell_for_computation(computation)
99 }
100
101 pub(crate) fn get_or_insert_cell<C>(&mut self, input: C) -> Cell
102 where
103 C: OutputType + ComputationId,
104 S: StorageFor<C>,
105 {
106 if let Some(cell) = self.get_cell(&input) {
107 cell
108 } else {
109 let computation_id = C::computation_id();
110
111 let new_id = self.cells.add_node(CellData::new(computation_id));
112 let cell = Cell::new(new_id);
113 self.storage.insert_new_cell(cell, input);
114 cell
115 }
116 }
117
118 pub fn update_input<C: OutputType>(&mut self, input: C, new_value: C::Output)
123 where
124 C: std::fmt::Debug + ComputationId,
125 S: StorageFor<C>,
126 {
127 let cell_id = self.get_or_insert_cell(input);
128 debug_assert!(
129 self.is_input(cell_id),
130 "`update_input` given a non-input value. Inputs must have 0 dependencies",
131 );
132
133 let changed = self.storage.update_output(cell_id, new_value);
134 let cell = &mut self.cells[cell_id.index()];
135
136 if changed {
137 self.version += 1;
138 cell.last_updated_version = self.version;
139 cell.last_verified_version = self.version;
140 } else {
141 cell.last_verified_version = self.version;
142 }
143 }
144
145 fn is_input(&self, cell: Cell) -> bool {
146 self.cells.neighbors(cell.index()).count() == 0
147 }
148
149 fn handle(&mut self, cell: Cell) -> DbHandle<S> {
150 DbHandle::new(self, cell)
151 }
152
153 #[cfg(test)]
154 pub(crate) fn unwrap_cell_value<C: OutputType>(&self, input: &C) -> &CellData
155 where
156 C: std::fmt::Debug,
157 S: StorageFor<C>,
158 {
159 let cell = self
160 .get_cell(input)
161 .unwrap_or_else(|| panic!("unwrap_cell_value: Expected cell for `{input:?}` to exist"));
162 &self.cells[cell.index()]
163 }
164
165 fn run_compute_function(&mut self, cell_id: Cell) {
169 let cell = &self.cells[cell_id.index()];
170 let computation_id = cell.computation_id;
171
172 let mut handle = self.handle(cell_id);
173 let changed = S::run_computation(&mut handle, cell_id, computation_id);
174
175 let cell = &mut self.cells[cell_id.index()];
176 cell.last_verified_version = self.version;
177
178 if changed {
179 cell.last_updated_version = self.version;
180 }
181 }
182
183 fn update_cell(&mut self, cell_id: Cell) {
186 let cell = &self.cells[cell_id.index()];
187
188 if cell.last_verified_version != self.version {
189 if self.is_stale_cell(cell_id) {
191 self.run_compute_function(cell_id);
192 } else {
193 let cell = &mut self.cells[cell_id.index()];
194 cell.last_verified_version = self.version;
195 }
196 }
197 }
198
199 pub fn get<C: OutputType + ComputationId>(&mut self, compute: C) -> &C::Output
204 where
205 S: StorageFor<C>,
206 {
207 let cell_id = self.get_or_insert_cell(compute);
208 self.get_with_cell::<C>(cell_id)
209 }
210
211 pub(crate) fn get_with_cell<Concrete: OutputType>(&mut self, cell_id: Cell) -> &Concrete::Output
216 where
217 S: StorageFor<Concrete>,
218 {
219 self.update_cell(cell_id);
220
221 self.storage
222 .get_output(cell_id)
223 .expect("cell result should have been computed already")
224 }
225}