Skip to main content

commonware_storage/cache/
storage.rs

1use super::{Config, Error};
2use crate::{
3    journal::segmented::variable::{Config as JConfig, Journal},
4    rmap::RMap,
5};
6use commonware_codec::{varint::UInt, CodecShared, EncodeSize, Read, ReadExt, Write};
7use commonware_runtime::{
8    telemetry::metrics::{Counter, Gauge, GaugeExt, MetricsExt as _},
9    Buf, BufMut, Metrics, Storage,
10};
11use futures::{future::try_join_all, pin_mut, StreamExt};
12use std::collections::{BTreeMap, BTreeSet};
13use tracing::debug;
14
15/// Record stored in the `Cache`.
16struct Record<V: CodecShared> {
17    index: u64,
18    value: V,
19}
20
21impl<V: CodecShared> Record<V> {
22    /// Create a new `Record`.
23    const fn new(index: u64, value: V) -> Self {
24        Self { index, value }
25    }
26}
27
28impl<V: CodecShared> Write for Record<V> {
29    fn write(&self, buf: &mut impl BufMut) {
30        UInt(self.index).write(buf);
31        self.value.write(buf);
32    }
33}
34
35impl<V: CodecShared> Read for Record<V> {
36    type Cfg = V::Cfg;
37
38    fn read_cfg(buf: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
39        let index = UInt::read(buf)?.into();
40        let value = V::read_cfg(buf, cfg)?;
41        Ok(Self { index, value })
42    }
43}
44
45impl<V: CodecShared> EncodeSize for Record<V> {
46    fn encode_size(&self) -> usize {
47        UInt(self.index).encode_size() + self.value.encode_size()
48    }
49}
50
51#[cfg(feature = "arbitrary")]
52impl<V: CodecShared> arbitrary::Arbitrary<'_> for Record<V>
53where
54    V: for<'a> arbitrary::Arbitrary<'a>,
55{
56    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
57        Ok(Self::new(u.arbitrary()?, u.arbitrary()?))
58    }
59}
60
61/// Implementation of `Cache` storage.
62pub struct Cache<E: Storage + Metrics, V: CodecShared> {
63    items_per_blob: u64,
64    journal: Journal<E, Record<V>>,
65    pending: BTreeSet<u64>,
66
67    // Oldest allowed section to read from. This is updated when `prune` is called.
68    oldest_allowed: Option<u64>,
69    indices: BTreeMap<u64, u64>,
70    intervals: RMap,
71
72    items_tracked: Gauge,
73    gets: Counter,
74    has: Counter,
75    syncs: Counter,
76}
77
78impl<E: Storage + Metrics, V: CodecShared> Cache<E, V> {
79    /// Calculate the section for a given index.
80    const fn section(&self, index: u64) -> u64 {
81        (index / self.items_per_blob) * self.items_per_blob
82    }
83
84    /// Initialize a new `Cache` instance.
85    ///
86    /// The in-memory index for `Cache` is populated during this call
87    /// by replaying the journal.
88    pub async fn init(context: E, cfg: Config<V::Cfg>) -> Result<Self, Error> {
89        // Initialize journal
90        let journal = Journal::<E, Record<V>>::init(
91            context.child("journal"),
92            JConfig {
93                partition: cfg.partition,
94                compression: cfg.compression,
95                codec_config: cfg.codec_config,
96                page_cache: cfg.page_cache,
97                write_buffer: cfg.write_buffer,
98            },
99        )
100        .await?;
101
102        // Initialize keys and run corruption check
103        let mut indices = BTreeMap::new();
104        let mut intervals = RMap::new();
105        {
106            debug!("initializing cache");
107            let stream = journal.replay(0, 0, cfg.replay_buffer).await?;
108            pin_mut!(stream);
109            while let Some(result) = stream.next().await {
110                // Extract key from record
111                let (_, offset, _, data) = result?;
112
113                // Store index
114                indices.insert(data.index, offset);
115
116                // Store index in intervals
117                intervals.insert(data.index);
118            }
119            debug!(items = indices.len(), "cache initialized");
120        }
121
122        // Initialize metrics
123        let items_tracked = context.gauge("items_tracked", "Number of items tracked");
124        let gets = context.counter("gets", "Number of gets performed");
125        let has = context.counter("has", "Number of has performed");
126        let syncs = context.counter("syncs", "Number of syncs called");
127        let _ = items_tracked.try_set(indices.len());
128
129        // Return populated cache
130        Ok(Self {
131            items_per_blob: cfg.items_per_blob.get(),
132            journal,
133            pending: BTreeSet::new(),
134            oldest_allowed: None,
135            indices,
136            intervals,
137            items_tracked,
138            gets,
139            has,
140            syncs,
141        })
142    }
143
144    /// Retrieve an item from the [Cache].
145    pub async fn get(&self, index: u64) -> Result<Option<V>, Error> {
146        // Update metrics
147        self.gets.inc();
148
149        // Get index location
150        let offset = match self.indices.get(&index) {
151            Some(offset) => *offset,
152            None => return Ok(None),
153        };
154
155        // Fetch item from disk
156        let section = self.section(index);
157        let record = self.journal.get(section, offset).await?;
158        Ok(Some(record.value))
159    }
160
161    /// Retrieve the next gap in the [Cache].
162    pub fn next_gap(&self, index: u64) -> (Option<u64>, Option<u64>) {
163        self.intervals.next_gap(index)
164    }
165
166    /// Returns the first index in the [Cache].
167    pub fn first(&self) -> Option<u64> {
168        self.intervals.iter().next().map(|(&start, _)| start)
169    }
170
171    /// Returns up to `max` missing items starting from `start`.
172    ///
173    /// This method iterates through gaps between existing ranges, collecting missing indices
174    /// until either `max` items are found or there are no more gaps to fill.
175    pub fn missing_items(&self, start: u64, max: usize) -> Vec<u64> {
176        self.intervals.missing_items(start, max)
177    }
178
179    /// Check if an item exists in the [Cache].
180    pub fn has(&self, index: u64) -> bool {
181        // Update metrics
182        self.has.inc();
183
184        // Check if index exists
185        self.indices.contains_key(&index)
186    }
187
188    /// Prune [Cache] to the provided `min`.
189    ///
190    /// If this is called with a min lower than the last pruned, nothing
191    /// will happen.
192    pub async fn prune(&mut self, min: u64) -> Result<(), Error> {
193        // Update `min` to reflect section mask
194        let min = self.section(min);
195
196        // Check if min is less than last pruned
197        if let Some(oldest_allowed) = self.oldest_allowed {
198            if min <= oldest_allowed {
199                // We don't return an error in this case because the caller
200                // shouldn't be burdened with converting `min` to some section.
201                return Ok(());
202            }
203        }
204        debug!(min, "pruning cache");
205
206        // Prune journal
207        self.journal.prune(min).await.map_err(Error::Journal)?;
208
209        // Remove pending writes (no need to call `sync` as we are pruning)
210        loop {
211            let next = match self.pending.iter().next() {
212                Some(section) if *section < min => *section,
213                _ => break,
214            };
215            self.pending.remove(&next);
216        }
217
218        // Remove all indices that are less than min
219        loop {
220            let next = match self.indices.first_key_value() {
221                Some((index, _)) if *index < min => *index,
222                _ => break,
223            };
224            self.indices.remove(&next).unwrap();
225        }
226
227        // Remove all intervals that are less than min
228        if min > 0 {
229            self.intervals.remove(0, min - 1);
230        }
231
232        // Update last pruned (to prevent reads from
233        // pruned sections)
234        self.oldest_allowed = Some(min);
235        let _ = self.items_tracked.try_set(self.indices.len());
236        Ok(())
237    }
238
239    /// Store an item in the [Cache].
240    ///
241    /// If the index already exists, put does nothing and returns.
242    pub async fn put(&mut self, index: u64, value: V) -> Result<(), Error> {
243        // Check last pruned
244        let oldest_allowed = self.oldest_allowed.unwrap_or(0);
245        if index < oldest_allowed {
246            return Err(Error::AlreadyPrunedTo(oldest_allowed));
247        }
248
249        // Check for existing index
250        if self.indices.contains_key(&index) {
251            return Ok(());
252        }
253
254        // Store item in journal
255        let record = Record::new(index, value);
256        let section = self.section(index);
257        let (offset, _) = self.journal.append(section, &record).await?;
258
259        // Store index
260        self.indices.insert(index, offset);
261
262        // Add index to intervals
263        self.intervals.insert(index);
264
265        // Add section to pending
266        self.pending.insert(section);
267
268        // Update metrics
269        let _ = self.items_tracked.try_set(self.indices.len());
270        Ok(())
271    }
272
273    /// Sync all pending writes.
274    pub async fn sync(&mut self) -> Result<(), Error> {
275        let mut syncs = Vec::with_capacity(self.pending.len());
276        for section in self.pending.iter() {
277            syncs.push(self.journal.sync(*section));
278            self.syncs.inc();
279        }
280        try_join_all(syncs).await?;
281        self.pending.clear();
282        Ok(())
283    }
284
285    /// Stores an item in the [Cache] and syncs it, plus any other pending writes, to disk.
286    ///
287    /// If the index already exists, the cache is just synced.
288    pub async fn put_sync(&mut self, index: u64, value: V) -> Result<(), Error> {
289        self.put(index, value).await?;
290        self.sync().await
291    }
292
293    /// Remove all persistent data created by this [Cache].
294    pub async fn destroy(self) -> Result<(), Error> {
295        self.journal.destroy().await.map_err(Error::Journal)
296    }
297}
298
299#[cfg(all(test, feature = "arbitrary"))]
300mod conformance {
301    use super::*;
302    use commonware_codec::conformance::CodecConformance;
303
304    commonware_conformance::conformance_tests! {
305        CodecConformance<Record<u64>>,
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use commonware_runtime::deterministic::Context;
313
314    type TestCache = Cache<Context, u64>;
315
316    fn is_send<T: Send>(_: T) {}
317
318    #[allow(dead_code)]
319    fn assert_cache_futures_are_send(cache: &TestCache, key: &u64) {
320        is_send(cache.get(*key));
321    }
322
323    #[allow(dead_code)]
324    fn assert_cache_destroy_is_send(cache: TestCache) {
325        is_send(cache.destroy());
326    }
327}