1use super::{Config, Error};
2use commonware_codec::{Codec, FixedSize, ReadExt};
3use commonware_cryptography::{crc32, Crc32};
4use commonware_runtime::{
5 telemetry::metrics::status::GaugeExt, Blob, BufMut, Clock, Error as RError, Metrics, Storage,
6};
7use commonware_utils::{sync::AsyncMutex, Span};
8use futures::future::try_join_all;
9use prometheus_client::metrics::{counter::Counter, gauge::Gauge};
10use std::collections::{BTreeMap, BTreeSet, HashMap};
11use tracing::{debug, warn};
12
13const BLOB_NAMES: [&[u8]; 2] = [b"left", b"right"];
15
16struct Info {
18 start: usize,
19 length: usize,
20}
21
22impl Info {
23 const fn new(start: usize, length: usize) -> Self {
25 Self { start, length }
26 }
27}
28
29struct Wrapper<B: Blob, K: Span> {
31 blob: B,
32 version: u64,
33 lengths: HashMap<K, Info>,
34 modified: BTreeSet<K>,
35 data: Vec<u8>,
36}
37
38impl<B: Blob, K: Span> Wrapper<B, K> {
39 const fn new(blob: B, version: u64, lengths: HashMap<K, Info>, data: Vec<u8>) -> Self {
41 Self {
42 blob,
43 version,
44 lengths,
45 modified: BTreeSet::new(),
46 data,
47 }
48 }
49
50 fn empty(blob: B) -> Self {
52 Self {
53 blob,
54 version: 0,
55 lengths: HashMap::new(),
56 modified: BTreeSet::new(),
57 data: Vec::new(),
58 }
59 }
60}
61
62struct State<B: Blob, K: Span> {
64 cursor: usize,
65 next_version: u64,
66 key_order_changed: u64,
67 blobs: [Wrapper<B, K>; 2],
68}
69
70pub struct Metadata<E: Clock + Storage + Metrics, K: Span, V: Codec> {
72 context: E,
73
74 map: BTreeMap<K, V>,
75 partition: String,
76 state: AsyncMutex<State<E::Blob, K>>,
77
78 sync_overwrites: Counter,
79 sync_rewrites: Counter,
80 keys: Gauge,
81}
82
83impl<E: Clock + Storage + Metrics, K: Span, V: Codec> Metadata<E, K, V> {
84 pub async fn init(context: E, cfg: Config<V::Cfg>) -> Result<Self, Error> {
86 let (left_blob, left_len) = context.open(&cfg.partition, BLOB_NAMES[0]).await?;
88 let (right_blob, right_len) = context.open(&cfg.partition, BLOB_NAMES[1]).await?;
89
90 let (left_map, left_wrapper) =
92 Self::load(&cfg.codec_config, 0, left_blob, left_len).await?;
93 let (right_map, right_wrapper) =
94 Self::load(&cfg.codec_config, 1, right_blob, right_len).await?;
95
96 let mut map = left_map;
98 let mut cursor = 0;
99 let mut version = left_wrapper.version;
100 if right_wrapper.version > left_wrapper.version {
101 cursor = 1;
102 map = right_map;
103 version = right_wrapper.version;
104 }
105 let next_version = version.checked_add(1).expect("version overflow");
106
107 let sync_rewrites = Counter::default();
109 let sync_overwrites = Counter::default();
110 let keys = Gauge::default();
111 context.register(
112 "sync_rewrites",
113 "number of syncs that rewrote all data",
114 sync_rewrites.clone(),
115 );
116 context.register(
117 "sync_overwrites",
118 "number of syncs that modified existing data",
119 sync_overwrites.clone(),
120 );
121 context.register("keys", "number of tracked keys", keys.clone());
122
123 let _ = keys.try_set(map.len());
125 Ok(Self {
126 context,
127
128 map,
129 partition: cfg.partition,
130 state: AsyncMutex::new(State {
131 cursor,
132 next_version,
133 key_order_changed: next_version, blobs: [left_wrapper, right_wrapper],
135 }),
136
137 sync_rewrites,
138 sync_overwrites,
139 keys,
140 })
141 }
142
143 async fn load(
144 codec_config: &V::Cfg,
145 index: usize,
146 blob: E::Blob,
147 len: u64,
148 ) -> Result<(BTreeMap<K, V>, Wrapper<E::Blob, K>), Error> {
149 if len == 0 {
151 return Ok((BTreeMap::new(), Wrapper::empty(blob)));
153 }
154
155 let len: usize = len.try_into().expect("blob too large for platform");
157 let buf = blob.read_at(0, len).await?.coalesce();
158
159 if buf.len() < 8 + crc32::Digest::SIZE {
163 warn!(
165 blob = index,
166 len = buf.len(),
167 "blob is too short: truncating"
168 );
169 blob.resize(0).await?;
170 blob.sync().await?;
171 return Ok((BTreeMap::new(), Wrapper::empty(blob)));
172 }
173
174 let checksum_index = buf.len() - crc32::Digest::SIZE;
176 let stored_checksum =
177 u32::from_be_bytes(buf.as_ref()[checksum_index..].try_into().unwrap());
178 let computed_checksum = Crc32::checksum(&buf.as_ref()[..checksum_index]);
179 if stored_checksum != computed_checksum {
180 warn!(
182 blob = index,
183 stored = stored_checksum,
184 computed = computed_checksum,
185 "checksum mismatch: truncating"
186 );
187 blob.resize(0).await?;
188 blob.sync().await?;
189 return Ok((BTreeMap::new(), Wrapper::empty(blob)));
190 }
191
192 let version = u64::from_be_bytes(buf.as_ref()[..8].try_into().unwrap());
194
195 let mut data = BTreeMap::new();
200 let mut lengths = HashMap::new();
201 let mut cursor = u64::SIZE;
202 while cursor < checksum_index {
203 let key = K::read(&mut buf.as_ref()[cursor..].as_ref())
205 .expect("unable to read key from blob");
206 cursor += key.encode_size();
207
208 let value = V::read_cfg(&mut buf.as_ref()[cursor..].as_ref(), codec_config)
210 .expect("unable to read value from blob");
211 lengths.insert(key.clone(), Info::new(cursor, value.encode_size()));
212 cursor += value.encode_size();
213 data.insert(key, value);
214 }
215
216 Ok((
218 data,
219 Wrapper::new(blob, version, lengths, buf.freeze().into()),
220 ))
221 }
222
223 pub fn get(&self, key: &K) -> Option<&V> {
225 self.map.get(key)
226 }
227
228 pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
230 let value = self.map.get_mut(key)?;
232
233 let state = self.state.get_mut();
237 state.blobs[state.cursor].modified.insert(key.clone());
238 state.blobs[1 - state.cursor].modified.insert(key.clone());
239
240 Some(value)
241 }
242
243 pub fn clear(&mut self) {
246 self.map.clear();
248
249 let state = self.state.get_mut();
251 state.key_order_changed = state.next_version;
252 self.keys.set(0);
253 }
254
255 pub fn put(&mut self, key: K, value: V) -> Option<V> {
261 let previous = self.map.insert(key.clone(), value);
263
264 let state = self.state.get_mut();
268 if previous.is_some() {
269 state.blobs[state.cursor].modified.insert(key.clone());
270 state.blobs[1 - state.cursor].modified.insert(key);
271 } else {
272 state.key_order_changed = state.next_version;
273 }
274 let _ = self.keys.try_set(self.map.len());
275 previous
276 }
277
278 pub async fn put_sync(&mut self, key: K, value: V) -> Result<(), Error> {
280 self.put(key, value);
281 self.sync().await
282 }
283
284 pub fn upsert(&mut self, key: K, f: impl FnOnce(&mut V))
286 where
287 V: Default,
288 {
289 if let Some(value) = self.get_mut(&key) {
290 f(value);
292 } else {
293 let mut value = V::default();
295 f(&mut value);
296 self.put(key, value);
297 }
298 }
299
300 pub async fn upsert_sync(&mut self, key: K, f: impl FnOnce(&mut V)) -> Result<(), Error>
302 where
303 V: Default,
304 {
305 self.upsert(key, f);
306 self.sync().await
307 }
308
309 pub fn remove(&mut self, key: &K) -> Option<V> {
311 let past = self.map.remove(key);
313
314 if past.is_some() {
316 let state = self.state.get_mut();
317 state.key_order_changed = state.next_version;
318 }
319 let _ = self.keys.try_set(self.map.len());
320
321 past
322 }
323
324 pub fn keys(&self) -> impl Iterator<Item = &K> {
326 self.map.keys()
327 }
328
329 pub fn retain(&mut self, mut f: impl FnMut(&K, &V) -> bool) {
331 let old_len = self.map.len();
333 self.map.retain(|k, v| f(k, v));
334 let new_len = self.map.len();
335
336 if new_len != old_len {
338 let state = self.state.get_mut();
339 state.key_order_changed = state.next_version;
340 let _ = self.keys.try_set(self.map.len());
341 }
342 }
343
344 pub async fn sync(&self) -> Result<(), Error> {
346 let mut state = self.state.lock().await;
349
350 let cursor = state.cursor;
352 let next_version = state.next_version;
353 let key_order_changed = state.key_order_changed;
354
355 let past_version = state.blobs[cursor].version;
361 let next_next_version = next_version.checked_add(1).expect("version overflow");
362
363 let target_cursor = 1 - cursor;
365
366 state.cursor = target_cursor;
368 state.next_version = next_next_version;
369
370 let target = &mut state.blobs[target_cursor];
372
373 let mut overwrite = true;
376 let mut writes = vec![];
377 if key_order_changed < past_version {
378 let write_capacity = target.modified.len() + 2;
379 writes.reserve(write_capacity);
380 for key in target.modified.iter() {
381 let info = target.lengths.get(key).expect("key must exist");
382 let new_value = self.map.get(key).expect("key must exist");
383 if info.length == new_value.encode_size() {
384 let encoded = new_value.encode_mut();
386 target.data[info.start..info.start + info.length].copy_from_slice(&encoded);
387 writes.push(target.blob.write_at(info.start as u64, encoded));
388 } else {
389 overwrite = false;
391 break;
392 }
393 }
394 } else {
395 overwrite = false;
397 }
398
399 target.modified.clear();
401
402 if overwrite {
404 let version = next_version.to_be_bytes();
406 target.data[0..8].copy_from_slice(&version);
407 writes.push(target.blob.write_at(0, version.as_slice().into()));
408
409 let checksum_index = target.data.len() - crc32::Digest::SIZE;
411 let checksum = Crc32::checksum(&target.data[..checksum_index]).to_be_bytes();
412 target.data[checksum_index..].copy_from_slice(&checksum);
413 writes.push(
414 target
415 .blob
416 .write_at(checksum_index as u64, checksum.as_slice().into()),
417 );
418
419 try_join_all(writes).await?;
421 target.blob.sync().await?;
422
423 target.version = next_version;
425 self.sync_overwrites.inc();
426 return Ok(());
427 }
428
429 let mut lengths = HashMap::new();
431 let mut next_data = Vec::with_capacity(target.data.len());
432 next_data.put_u64(next_version);
433
434 for (key, value) in &self.map {
436 key.write(&mut next_data);
437 let start = next_data.len();
438 value.write(&mut next_data);
439 lengths.insert(key.clone(), Info::new(start, value.encode_size()));
440 }
441 next_data.put_u32(Crc32::checksum(&next_data[..]));
442
443 target.blob.write_at(0, next_data.clone()).await?;
445 if next_data.len() < target.data.len() {
446 target.blob.resize(next_data.len() as u64).await?;
447 }
448 target.blob.sync().await?;
449
450 target.version = next_version;
452 target.lengths = lengths;
453 target.data = next_data;
454
455 self.sync_rewrites.inc();
456 Ok(())
457 }
458
459 pub async fn destroy(self) -> Result<(), Error> {
461 let state = self.state.into_inner();
462 for (i, wrapper) in state.blobs.into_iter().enumerate() {
463 drop(wrapper.blob);
464 self.context
465 .remove(&self.partition, Some(BLOB_NAMES[i]))
466 .await?;
467 debug!(blob = i, "destroyed blob");
468 }
469 match self.context.remove(&self.partition, None).await {
470 Ok(()) => {}
471 Err(RError::PartitionMissing(_)) => {
472 }
474 Err(err) => return Err(Error::Runtime(err)),
475 }
476 Ok(())
477 }
478}