Skip to main content

commonware_storage/cache/
storage.rs

1use super::{Config, Error};
2use crate::{
3    journal::segmented::variable::{Config as JConfig, Journal},
4    rmap::RMap,
5};
6use commonware_codec::{varint::UInt, CodecShared, EncodeSize, Read, ReadExt, Write};
7use commonware_runtime::{telemetry::metrics::status::GaugeExt, Buf, BufMut, Metrics, Storage};
8use futures::{future::try_join_all, pin_mut, StreamExt};
9use prometheus_client::metrics::{counter::Counter, gauge::Gauge};
10use std::collections::{BTreeMap, BTreeSet};
11use tracing::debug;
12
13/// Record stored in the `Cache`.
14struct Record<V: CodecShared> {
15    index: u64,
16    value: V,
17}
18
19impl<V: CodecShared> Record<V> {
20    /// Create a new `Record`.
21    const fn new(index: u64, value: V) -> Self {
22        Self { index, value }
23    }
24}
25
26impl<V: CodecShared> Write for Record<V> {
27    fn write(&self, buf: &mut impl BufMut) {
28        UInt(self.index).write(buf);
29        self.value.write(buf);
30    }
31}
32
33impl<V: CodecShared> Read for Record<V> {
34    type Cfg = V::Cfg;
35
36    fn read_cfg(buf: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
37        let index = UInt::read(buf)?.into();
38        let value = V::read_cfg(buf, cfg)?;
39        Ok(Self { index, value })
40    }
41}
42
43impl<V: CodecShared> EncodeSize for Record<V> {
44    fn encode_size(&self) -> usize {
45        UInt(self.index).encode_size() + self.value.encode_size()
46    }
47}
48
49#[cfg(feature = "arbitrary")]
50impl<V: CodecShared> arbitrary::Arbitrary<'_> for Record<V>
51where
52    V: for<'a> arbitrary::Arbitrary<'a>,
53{
54    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
55        Ok(Self::new(u.arbitrary()?, u.arbitrary()?))
56    }
57}
58
59/// Implementation of `Cache` storage.
60pub struct Cache<E: Storage + Metrics, V: CodecShared> {
61    items_per_blob: u64,
62    journal: Journal<E, Record<V>>,
63    pending: BTreeSet<u64>,
64
65    // Oldest allowed section to read from. This is updated when `prune` is called.
66    oldest_allowed: Option<u64>,
67    indices: BTreeMap<u64, u64>,
68    intervals: RMap,
69
70    items_tracked: Gauge,
71    gets: Counter,
72    has: Counter,
73    syncs: Counter,
74}
75
76impl<E: Storage + Metrics, V: CodecShared> Cache<E, V> {
77    /// Calculate the section for a given index.
78    const fn section(&self, index: u64) -> u64 {
79        (index / self.items_per_blob) * self.items_per_blob
80    }
81
82    /// Initialize a new `Cache` instance.
83    ///
84    /// The in-memory index for `Cache` is populated during this call
85    /// by replaying the journal.
86    pub async fn init(context: E, cfg: Config<V::Cfg>) -> Result<Self, Error> {
87        // Initialize journal
88        let journal = Journal::<E, Record<V>>::init(
89            context.with_label("journal"),
90            JConfig {
91                partition: cfg.partition,
92                compression: cfg.compression,
93                codec_config: cfg.codec_config,
94                page_cache: cfg.page_cache,
95                write_buffer: cfg.write_buffer,
96            },
97        )
98        .await?;
99
100        // Initialize keys and run corruption check
101        let mut indices = BTreeMap::new();
102        let mut intervals = RMap::new();
103        {
104            debug!("initializing cache");
105            let stream = journal.replay(0, 0, cfg.replay_buffer).await?;
106            pin_mut!(stream);
107            while let Some(result) = stream.next().await {
108                // Extract key from record
109                let (_, offset, _, data) = result?;
110
111                // Store index
112                indices.insert(data.index, offset);
113
114                // Store index in intervals
115                intervals.insert(data.index);
116            }
117            debug!(items = indices.len(), "cache initialized");
118        }
119
120        // Initialize metrics
121        let items_tracked = Gauge::default();
122        let gets = Counter::default();
123        let has = Counter::default();
124        let syncs = Counter::default();
125        context.register(
126            "items_tracked",
127            "Number of items tracked",
128            items_tracked.clone(),
129        );
130        context.register("gets", "Number of gets performed", gets.clone());
131        context.register("has", "Number of has performed", has.clone());
132        context.register("syncs", "Number of syncs called", syncs.clone());
133        let _ = items_tracked.try_set(indices.len());
134
135        // Return populated cache
136        Ok(Self {
137            items_per_blob: cfg.items_per_blob.get(),
138            journal,
139            pending: BTreeSet::new(),
140            oldest_allowed: None,
141            indices,
142            intervals,
143            items_tracked,
144            gets,
145            has,
146            syncs,
147        })
148    }
149
150    /// Retrieve an item from the [Cache].
151    pub async fn get(&self, index: u64) -> Result<Option<V>, Error> {
152        // Update metrics
153        self.gets.inc();
154
155        // Get index location
156        let offset = match self.indices.get(&index) {
157            Some(offset) => *offset,
158            None => return Ok(None),
159        };
160
161        // Fetch item from disk
162        let section = self.section(index);
163        let record = self.journal.get(section, offset).await?;
164        Ok(Some(record.value))
165    }
166
167    /// Retrieve the next gap in the [Cache].
168    pub fn next_gap(&self, index: u64) -> (Option<u64>, Option<u64>) {
169        self.intervals.next_gap(index)
170    }
171
172    /// Returns the first index in the [Cache].
173    pub fn first(&self) -> Option<u64> {
174        self.intervals.iter().next().map(|(&start, _)| start)
175    }
176
177    /// Returns up to `max` missing items starting from `start`.
178    ///
179    /// This method iterates through gaps between existing ranges, collecting missing indices
180    /// until either `max` items are found or there are no more gaps to fill.
181    pub fn missing_items(&self, start: u64, max: usize) -> Vec<u64> {
182        self.intervals.missing_items(start, max)
183    }
184
185    /// Check if an item exists in the [Cache].
186    pub fn has(&self, index: u64) -> bool {
187        // Update metrics
188        self.has.inc();
189
190        // Check if index exists
191        self.indices.contains_key(&index)
192    }
193
194    /// Prune [Cache] to the provided `min`.
195    ///
196    /// If this is called with a min lower than the last pruned, nothing
197    /// will happen.
198    pub async fn prune(&mut self, min: u64) -> Result<(), Error> {
199        // Update `min` to reflect section mask
200        let min = self.section(min);
201
202        // Check if min is less than last pruned
203        if let Some(oldest_allowed) = self.oldest_allowed {
204            if min <= oldest_allowed {
205                // We don't return an error in this case because the caller
206                // shouldn't be burdened with converting `min` to some section.
207                return Ok(());
208            }
209        }
210        debug!(min, "pruning cache");
211
212        // Prune journal
213        self.journal.prune(min).await.map_err(Error::Journal)?;
214
215        // Remove pending writes (no need to call `sync` as we are pruning)
216        loop {
217            let next = match self.pending.iter().next() {
218                Some(section) if *section < min => *section,
219                _ => break,
220            };
221            self.pending.remove(&next);
222        }
223
224        // Remove all indices that are less than min
225        loop {
226            let next = match self.indices.first_key_value() {
227                Some((index, _)) if *index < min => *index,
228                _ => break,
229            };
230            self.indices.remove(&next).unwrap();
231        }
232
233        // Remove all intervals that are less than min
234        if min > 0 {
235            self.intervals.remove(0, min - 1);
236        }
237
238        // Update last pruned (to prevent reads from
239        // pruned sections)
240        self.oldest_allowed = Some(min);
241        let _ = self.items_tracked.try_set(self.indices.len());
242        Ok(())
243    }
244
245    /// Store an item in the [Cache].
246    ///
247    /// If the index already exists, put does nothing and returns.
248    pub async fn put(&mut self, index: u64, value: V) -> Result<(), Error> {
249        // Check last pruned
250        let oldest_allowed = self.oldest_allowed.unwrap_or(0);
251        if index < oldest_allowed {
252            return Err(Error::AlreadyPrunedTo(oldest_allowed));
253        }
254
255        // Check for existing index
256        if self.indices.contains_key(&index) {
257            return Ok(());
258        }
259
260        // Store item in journal
261        let record = Record::new(index, value);
262        let section = self.section(index);
263        let (offset, _) = self.journal.append(section, &record).await?;
264
265        // Store index
266        self.indices.insert(index, offset);
267
268        // Add index to intervals
269        self.intervals.insert(index);
270
271        // Add section to pending
272        self.pending.insert(section);
273
274        // Update metrics
275        let _ = self.items_tracked.try_set(self.indices.len());
276        Ok(())
277    }
278
279    /// Sync all pending writes.
280    pub async fn sync(&mut self) -> Result<(), Error> {
281        let mut syncs = Vec::with_capacity(self.pending.len());
282        for section in self.pending.iter() {
283            syncs.push(self.journal.sync(*section));
284            self.syncs.inc();
285        }
286        try_join_all(syncs).await?;
287        self.pending.clear();
288        Ok(())
289    }
290
291    /// Stores an item in the [Cache] and syncs it, plus any other pending writes, to disk.
292    ///
293    /// If the index already exists, the cache is just synced.
294    pub async fn put_sync(&mut self, index: u64, value: V) -> Result<(), Error> {
295        self.put(index, value).await?;
296        self.sync().await
297    }
298
299    /// Remove all persistent data created by this [Cache].
300    pub async fn destroy(self) -> Result<(), Error> {
301        self.journal.destroy().await.map_err(Error::Journal)
302    }
303}
304
305#[cfg(all(test, feature = "arbitrary"))]
306mod conformance {
307    use super::*;
308    use commonware_codec::conformance::CodecConformance;
309
310    commonware_conformance::conformance_tests! {
311        CodecConformance<Record<u64>>,
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use commonware_runtime::deterministic::Context;
319
320    type TestCache = Cache<Context, u64>;
321
322    fn is_send<T: Send>(_: T) {}
323
324    #[allow(dead_code)]
325    fn assert_cache_futures_are_send(cache: &TestCache, key: &u64) {
326        is_send(cache.get(*key));
327    }
328
329    #[allow(dead_code)]
330    fn assert_cache_destroy_is_send(cache: TestCache) {
331        is_send(cache.destroy());
332    }
333}