lumina_node/store/
in_memory_store.rs1use std::collections::hash_map::Entry;
2use std::collections::HashMap;
3use std::fmt::Display;
4use std::pin::pin;
5
6use async_trait::async_trait;
7use celestia_types::hash::Hash;
8use celestia_types::ExtendedHeader;
9use cid::Cid;
10use tokio::sync::{Notify, RwLock};
11use tracing::debug;
12
13use crate::block_ranges::BlockRanges;
14use crate::store::utils::VerifiedExtendedHeaders;
15use crate::store::{Result, SamplingMetadata, Store, StoreError, StoreInsertionError};
16
17#[derive(Debug)]
19pub struct InMemoryStore {
20 inner: RwLock<InMemoryStoreInner>,
22 header_added_notifier: Notify,
24}
25
26#[derive(Debug, Clone)]
27struct InMemoryStoreInner {
28 headers: HashMap<Hash, ExtendedHeader>,
30 height_to_hash: HashMap<u64, Hash>,
32 header_ranges: BlockRanges,
34 sampling_data: HashMap<u64, SamplingMetadata>,
36 sampled_ranges: BlockRanges,
38 pruned_ranges: BlockRanges,
40}
41
42impl InMemoryStoreInner {
43 fn new() -> Self {
44 Self {
45 headers: HashMap::new(),
46 height_to_hash: HashMap::new(),
47 header_ranges: BlockRanges::default(),
48 sampling_data: HashMap::new(),
49 sampled_ranges: BlockRanges::default(),
50 pruned_ranges: BlockRanges::default(),
51 }
52 }
53}
54
55impl InMemoryStore {
56 pub fn new() -> Self {
58 InMemoryStore {
59 inner: RwLock::new(InMemoryStoreInner::new()),
60 header_added_notifier: Notify::new(),
61 }
62 }
63
64 #[inline]
65 async fn get_head_height(&self) -> Result<u64> {
66 self.inner.read().await.get_head_height()
67 }
68
69 async fn get_head(&self) -> Result<ExtendedHeader> {
70 let head_height = self.get_head_height().await?;
71 self.get_by_height(head_height).await
72 }
73
74 async fn contains_hash(&self, hash: &Hash) -> bool {
75 self.inner.read().await.contains_hash(hash)
76 }
77
78 async fn get_by_hash(&self, hash: &Hash) -> Result<ExtendedHeader> {
79 self.inner.read().await.get_by_hash(hash)
80 }
81
82 async fn contains_height(&self, height: u64) -> bool {
83 self.inner.read().await.contains_height(height)
84 }
85
86 async fn get_by_height(&self, height: u64) -> Result<ExtendedHeader> {
87 self.inner.read().await.get_by_height(height)
88 }
89
90 pub(crate) async fn insert<R>(&self, headers: R) -> Result<()>
91 where
92 R: TryInto<VerifiedExtendedHeaders> + Send,
93 <R as TryInto<VerifiedExtendedHeaders>>::Error: Display,
94 {
95 let headers = headers
96 .try_into()
97 .map_err(|e| StoreInsertionError::HeadersVerificationFailed(e.to_string()))?;
98
99 self.inner.write().await.insert(headers).await?;
100 self.header_added_notifier.notify_waiters();
101
102 Ok(())
103 }
104
105 async fn update_sampling_metadata(&self, height: u64, cids: Vec<Cid>) -> Result<()> {
106 self.inner
107 .write()
108 .await
109 .update_sampling_metadata(height, cids)
110 .await
111 }
112
113 async fn get_sampling_metadata(&self, height: u64) -> Result<Option<SamplingMetadata>> {
114 self.inner.read().await.get_sampling_metadata(height).await
115 }
116
117 async fn mark_as_sampled(&self, height: u64) -> Result<()> {
118 self.inner.write().await.mark_as_sampled(height).await
119 }
120
121 async fn get_stored_ranges(&self) -> BlockRanges {
122 self.inner.read().await.header_ranges.clone()
123 }
124
125 async fn get_sampled_ranges(&self) -> BlockRanges {
126 self.inner.read().await.sampled_ranges.clone()
127 }
128
129 async fn get_pruned_ranges(&self) -> BlockRanges {
130 self.inner.read().await.pruned_ranges.clone()
131 }
132
133 pub async fn async_clone(&self) -> Self {
135 InMemoryStore {
136 inner: RwLock::new(self.inner.read().await.clone()),
137 header_added_notifier: Notify::new(),
138 }
139 }
140
141 async fn remove_height(&self, height: u64) -> Result<()> {
142 let mut inner = self.inner.write().await;
143 inner.remove_height(height)
144 }
145}
146
147impl InMemoryStoreInner {
148 #[inline]
149 fn get_head_height(&self) -> Result<u64> {
150 self.header_ranges.head().ok_or(StoreError::NotFound)
151 }
152
153 fn contains_hash(&self, hash: &Hash) -> bool {
154 self.headers.contains_key(hash)
155 }
156
157 fn get_by_hash(&self, hash: &Hash) -> Result<ExtendedHeader> {
158 self.headers.get(hash).cloned().ok_or(StoreError::NotFound)
159 }
160
161 fn contains_height(&self, height: u64) -> bool {
162 self.header_ranges.contains(height)
163 }
164
165 fn get_by_height(&self, height: u64) -> Result<ExtendedHeader> {
166 let Some(hash) = self.height_to_hash.get(&height).copied() else {
167 return Err(StoreError::NotFound);
168 };
169
170 Ok(self
171 .headers
172 .get(&hash)
173 .expect("inconsistent between header hash and header heights")
174 .to_owned())
175 }
176
177 async fn insert(&mut self, headers: VerifiedExtendedHeaders) -> Result<()> {
178 let (Some(head), Some(tail)) = (headers.as_ref().first(), headers.as_ref().last()) else {
179 return Ok(());
180 };
181
182 let headers_range = head.height().value()..=tail.height().value();
183 let (prev_exists, next_exists) = self
184 .header_ranges
185 .check_insertion_constraints(&headers_range)
186 .map_err(StoreInsertionError::ContraintsNotMet)?;
187
188 self.verify_against_neighbours(prev_exists.then_some(head), next_exists.then_some(tail))?;
190
191 for header in headers.into_iter() {
192 let hash = header.hash();
193 let height = header.height().value();
194
195 debug_assert!(
196 !self.height_to_hash.contains_key(&height),
197 "inconsistency between headers table and ranges table"
198 );
199
200 let Entry::Vacant(headers_entry) = self.headers.entry(hash) else {
201 return Err(StoreInsertionError::HashExists(hash).into());
203 };
204
205 debug!("Inserting header {hash} with height {height}");
206 headers_entry.insert(header);
207 self.height_to_hash.insert(height, hash);
208 }
209
210 self.header_ranges
211 .insert_relaxed(&headers_range)
212 .expect("invalid range");
213 self.sampled_ranges
214 .remove_relaxed(&headers_range)
215 .expect("invalid range");
216 self.pruned_ranges
217 .remove_relaxed(&headers_range)
218 .expect("invalid range");
219
220 Ok(())
221 }
222
223 fn verify_against_neighbours(
224 &self,
225 lowest_header: Option<&ExtendedHeader>,
226 highest_header: Option<&ExtendedHeader>,
227 ) -> Result<()> {
228 if let Some(lowest_header) = lowest_header {
229 let prev = self
230 .get_by_height(lowest_header.height().value() - 1)
231 .map_err(|e| match e {
232 StoreError::NotFound => {
233 panic!("inconsistency between headers and ranges table")
234 }
235 e => e,
236 })?;
237
238 prev.verify(lowest_header)
239 .map_err(|e| StoreInsertionError::NeighborsVerificationFailed(e.to_string()))?;
240 }
241
242 if let Some(highest_header) = highest_header {
243 let next = self
244 .get_by_height(highest_header.height().value() + 1)
245 .map_err(|e| match e {
246 StoreError::NotFound => {
247 panic!("inconsistency between headers and ranges table")
248 }
249 e => e,
250 })?;
251
252 highest_header
253 .verify(&next)
254 .map_err(|e| StoreInsertionError::NeighborsVerificationFailed(e.to_string()))?;
255 }
256
257 Ok(())
258 }
259
260 async fn update_sampling_metadata(&mut self, height: u64, cids: Vec<Cid>) -> Result<()> {
261 if !self.contains_height(height) {
262 return Err(StoreError::NotFound);
263 }
264
265 match self.sampling_data.entry(height) {
266 Entry::Vacant(entry) => {
267 entry.insert(SamplingMetadata { cids });
268 }
269 Entry::Occupied(mut entry) => {
270 let metadata = entry.get_mut();
271
272 for cid in cids {
273 if !metadata.cids.contains(&cid) {
274 metadata.cids.push(cid);
275 }
276 }
277 }
278 }
279
280 Ok(())
281 }
282
283 async fn get_sampling_metadata(&self, height: u64) -> Result<Option<SamplingMetadata>> {
284 if !self.contains_height(height) {
285 return Err(StoreError::NotFound);
286 }
287
288 let Some(metadata) = self.sampling_data.get(&height) else {
289 return Ok(None);
290 };
291
292 Ok(Some(metadata.clone()))
293 }
294
295 async fn mark_as_sampled(&mut self, height: u64) -> Result<()> {
296 if !self.contains_height(height) {
297 return Err(StoreError::NotFound);
298 }
299
300 self.sampled_ranges
301 .insert_relaxed(height..=height)
302 .expect("invalid height");
303
304 Ok(())
305 }
306
307 fn remove_height(&mut self, height: u64) -> Result<()> {
308 if !self.header_ranges.contains(height) {
309 return Err(StoreError::NotFound);
310 }
311
312 let Entry::Occupied(height_to_hash) = self.height_to_hash.entry(height) else {
313 return Err(StoreError::StoredDataError(format!(
314 "inconsistency between ranges and height_to_hash tables, height {height}"
315 )));
316 };
317
318 let hash = height_to_hash.get();
319 let Entry::Occupied(header) = self.headers.entry(*hash) else {
320 return Err(StoreError::StoredDataError(format!(
321 "inconsistency between header and height_to_hash tables, hash {hash}"
322 )));
323 };
324
325 self.sampling_data.remove(&height);
327
328 height_to_hash.remove_entry();
329 header.remove_entry();
330
331 self.header_ranges
332 .remove_relaxed(height..=height)
333 .expect("invalid height");
334 self.sampled_ranges
335 .remove_relaxed(height..=height)
336 .expect("invalid height");
337 self.pruned_ranges
338 .insert_relaxed(height..=height)
339 .expect("invalid height");
340
341 Ok(())
342 }
343}
344
345#[async_trait]
346impl Store for InMemoryStore {
347 async fn get_head(&self) -> Result<ExtendedHeader> {
348 self.get_head().await
349 }
350
351 async fn get_by_hash(&self, hash: &Hash) -> Result<ExtendedHeader> {
352 self.get_by_hash(hash).await
353 }
354
355 async fn get_by_height(&self, height: u64) -> Result<ExtendedHeader> {
356 self.get_by_height(height).await
357 }
358
359 async fn wait_new_head(&self) -> u64 {
360 let head = self.get_head_height().await.unwrap_or(0);
361 let mut notifier = pin!(self.header_added_notifier.notified());
362
363 loop {
364 let new_head = self.get_head_height().await.unwrap_or(0);
365
366 if head != new_head {
367 return new_head;
368 }
369
370 notifier.as_mut().await;
372
373 notifier.set(self.header_added_notifier.notified());
375 }
376 }
377
378 async fn wait_height(&self, height: u64) -> Result<()> {
379 let mut notifier = pin!(self.header_added_notifier.notified());
380
381 loop {
382 if self.contains_height(height).await {
383 return Ok(());
384 }
385
386 notifier.as_mut().await;
388
389 notifier.set(self.header_added_notifier.notified());
391 }
392 }
393
394 async fn head_height(&self) -> Result<u64> {
395 self.get_head_height().await
396 }
397
398 async fn has(&self, hash: &Hash) -> bool {
399 self.contains_hash(hash).await
400 }
401
402 async fn has_at(&self, height: u64) -> bool {
403 self.contains_height(height).await
404 }
405
406 async fn insert<R>(&self, header: R) -> Result<()>
407 where
408 R: TryInto<VerifiedExtendedHeaders> + Send,
409 <R as TryInto<VerifiedExtendedHeaders>>::Error: Display,
410 {
411 self.insert(header).await
412 }
413
414 async fn update_sampling_metadata(&self, height: u64, cids: Vec<Cid>) -> Result<()> {
415 self.update_sampling_metadata(height, cids).await
416 }
417
418 async fn mark_as_sampled(&self, height: u64) -> Result<()> {
419 self.mark_as_sampled(height).await
420 }
421
422 async fn get_sampling_metadata(&self, height: u64) -> Result<Option<SamplingMetadata>> {
423 self.get_sampling_metadata(height).await
424 }
425
426 async fn get_stored_header_ranges(&self) -> Result<BlockRanges> {
427 Ok(self.get_stored_ranges().await)
428 }
429
430 async fn get_sampled_ranges(&self) -> Result<BlockRanges> {
431 Ok(self.get_sampled_ranges().await)
432 }
433
434 async fn get_pruned_ranges(&self) -> Result<BlockRanges> {
435 Ok(self.get_pruned_ranges().await)
436 }
437
438 async fn remove_height(&self, height: u64) -> Result<()> {
439 self.remove_height(height).await
440 }
441
442 async fn close(self) -> Result<()> {
443 Ok(())
444 }
445}
446
447impl Default for InMemoryStore {
448 fn default() -> Self {
449 Self::new()
450 }
451}