commonware_storage/cache/
storage.rs

1use super::{Config, Error};
2use crate::{
3    journal::segmented::variable::{Config as JConfig, Journal},
4    kv,
5    rmap::RMap,
6};
7use bytes::{Buf, BufMut};
8use commonware_codec::{varint::UInt, CodecShared, EncodeSize, Read, ReadExt, Write};
9use commonware_runtime::{telemetry::metrics::status::GaugeExt, Metrics, Storage};
10use futures::{future::try_join_all, pin_mut, StreamExt};
11use prometheus_client::metrics::{counter::Counter, gauge::Gauge};
12use std::collections::{BTreeMap, BTreeSet};
13use tracing::debug;
14
15/// Record stored in the `Cache`.
16struct Record<V: CodecShared> {
17    index: u64,
18    value: V,
19}
20
21impl<V: CodecShared> Record<V> {
22    /// Create a new `Record`.
23    const fn new(index: u64, value: V) -> Self {
24        Self { index, value }
25    }
26}
27
28impl<V: CodecShared> Write for Record<V> {
29    fn write(&self, buf: &mut impl BufMut) {
30        UInt(self.index).write(buf);
31        self.value.write(buf);
32    }
33}
34
35impl<V: CodecShared> Read for Record<V> {
36    type Cfg = V::Cfg;
37
38    fn read_cfg(buf: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
39        let index = UInt::read(buf)?.into();
40        let value = V::read_cfg(buf, cfg)?;
41        Ok(Self { index, value })
42    }
43}
44
45impl<V: CodecShared> EncodeSize for Record<V> {
46    fn encode_size(&self) -> usize {
47        UInt(self.index).encode_size() + self.value.encode_size()
48    }
49}
50
51#[cfg(feature = "arbitrary")]
52impl<V: CodecShared> arbitrary::Arbitrary<'_> for Record<V>
53where
54    V: for<'a> arbitrary::Arbitrary<'a>,
55{
56    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
57        Ok(Self::new(u.arbitrary()?, u.arbitrary()?))
58    }
59}
60
61/// Implementation of `Cache` storage.
62pub struct Cache<E: Storage + Metrics, V: CodecShared> {
63    items_per_blob: u64,
64    journal: Journal<E, Record<V>>,
65    pending: BTreeSet<u64>,
66
67    // Oldest allowed section to read from. This is updated when `prune` is called.
68    oldest_allowed: Option<u64>,
69    indices: BTreeMap<u64, u64>,
70    intervals: RMap,
71
72    items_tracked: Gauge,
73    gets: Counter,
74    has: Counter,
75    syncs: Counter,
76}
77
78impl<E: Storage + Metrics, V: CodecShared> Cache<E, V> {
79    /// Calculate the section for a given index.
80    const fn section(&self, index: u64) -> u64 {
81        (index / self.items_per_blob) * self.items_per_blob
82    }
83
84    /// Initialize a new `Cache` instance.
85    ///
86    /// The in-memory index for `Cache` is populated during this call
87    /// by replaying the journal.
88    pub async fn init(context: E, cfg: Config<V::Cfg>) -> Result<Self, Error> {
89        // Initialize journal
90        let journal = Journal::<E, Record<V>>::init(
91            context.with_label("journal"),
92            JConfig {
93                partition: cfg.partition,
94                compression: cfg.compression,
95                codec_config: cfg.codec_config,
96                buffer_pool: cfg.buffer_pool,
97                write_buffer: cfg.write_buffer,
98            },
99        )
100        .await?;
101
102        // Initialize keys and run corruption check
103        let mut indices = BTreeMap::new();
104        let mut intervals = RMap::new();
105        {
106            debug!("initializing cache");
107            let stream = journal.replay(0, 0, cfg.replay_buffer).await?;
108            pin_mut!(stream);
109            while let Some(result) = stream.next().await {
110                // Extract key from record
111                let (_, offset, _, data) = result?;
112
113                // Store index
114                indices.insert(data.index, offset);
115
116                // Store index in intervals
117                intervals.insert(data.index);
118            }
119            debug!(items = indices.len(), "cache initialized");
120        }
121
122        // Initialize metrics
123        let items_tracked = Gauge::default();
124        let gets = Counter::default();
125        let has = Counter::default();
126        let syncs = Counter::default();
127        context.register(
128            "items_tracked",
129            "Number of items tracked",
130            items_tracked.clone(),
131        );
132        context.register("gets", "Number of gets performed", gets.clone());
133        context.register("has", "Number of has performed", has.clone());
134        context.register("syncs", "Number of syncs called", syncs.clone());
135        let _ = items_tracked.try_set(indices.len());
136
137        // Return populated cache
138        Ok(Self {
139            items_per_blob: cfg.items_per_blob.get(),
140            journal,
141            pending: BTreeSet::new(),
142            oldest_allowed: None,
143            indices,
144            intervals,
145            items_tracked,
146            gets,
147            has,
148            syncs,
149        })
150    }
151
152    /// Retrieve an item from the [Cache].
153    pub async fn get(&self, index: u64) -> Result<Option<V>, Error> {
154        // Update metrics
155        self.gets.inc();
156
157        // Get index location
158        let offset = match self.indices.get(&index) {
159            Some(offset) => *offset,
160            None => return Ok(None),
161        };
162
163        // Fetch item from disk
164        let section = self.section(index);
165        let record = self.journal.get(section, offset).await?;
166        Ok(Some(record.value))
167    }
168
169    /// Retrieve the next gap in the [Cache].
170    pub fn next_gap(&self, index: u64) -> (Option<u64>, Option<u64>) {
171        self.intervals.next_gap(index)
172    }
173
174    /// Returns the first index in the [Cache].
175    pub fn first(&self) -> Option<u64> {
176        self.intervals.iter().next().map(|(&start, _)| start)
177    }
178
179    /// Returns up to `max` missing items starting from `start`.
180    ///
181    /// This method iterates through gaps between existing ranges, collecting missing indices
182    /// until either `max` items are found or there are no more gaps to fill.
183    pub fn missing_items(&self, start: u64, max: usize) -> Vec<u64> {
184        self.intervals.missing_items(start, max)
185    }
186
187    /// Check if an item exists in the [Cache].
188    pub fn has(&self, index: u64) -> bool {
189        // Update metrics
190        self.has.inc();
191
192        // Check if index exists
193        self.indices.contains_key(&index)
194    }
195
196    /// Prune [Cache] to the provided `min`.
197    ///
198    /// If this is called with a min lower than the last pruned, nothing
199    /// will happen.
200    pub async fn prune(&mut self, min: u64) -> Result<(), Error> {
201        // Update `min` to reflect section mask
202        let min = self.section(min);
203
204        // Check if min is less than last pruned
205        if let Some(oldest_allowed) = self.oldest_allowed {
206            if min <= oldest_allowed {
207                // We don't return an error in this case because the caller
208                // shouldn't be burdened with converting `min` to some section.
209                return Ok(());
210            }
211        }
212        debug!(min, "pruning cache");
213
214        // Prune journal
215        self.journal.prune(min).await.map_err(Error::Journal)?;
216
217        // Remove pending writes (no need to call `sync` as we are pruning)
218        loop {
219            let next = match self.pending.iter().next() {
220                Some(section) if *section < min => *section,
221                _ => break,
222            };
223            self.pending.remove(&next);
224        }
225
226        // Remove all indices that are less than min
227        loop {
228            let next = match self.indices.first_key_value() {
229                Some((index, _)) if *index < min => *index,
230                _ => break,
231            };
232            self.indices.remove(&next).unwrap();
233        }
234
235        // Remove all intervals that are less than min
236        if min > 0 {
237            self.intervals.remove(0, min - 1);
238        }
239
240        // Update last pruned (to prevent reads from
241        // pruned sections)
242        self.oldest_allowed = Some(min);
243        let _ = self.items_tracked.try_set(self.indices.len());
244        Ok(())
245    }
246
247    /// Store an item in the [Cache].
248    ///
249    /// If the index already exists, put does nothing and returns.
250    pub async fn put(&mut self, index: u64, value: V) -> Result<(), Error> {
251        // Check last pruned
252        let oldest_allowed = self.oldest_allowed.unwrap_or(0);
253        if index < oldest_allowed {
254            return Err(Error::AlreadyPrunedTo(oldest_allowed));
255        }
256
257        // Check for existing index
258        if self.indices.contains_key(&index) {
259            return Ok(());
260        }
261
262        // Store item in journal
263        let record = Record::new(index, value);
264        let section = self.section(index);
265        let (offset, _) = self.journal.append(section, record).await?;
266
267        // Store index
268        self.indices.insert(index, offset);
269
270        // Add index to intervals
271        self.intervals.insert(index);
272
273        // Add section to pending
274        self.pending.insert(section);
275
276        // Update metrics
277        let _ = self.items_tracked.try_set(self.indices.len());
278        Ok(())
279    }
280
281    /// Sync all pending writes.
282    pub async fn sync(&mut self) -> Result<(), Error> {
283        let mut syncs = Vec::with_capacity(self.pending.len());
284        for section in self.pending.iter() {
285            syncs.push(self.journal.sync(*section));
286            self.syncs.inc();
287        }
288        try_join_all(syncs).await?;
289        self.pending.clear();
290        Ok(())
291    }
292
293    /// Stores an item in the [Cache] and syncs it, plus any other pending writes, to disk.
294    ///
295    /// If the index already exists, the cache is just synced.
296    pub async fn put_sync(&mut self, index: u64, value: V) -> Result<(), Error> {
297        self.put(index, value).await?;
298        self.sync().await
299    }
300
301    /// Remove all persistent data created by this [Cache].
302    pub async fn destroy(self) -> Result<(), Error> {
303        self.journal.destroy().await.map_err(Error::Journal)
304    }
305}
306
307impl<E: Storage + Metrics, V: CodecShared> kv::Gettable for Cache<E, V> {
308    type Key = u64;
309    type Value = V;
310    type Error = Error;
311
312    async fn get(&self, key: &Self::Key) -> Result<Option<Self::Value>, Self::Error> {
313        self.get(*key).await
314    }
315}
316
317#[cfg(all(test, feature = "arbitrary"))]
318mod conformance {
319    use super::*;
320    use commonware_codec::conformance::CodecConformance;
321
322    commonware_conformance::conformance_tests! {
323        CodecConformance<Record<u64>>,
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use crate::kv::tests::{assert_gettable, assert_send};
331    use commonware_runtime::deterministic::Context;
332
333    type TestCache = Cache<Context, u64>;
334
335    #[allow(dead_code)]
336    fn assert_cache_futures_are_send(cache: &TestCache, key: &u64) {
337        assert_gettable(cache, key);
338    }
339
340    #[allow(dead_code)]
341    fn assert_cache_destroy_is_send(cache: TestCache) {
342        assert_send(cache.destroy());
343    }
344}