use super::{Config, Error};
use crate::{
journal::segmented::variable::{Config as JConfig, Journal},
rmap::RMap,
};
use commonware_codec::{varint::UInt, CodecShared, EncodeSize, Read, ReadExt, Write};
use commonware_runtime::{telemetry::metrics::status::GaugeExt, Buf, BufMut, Metrics, Storage};
use futures::{future::try_join_all, pin_mut, StreamExt};
use prometheus_client::metrics::{counter::Counter, gauge::Gauge};
use std::collections::{BTreeMap, BTreeSet};
use tracing::debug;
struct Record<V: CodecShared> {
index: u64,
value: V,
}
impl<V: CodecShared> Record<V> {
const fn new(index: u64, value: V) -> Self {
Self { index, value }
}
}
impl<V: CodecShared> Write for Record<V> {
fn write(&self, buf: &mut impl BufMut) {
UInt(self.index).write(buf);
self.value.write(buf);
}
}
impl<V: CodecShared> Read for Record<V> {
type Cfg = V::Cfg;
fn read_cfg(buf: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
let index = UInt::read(buf)?.into();
let value = V::read_cfg(buf, cfg)?;
Ok(Self { index, value })
}
}
impl<V: CodecShared> EncodeSize for Record<V> {
fn encode_size(&self) -> usize {
UInt(self.index).encode_size() + self.value.encode_size()
}
}
#[cfg(feature = "arbitrary")]
impl<V: CodecShared> arbitrary::Arbitrary<'_> for Record<V>
where
V: for<'a> arbitrary::Arbitrary<'a>,
{
fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
Ok(Self::new(u.arbitrary()?, u.arbitrary()?))
}
}
pub struct Cache<E: Storage + Metrics, V: CodecShared> {
items_per_blob: u64,
journal: Journal<E, Record<V>>,
pending: BTreeSet<u64>,
oldest_allowed: Option<u64>,
indices: BTreeMap<u64, u64>,
intervals: RMap,
items_tracked: Gauge,
gets: Counter,
has: Counter,
syncs: Counter,
}
impl<E: Storage + Metrics, V: CodecShared> Cache<E, V> {
const fn section(&self, index: u64) -> u64 {
(index / self.items_per_blob) * self.items_per_blob
}
pub async fn init(context: E, cfg: Config<V::Cfg>) -> Result<Self, Error> {
let journal = Journal::<E, Record<V>>::init(
context.with_label("journal"),
JConfig {
partition: cfg.partition,
compression: cfg.compression,
codec_config: cfg.codec_config,
page_cache: cfg.page_cache,
write_buffer: cfg.write_buffer,
},
)
.await?;
let mut indices = BTreeMap::new();
let mut intervals = RMap::new();
{
debug!("initializing cache");
let stream = journal.replay(0, 0, cfg.replay_buffer).await?;
pin_mut!(stream);
while let Some(result) = stream.next().await {
let (_, offset, _, data) = result?;
indices.insert(data.index, offset);
intervals.insert(data.index);
}
debug!(items = indices.len(), "cache initialized");
}
let items_tracked = Gauge::default();
let gets = Counter::default();
let has = Counter::default();
let syncs = Counter::default();
context.register(
"items_tracked",
"Number of items tracked",
items_tracked.clone(),
);
context.register("gets", "Number of gets performed", gets.clone());
context.register("has", "Number of has performed", has.clone());
context.register("syncs", "Number of syncs called", syncs.clone());
let _ = items_tracked.try_set(indices.len());
Ok(Self {
items_per_blob: cfg.items_per_blob.get(),
journal,
pending: BTreeSet::new(),
oldest_allowed: None,
indices,
intervals,
items_tracked,
gets,
has,
syncs,
})
}
pub async fn get(&self, index: u64) -> Result<Option<V>, Error> {
self.gets.inc();
let offset = match self.indices.get(&index) {
Some(offset) => *offset,
None => return Ok(None),
};
let section = self.section(index);
let record = self.journal.get(section, offset).await?;
Ok(Some(record.value))
}
pub fn next_gap(&self, index: u64) -> (Option<u64>, Option<u64>) {
self.intervals.next_gap(index)
}
pub fn first(&self) -> Option<u64> {
self.intervals.iter().next().map(|(&start, _)| start)
}
pub fn missing_items(&self, start: u64, max: usize) -> Vec<u64> {
self.intervals.missing_items(start, max)
}
pub fn has(&self, index: u64) -> bool {
self.has.inc();
self.indices.contains_key(&index)
}
pub async fn prune(&mut self, min: u64) -> Result<(), Error> {
let min = self.section(min);
if let Some(oldest_allowed) = self.oldest_allowed {
if min <= oldest_allowed {
return Ok(());
}
}
debug!(min, "pruning cache");
self.journal.prune(min).await.map_err(Error::Journal)?;
loop {
let next = match self.pending.iter().next() {
Some(section) if *section < min => *section,
_ => break,
};
self.pending.remove(&next);
}
loop {
let next = match self.indices.first_key_value() {
Some((index, _)) if *index < min => *index,
_ => break,
};
self.indices.remove(&next).unwrap();
}
if min > 0 {
self.intervals.remove(0, min - 1);
}
self.oldest_allowed = Some(min);
let _ = self.items_tracked.try_set(self.indices.len());
Ok(())
}
pub async fn put(&mut self, index: u64, value: V) -> Result<(), Error> {
let oldest_allowed = self.oldest_allowed.unwrap_or(0);
if index < oldest_allowed {
return Err(Error::AlreadyPrunedTo(oldest_allowed));
}
if self.indices.contains_key(&index) {
return Ok(());
}
let record = Record::new(index, value);
let section = self.section(index);
let (offset, _) = self.journal.append(section, &record).await?;
self.indices.insert(index, offset);
self.intervals.insert(index);
self.pending.insert(section);
let _ = self.items_tracked.try_set(self.indices.len());
Ok(())
}
pub async fn sync(&mut self) -> Result<(), Error> {
let mut syncs = Vec::with_capacity(self.pending.len());
for section in self.pending.iter() {
syncs.push(self.journal.sync(*section));
self.syncs.inc();
}
try_join_all(syncs).await?;
self.pending.clear();
Ok(())
}
pub async fn put_sync(&mut self, index: u64, value: V) -> Result<(), Error> {
self.put(index, value).await?;
self.sync().await
}
pub async fn destroy(self) -> Result<(), Error> {
self.journal.destroy().await.map_err(Error::Journal)
}
}
#[cfg(all(test, feature = "arbitrary"))]
mod conformance {
use super::*;
use commonware_codec::conformance::CodecConformance;
commonware_conformance::conformance_tests! {
CodecConformance<Record<u64>>,
}
}
#[cfg(test)]
mod tests {
use super::*;
use commonware_runtime::deterministic::Context;
type TestCache = Cache<Context, u64>;
fn is_send<T: Send>(_: T) {}
#[allow(dead_code)]
fn assert_cache_futures_are_send(cache: &TestCache, key: &u64) {
is_send(cache.get(*key));
}
#[allow(dead_code)]
fn assert_cache_destroy_is_send(cache: TestCache) {
is_send(cache.destroy());
}
}