use super::{Config, Error};
use crate::Context;
use commonware_codec::{Codec, FixedSize, ReadExt};
use commonware_cryptography::{crc32, Crc32};
use commonware_runtime::{telemetry::metrics::status::GaugeExt, Blob, BufMut, Error as RError};
use commonware_utils::{sync::AsyncMutex, Span};
use futures::future::try_join_all;
use prometheus_client::metrics::{counter::Counter, gauge::Gauge};
use std::collections::{BTreeMap, BTreeSet, HashMap};
use tracing::{debug, warn};
const BLOB_NAMES: [&[u8]; 2] = [b"left", b"right"];
struct Info {
start: usize,
length: usize,
}
impl Info {
const fn new(start: usize, length: usize) -> Self {
Self { start, length }
}
}
struct Wrapper<B: Blob, K: Span> {
blob: B,
version: u64,
lengths: HashMap<K, Info>,
modified: BTreeSet<K>,
data: Vec<u8>,
}
impl<B: Blob, K: Span> Wrapper<B, K> {
const fn new(blob: B, version: u64, lengths: HashMap<K, Info>, data: Vec<u8>) -> Self {
Self {
blob,
version,
lengths,
modified: BTreeSet::new(),
data,
}
}
fn empty(blob: B) -> Self {
Self {
blob,
version: 0,
lengths: HashMap::new(),
modified: BTreeSet::new(),
data: Vec::new(),
}
}
}
struct State<B: Blob, K: Span> {
cursor: usize,
next_version: u64,
key_order_changed: u64,
blobs: [Wrapper<B, K>; 2],
}
pub struct Metadata<E: Context, K: Span, V: Codec> {
context: E,
map: BTreeMap<K, V>,
partition: String,
state: AsyncMutex<State<E::Blob, K>>,
sync_overwrites: Counter,
sync_rewrites: Counter,
keys: Gauge,
}
impl<E: Context, K: Span, V: Codec> Metadata<E, K, V> {
pub async fn init(context: E, cfg: Config<V::Cfg>) -> Result<Self, Error> {
let (left_blob, left_len) = context.open(&cfg.partition, BLOB_NAMES[0]).await?;
let (right_blob, right_len) = context.open(&cfg.partition, BLOB_NAMES[1]).await?;
let (left_map, left_wrapper) =
Self::load(&cfg.codec_config, 0, left_blob, left_len).await?;
let (right_map, right_wrapper) =
Self::load(&cfg.codec_config, 1, right_blob, right_len).await?;
let mut map = left_map;
let mut cursor = 0;
let mut version = left_wrapper.version;
if right_wrapper.version > left_wrapper.version {
cursor = 1;
map = right_map;
version = right_wrapper.version;
}
let next_version = version.checked_add(1).expect("version overflow");
let sync_rewrites = Counter::default();
let sync_overwrites = Counter::default();
let keys = Gauge::default();
context.register(
"sync_rewrites",
"number of syncs that rewrote all data",
sync_rewrites.clone(),
);
context.register(
"sync_overwrites",
"number of syncs that modified existing data",
sync_overwrites.clone(),
);
context.register("keys", "number of tracked keys", keys.clone());
let _ = keys.try_set(map.len());
Ok(Self {
context,
map,
partition: cfg.partition,
state: AsyncMutex::new(State {
cursor,
next_version,
key_order_changed: next_version, blobs: [left_wrapper, right_wrapper],
}),
sync_rewrites,
sync_overwrites,
keys,
})
}
async fn load(
codec_config: &V::Cfg,
index: usize,
blob: E::Blob,
len: u64,
) -> Result<(BTreeMap<K, V>, Wrapper<E::Blob, K>), Error> {
if len == 0 {
return Ok((BTreeMap::new(), Wrapper::empty(blob)));
}
let len: usize = len.try_into().expect("blob too large for platform");
let buf = blob.read_at(0, len).await?.coalesce();
if buf.len() < 8 + crc32::Digest::SIZE {
warn!(
blob = index,
len = buf.len(),
"blob is too short: truncating"
);
blob.resize(0).await?;
blob.sync().await?;
return Ok((BTreeMap::new(), Wrapper::empty(blob)));
}
let checksum_index = buf.len() - crc32::Digest::SIZE;
let stored_checksum =
u32::from_be_bytes(buf.as_ref()[checksum_index..].try_into().unwrap());
let computed_checksum = Crc32::checksum(&buf.as_ref()[..checksum_index]);
if stored_checksum != computed_checksum {
warn!(
blob = index,
stored = stored_checksum,
computed = computed_checksum,
"checksum mismatch: truncating"
);
blob.resize(0).await?;
blob.sync().await?;
return Ok((BTreeMap::new(), Wrapper::empty(blob)));
}
let version = u64::from_be_bytes(buf.as_ref()[..8].try_into().unwrap());
let mut data = BTreeMap::new();
let mut lengths = HashMap::new();
let mut cursor = u64::SIZE;
while cursor < checksum_index {
let key = K::read(&mut buf.as_ref()[cursor..].as_ref())
.expect("unable to read key from blob");
cursor += key.encode_size();
let value = V::read_cfg(&mut buf.as_ref()[cursor..].as_ref(), codec_config)
.expect("unable to read value from blob");
lengths.insert(key.clone(), Info::new(cursor, value.encode_size()));
cursor += value.encode_size();
data.insert(key, value);
}
Ok((
data,
Wrapper::new(blob, version, lengths, buf.freeze().into()),
))
}
pub fn get(&self, key: &K) -> Option<&V> {
self.map.get(key)
}
pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
let value = self.map.get_mut(key)?;
let state = self.state.get_mut();
state.blobs[state.cursor].modified.insert(key.clone());
state.blobs[1 - state.cursor].modified.insert(key.clone());
Some(value)
}
pub fn clear(&mut self) {
self.map.clear();
let state = self.state.get_mut();
state.key_order_changed = state.next_version;
self.keys.set(0);
}
pub fn put(&mut self, key: K, value: V) -> Option<V> {
let previous = self.map.insert(key.clone(), value);
let state = self.state.get_mut();
if previous.is_some() {
state.blobs[state.cursor].modified.insert(key.clone());
state.blobs[1 - state.cursor].modified.insert(key);
} else {
state.key_order_changed = state.next_version;
}
let _ = self.keys.try_set(self.map.len());
previous
}
pub async fn put_sync(&mut self, key: K, value: V) -> Result<(), Error> {
self.put(key, value);
self.sync().await
}
pub fn upsert(&mut self, key: K, f: impl FnOnce(&mut V))
where
V: Default,
{
if let Some(value) = self.get_mut(&key) {
f(value);
} else {
let mut value = V::default();
f(&mut value);
self.put(key, value);
}
}
pub async fn upsert_sync(&mut self, key: K, f: impl FnOnce(&mut V)) -> Result<(), Error>
where
V: Default,
{
self.upsert(key, f);
self.sync().await
}
pub fn remove(&mut self, key: &K) -> Option<V> {
let past = self.map.remove(key);
if past.is_some() {
let state = self.state.get_mut();
state.key_order_changed = state.next_version;
}
let _ = self.keys.try_set(self.map.len());
past
}
pub fn keys(&self) -> impl Iterator<Item = &K> {
self.map.keys()
}
pub fn retain(&mut self, mut f: impl FnMut(&K, &V) -> bool) {
let old_len = self.map.len();
self.map.retain(|k, v| f(k, v));
let new_len = self.map.len();
if new_len != old_len {
let state = self.state.get_mut();
state.key_order_changed = state.next_version;
let _ = self.keys.try_set(self.map.len());
}
}
pub async fn sync(&self) -> Result<(), Error> {
let mut state = self.state.lock().await;
let cursor = state.cursor;
let next_version = state.next_version;
let key_order_changed = state.key_order_changed;
let past_version = state.blobs[cursor].version;
let next_next_version = next_version.checked_add(1).expect("version overflow");
let target_cursor = 1 - cursor;
state.cursor = target_cursor;
state.next_version = next_next_version;
let target = &mut state.blobs[target_cursor];
let mut overwrite = true;
let mut writes = vec![];
if key_order_changed < past_version {
let write_capacity = target.modified.len() + 2;
writes.reserve(write_capacity);
for key in target.modified.iter() {
let info = target.lengths.get(key).expect("key must exist");
let new_value = self.map.get(key).expect("key must exist");
if info.length == new_value.encode_size() {
let encoded = new_value.encode_mut();
target.data[info.start..info.start + info.length].copy_from_slice(&encoded);
writes.push(target.blob.write_at(info.start as u64, encoded));
} else {
overwrite = false;
break;
}
}
} else {
overwrite = false;
}
target.modified.clear();
if overwrite {
let version = next_version.to_be_bytes();
target.data[0..8].copy_from_slice(&version);
writes.push(target.blob.write_at(0, version.as_slice().into()));
let checksum_index = target.data.len() - crc32::Digest::SIZE;
let checksum = Crc32::checksum(&target.data[..checksum_index]).to_be_bytes();
target.data[checksum_index..].copy_from_slice(&checksum);
writes.push(
target
.blob
.write_at(checksum_index as u64, checksum.as_slice().into()),
);
try_join_all(writes).await?;
target.blob.sync().await?;
target.version = next_version;
self.sync_overwrites.inc();
return Ok(());
}
let mut lengths = HashMap::new();
let mut next_data = Vec::with_capacity(target.data.len());
next_data.put_u64(next_version);
for (key, value) in &self.map {
key.write(&mut next_data);
let start = next_data.len();
value.write(&mut next_data);
lengths.insert(key.clone(), Info::new(start, value.encode_size()));
}
next_data.put_u32(Crc32::checksum(&next_data[..]));
target.blob.write_at(0, next_data.clone()).await?;
if next_data.len() < target.data.len() {
target.blob.resize(next_data.len() as u64).await?;
}
target.blob.sync().await?;
target.version = next_version;
target.lengths = lengths;
target.data = next_data;
self.sync_rewrites.inc();
Ok(())
}
pub async fn destroy(self) -> Result<(), Error> {
let state = self.state.into_inner();
for (i, wrapper) in state.blobs.into_iter().enumerate() {
drop(wrapper.blob);
self.context
.remove(&self.partition, Some(BLOB_NAMES[i]))
.await?;
debug!(blob = i, "destroyed blob");
}
match self.context.remove(&self.partition, None).await {
Ok(()) => {}
Err(RError::PartitionMissing(_)) => {
}
Err(err) => return Err(Error::Runtime(err)),
}
Ok(())
}
}