1use super::{Config, Error};
2use commonware_codec::{Codec, FixedSize, ReadExt};
3use commonware_cryptography::{crc32, Crc32};
4use commonware_runtime::{
5 telemetry::metrics::status::GaugeExt, Blob, BufMut, Clock, Error as RError, IoBufMut, Metrics,
6 Storage,
7};
8use commonware_utils::Span;
9use futures::{future::try_join_all, lock::Mutex};
10use prometheus_client::metrics::{counter::Counter, gauge::Gauge};
11use std::collections::{BTreeMap, BTreeSet, HashMap};
12use tracing::{debug, warn};
13
14const BLOB_NAMES: [&[u8]; 2] = [b"left", b"right"];
16
17struct Info {
19 start: usize,
20 length: usize,
21}
22
23impl Info {
24 const fn new(start: usize, length: usize) -> Self {
26 Self { start, length }
27 }
28}
29
30struct Wrapper<B: Blob, K: Span> {
32 blob: B,
33 version: u64,
34 lengths: HashMap<K, Info>,
35 modified: BTreeSet<K>,
36 data: Vec<u8>,
37}
38
39impl<B: Blob, K: Span> Wrapper<B, K> {
40 const fn new(blob: B, version: u64, lengths: HashMap<K, Info>, data: Vec<u8>) -> Self {
42 Self {
43 blob,
44 version,
45 lengths,
46 modified: BTreeSet::new(),
47 data,
48 }
49 }
50
51 fn empty(blob: B) -> Self {
53 Self {
54 blob,
55 version: 0,
56 lengths: HashMap::new(),
57 modified: BTreeSet::new(),
58 data: Vec::new(),
59 }
60 }
61}
62
63struct State<B: Blob, K: Span> {
65 cursor: usize,
66 next_version: u64,
67 key_order_changed: u64,
68 blobs: [Wrapper<B, K>; 2],
69}
70
71pub struct Metadata<E: Clock + Storage + Metrics, K: Span, V: Codec> {
73 context: E,
74
75 map: BTreeMap<K, V>,
76 partition: String,
77 state: Mutex<State<E::Blob, K>>,
78
79 sync_overwrites: Counter,
80 sync_rewrites: Counter,
81 keys: Gauge,
82}
83
84impl<E: Clock + Storage + Metrics, K: Span, V: Codec> Metadata<E, K, V> {
85 pub async fn init(context: E, cfg: Config<V::Cfg>) -> Result<Self, Error> {
87 let (left_blob, left_len) = context.open(&cfg.partition, BLOB_NAMES[0]).await?;
89 let (right_blob, right_len) = context.open(&cfg.partition, BLOB_NAMES[1]).await?;
90
91 let (left_map, left_wrapper) =
93 Self::load(&cfg.codec_config, 0, left_blob, left_len).await?;
94 let (right_map, right_wrapper) =
95 Self::load(&cfg.codec_config, 1, right_blob, right_len).await?;
96
97 let mut map = left_map;
99 let mut cursor = 0;
100 let mut version = left_wrapper.version;
101 if right_wrapper.version > left_wrapper.version {
102 cursor = 1;
103 map = right_map;
104 version = right_wrapper.version;
105 }
106 let next_version = version.checked_add(1).expect("version overflow");
107
108 let sync_rewrites = Counter::default();
110 let sync_overwrites = Counter::default();
111 let keys = Gauge::default();
112 context.register(
113 "sync_rewrites",
114 "number of syncs that rewrote all data",
115 sync_rewrites.clone(),
116 );
117 context.register(
118 "sync_overwrites",
119 "number of syncs that modified existing data",
120 sync_overwrites.clone(),
121 );
122 context.register("keys", "number of tracked keys", keys.clone());
123
124 let _ = keys.try_set(map.len());
126 Ok(Self {
127 context,
128
129 map,
130 partition: cfg.partition,
131 state: Mutex::new(State {
132 cursor,
133 next_version,
134 key_order_changed: next_version, blobs: [left_wrapper, right_wrapper],
136 }),
137
138 sync_rewrites,
139 sync_overwrites,
140 keys,
141 })
142 }
143
144 async fn load(
145 codec_config: &V::Cfg,
146 index: usize,
147 blob: E::Blob,
148 len: u64,
149 ) -> Result<(BTreeMap<K, V>, Wrapper<E::Blob, K>), Error> {
150 if len == 0 {
152 return Ok((BTreeMap::new(), Wrapper::empty(blob)));
154 }
155
156 let buf = blob
158 .read_at(0, IoBufMut::zeroed(len as usize))
159 .await?
160 .coalesce();
161
162 if buf.len() < 8 + crc32::Digest::SIZE {
166 warn!(
168 blob = index,
169 len = buf.len(),
170 "blob is too short: truncating"
171 );
172 blob.resize(0).await?;
173 blob.sync().await?;
174 return Ok((BTreeMap::new(), Wrapper::empty(blob)));
175 }
176
177 let checksum_index = buf.len() - crc32::Digest::SIZE;
179 let stored_checksum =
180 u32::from_be_bytes(buf.as_ref()[checksum_index..].try_into().unwrap());
181 let computed_checksum = Crc32::checksum(&buf.as_ref()[..checksum_index]);
182 if stored_checksum != computed_checksum {
183 warn!(
185 blob = index,
186 stored = stored_checksum,
187 computed = computed_checksum,
188 "checksum mismatch: truncating"
189 );
190 blob.resize(0).await?;
191 blob.sync().await?;
192 return Ok((BTreeMap::new(), Wrapper::empty(blob)));
193 }
194
195 let version = u64::from_be_bytes(buf.as_ref()[..8].try_into().unwrap());
197
198 let mut data = BTreeMap::new();
203 let mut lengths = HashMap::new();
204 let mut cursor = u64::SIZE;
205 while cursor < checksum_index {
206 let key = K::read(&mut buf.as_ref()[cursor..].as_ref())
208 .expect("unable to read key from blob");
209 cursor += key.encode_size();
210
211 let value = V::read_cfg(&mut buf.as_ref()[cursor..].as_ref(), codec_config)
213 .expect("unable to read value from blob");
214 lengths.insert(key.clone(), Info::new(cursor, value.encode_size()));
215 cursor += value.encode_size();
216 data.insert(key, value);
217 }
218
219 Ok((
221 data,
222 Wrapper::new(blob, version, lengths, buf.freeze().into()),
223 ))
224 }
225
226 pub fn get(&self, key: &K) -> Option<&V> {
228 self.map.get(key)
229 }
230
231 pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
233 let value = self.map.get_mut(key)?;
235
236 let state = self.state.get_mut();
240 state.blobs[state.cursor].modified.insert(key.clone());
241 state.blobs[1 - state.cursor].modified.insert(key.clone());
242
243 Some(value)
244 }
245
246 pub fn clear(&mut self) {
249 self.map.clear();
251
252 let state = self.state.get_mut();
254 state.key_order_changed = state.next_version;
255 self.keys.set(0);
256 }
257
258 pub fn put(&mut self, key: K, value: V) -> Option<V> {
264 let previous = self.map.insert(key.clone(), value);
266
267 let state = self.state.get_mut();
271 if previous.is_some() {
272 state.blobs[state.cursor].modified.insert(key.clone());
273 state.blobs[1 - state.cursor].modified.insert(key);
274 } else {
275 state.key_order_changed = state.next_version;
276 }
277 let _ = self.keys.try_set(self.map.len());
278 previous
279 }
280
281 pub async fn put_sync(&mut self, key: K, value: V) -> Result<(), Error> {
283 self.put(key, value);
284 self.sync().await
285 }
286
287 pub fn upsert(&mut self, key: K, f: impl FnOnce(&mut V))
289 where
290 V: Default,
291 {
292 if let Some(value) = self.get_mut(&key) {
293 f(value);
295 } else {
296 let mut value = V::default();
298 f(&mut value);
299 self.put(key, value);
300 }
301 }
302
303 pub async fn upsert_sync(&mut self, key: K, f: impl FnOnce(&mut V)) -> Result<(), Error>
305 where
306 V: Default,
307 {
308 self.upsert(key, f);
309 self.sync().await
310 }
311
312 pub fn remove(&mut self, key: &K) -> Option<V> {
314 let past = self.map.remove(key);
316
317 if past.is_some() {
319 let state = self.state.get_mut();
320 state.key_order_changed = state.next_version;
321 }
322 let _ = self.keys.try_set(self.map.len());
323
324 past
325 }
326
327 pub fn keys(&self) -> impl Iterator<Item = &K> {
329 self.map.keys()
330 }
331
332 pub fn retain(&mut self, mut f: impl FnMut(&K, &V) -> bool) {
334 let old_len = self.map.len();
336 self.map.retain(|k, v| f(k, v));
337 let new_len = self.map.len();
338
339 if new_len != old_len {
341 let state = self.state.get_mut();
342 state.key_order_changed = state.next_version;
343 let _ = self.keys.try_set(self.map.len());
344 }
345 }
346
347 pub async fn sync(&self) -> Result<(), Error> {
349 let mut state = self.state.lock().await;
352
353 let cursor = state.cursor;
355 let next_version = state.next_version;
356 let key_order_changed = state.key_order_changed;
357
358 let past_version = state.blobs[cursor].version;
364 let next_next_version = next_version.checked_add(1).expect("version overflow");
365
366 let target_cursor = 1 - cursor;
368
369 state.cursor = target_cursor;
371 state.next_version = next_next_version;
372
373 let target = &mut state.blobs[target_cursor];
375
376 let mut overwrite = true;
379 let mut writes = vec![];
380 if key_order_changed < past_version {
381 let write_capacity = target.modified.len() + 2;
382 writes.reserve(write_capacity);
383 for key in target.modified.iter() {
384 let info = target.lengths.get(key).expect("key must exist");
385 let new_value = self.map.get(key).expect("key must exist");
386 if info.length == new_value.encode_size() {
387 let encoded = new_value.encode_mut();
389 target.data[info.start..info.start + info.length].copy_from_slice(&encoded);
390 writes.push(target.blob.write_at(info.start as u64, encoded));
391 } else {
392 overwrite = false;
394 break;
395 }
396 }
397 } else {
398 overwrite = false;
400 }
401
402 target.modified.clear();
404
405 if overwrite {
407 let version = next_version.to_be_bytes();
409 target.data[0..8].copy_from_slice(&version);
410 writes.push(target.blob.write_at(0, version.as_slice().into()));
411
412 let checksum_index = target.data.len() - crc32::Digest::SIZE;
414 let checksum = Crc32::checksum(&target.data[..checksum_index]).to_be_bytes();
415 target.data[checksum_index..].copy_from_slice(&checksum);
416 writes.push(
417 target
418 .blob
419 .write_at(checksum_index as u64, checksum.as_slice().into()),
420 );
421
422 try_join_all(writes).await?;
424 target.blob.sync().await?;
425
426 target.version = next_version;
428 self.sync_overwrites.inc();
429 return Ok(());
430 }
431
432 let mut lengths = HashMap::new();
434 let mut next_data = Vec::with_capacity(target.data.len());
435 next_data.put_u64(next_version);
436
437 for (key, value) in &self.map {
439 key.write(&mut next_data);
440 let start = next_data.len();
441 value.write(&mut next_data);
442 lengths.insert(key.clone(), Info::new(start, value.encode_size()));
443 }
444 next_data.put_u32(Crc32::checksum(&next_data[..]));
445
446 target.blob.write_at(0, next_data.clone()).await?;
448 if next_data.len() < target.data.len() {
449 target.blob.resize(next_data.len() as u64).await?;
450 }
451 target.blob.sync().await?;
452
453 target.version = next_version;
455 target.lengths = lengths;
456 target.data = next_data;
457
458 self.sync_rewrites.inc();
459 Ok(())
460 }
461
462 pub async fn destroy(self) -> Result<(), Error> {
464 let state = self.state.into_inner();
465 for (i, wrapper) in state.blobs.into_iter().enumerate() {
466 drop(wrapper.blob);
467 self.context
468 .remove(&self.partition, Some(BLOB_NAMES[i]))
469 .await?;
470 debug!(blob = i, "destroyed blob");
471 }
472 match self.context.remove(&self.partition, None).await {
473 Ok(()) => {}
474 Err(RError::PartitionMissing(_)) => {
475 }
477 Err(err) => return Err(Error::Runtime(err)),
478 }
479 Ok(())
480 }
481}