commonware_storage/cache/
storage.rs

1use super::{Config, Error};
2use crate::{
3    journal::segmented::variable::{Config as JConfig, Journal},
4    rmap::RMap,
5};
6use bytes::{Buf, BufMut};
7use commonware_codec::{varint::UInt, Codec, EncodeSize, Read, ReadExt, Write};
8use commonware_runtime::{telemetry::metrics::status::GaugeExt, Metrics, Storage};
9use futures::{future::try_join_all, pin_mut, StreamExt};
10use prometheus_client::metrics::{counter::Counter, gauge::Gauge};
11use std::collections::{BTreeMap, BTreeSet};
12use tracing::debug;
13
14/// Record stored in the `Cache`.
15struct Record<V: Codec> {
16    index: u64,
17    value: V,
18}
19
20impl<V: Codec> Record<V> {
21    /// Create a new `Record`.
22    const fn new(index: u64, value: V) -> Self {
23        Self { index, value }
24    }
25}
26
27impl<V: Codec> Write for Record<V> {
28    fn write(&self, buf: &mut impl BufMut) {
29        UInt(self.index).write(buf);
30        self.value.write(buf);
31    }
32}
33
34impl<V: Codec> Read for Record<V> {
35    type Cfg = V::Cfg;
36
37    fn read_cfg(buf: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
38        let index = UInt::read(buf)?.into();
39        let value = V::read_cfg(buf, cfg)?;
40        Ok(Self { index, value })
41    }
42}
43
44impl<V: Codec> EncodeSize for Record<V> {
45    fn encode_size(&self) -> usize {
46        UInt(self.index).encode_size() + self.value.encode_size()
47    }
48}
49
50#[cfg(feature = "arbitrary")]
51impl<V: Codec> arbitrary::Arbitrary<'_> for Record<V>
52where
53    V: for<'a> arbitrary::Arbitrary<'a>,
54{
55    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
56        Ok(Self::new(u.arbitrary()?, u.arbitrary()?))
57    }
58}
59
60/// Implementation of `Cache` storage.
61pub struct Cache<E: Storage + Metrics, V: Codec> {
62    items_per_blob: u64,
63    journal: Journal<E, Record<V>>,
64    pending: BTreeSet<u64>,
65
66    // Oldest allowed section to read from. This is updated when `prune` is called.
67    oldest_allowed: Option<u64>,
68    indices: BTreeMap<u64, u32>,
69    intervals: RMap,
70
71    items_tracked: Gauge,
72    gets: Counter,
73    has: Counter,
74    syncs: Counter,
75}
76
77impl<E: Storage + Metrics, V: Codec> Cache<E, V> {
78    /// Calculate the section for a given index.
79    const fn section(&self, index: u64) -> u64 {
80        (index / self.items_per_blob) * self.items_per_blob
81    }
82
83    /// Initialize a new `Cache` instance.
84    ///
85    /// The in-memory index for `Cache` is populated during this call
86    /// by replaying the journal.
87    pub async fn init(context: E, cfg: Config<V::Cfg>) -> Result<Self, Error> {
88        // Initialize journal
89        let journal = Journal::<E, Record<V>>::init(
90            context.with_label("journal"),
91            JConfig {
92                partition: cfg.partition,
93                compression: cfg.compression,
94                codec_config: cfg.codec_config,
95                buffer_pool: cfg.buffer_pool,
96                write_buffer: cfg.write_buffer,
97            },
98        )
99        .await?;
100
101        // Initialize keys and run corruption check
102        let mut indices = BTreeMap::new();
103        let mut intervals = RMap::new();
104        {
105            debug!("initializing cache");
106            let stream = journal.replay(0, 0, cfg.replay_buffer).await?;
107            pin_mut!(stream);
108            while let Some(result) = stream.next().await {
109                // Extract key from record
110                let (_, offset, _, data) = result?;
111
112                // Store index
113                indices.insert(data.index, offset);
114
115                // Store index in intervals
116                intervals.insert(data.index);
117            }
118            debug!(items = indices.len(), "cache initialized");
119        }
120
121        // Initialize metrics
122        let items_tracked = Gauge::default();
123        let gets = Counter::default();
124        let has = Counter::default();
125        let syncs = Counter::default();
126        context.register(
127            "items_tracked",
128            "Number of items tracked",
129            items_tracked.clone(),
130        );
131        context.register("gets", "Number of gets performed", gets.clone());
132        context.register("has", "Number of has performed", has.clone());
133        context.register("syncs", "Number of syncs called", syncs.clone());
134        let _ = items_tracked.try_set(indices.len());
135
136        // Return populated cache
137        Ok(Self {
138            items_per_blob: cfg.items_per_blob.get(),
139            journal,
140            pending: BTreeSet::new(),
141            oldest_allowed: None,
142            indices,
143            intervals,
144            items_tracked,
145            gets,
146            has,
147            syncs,
148        })
149    }
150
151    /// Retrieve an item from the [Cache].
152    pub async fn get(&self, index: u64) -> Result<Option<V>, Error> {
153        // Update metrics
154        self.gets.inc();
155
156        // Get index location
157        let offset = match self.indices.get(&index) {
158            Some(offset) => *offset,
159            None => return Ok(None),
160        };
161
162        // Fetch item from disk
163        let section = self.section(index);
164        let record = self.journal.get(section, offset).await?;
165        Ok(Some(record.value))
166    }
167
168    /// Retrieve the next gap in the [Cache].
169    pub fn next_gap(&self, index: u64) -> (Option<u64>, Option<u64>) {
170        self.intervals.next_gap(index)
171    }
172
173    /// Returns the first index in the [Cache].
174    pub fn first(&self) -> Option<u64> {
175        self.intervals.iter().next().map(|(&start, _)| start)
176    }
177
178    /// Returns up to `max` missing items starting from `start`.
179    ///
180    /// This method iterates through gaps between existing ranges, collecting missing indices
181    /// until either `max` items are found or there are no more gaps to fill.
182    pub fn missing_items(&self, start: u64, max: usize) -> Vec<u64> {
183        self.intervals.missing_items(start, max)
184    }
185
186    /// Check if an item exists in the [Cache].
187    pub fn has(&self, index: u64) -> bool {
188        // Update metrics
189        self.has.inc();
190
191        // Check if index exists
192        self.indices.contains_key(&index)
193    }
194
195    /// Prune [Cache] to the provided `min`.
196    ///
197    /// If this is called with a min lower than the last pruned, nothing
198    /// will happen.
199    pub async fn prune(&mut self, min: u64) -> Result<(), Error> {
200        // Update `min` to reflect section mask
201        let min = self.section(min);
202
203        // Check if min is less than last pruned
204        if let Some(oldest_allowed) = self.oldest_allowed {
205            if min <= oldest_allowed {
206                // We don't return an error in this case because the caller
207                // shouldn't be burdened with converting `min` to some section.
208                return Ok(());
209            }
210        }
211        debug!(min, "pruning cache");
212
213        // Prune journal
214        self.journal.prune(min).await.map_err(Error::Journal)?;
215
216        // Remove pending writes (no need to call `sync` as we are pruning)
217        loop {
218            let next = match self.pending.iter().next() {
219                Some(section) if *section < min => *section,
220                _ => break,
221            };
222            self.pending.remove(&next);
223        }
224
225        // Remove all indices that are less than min
226        loop {
227            let next = match self.indices.first_key_value() {
228                Some((index, _)) if *index < min => *index,
229                _ => break,
230            };
231            self.indices.remove(&next).unwrap();
232        }
233
234        // Remove all intervals that are less than min
235        if min > 0 {
236            self.intervals.remove(0, min - 1);
237        }
238
239        // Update last pruned (to prevent reads from
240        // pruned sections)
241        self.oldest_allowed = Some(min);
242        let _ = self.items_tracked.try_set(self.indices.len());
243        Ok(())
244    }
245
246    /// Store an item in the [Cache].
247    ///
248    /// If the index already exists, put does nothing and returns.
249    pub async fn put(&mut self, index: u64, value: V) -> Result<(), Error> {
250        // Check last pruned
251        let oldest_allowed = self.oldest_allowed.unwrap_or(0);
252        if index < oldest_allowed {
253            return Err(Error::AlreadyPrunedTo(oldest_allowed));
254        }
255
256        // Check for existing index
257        if self.indices.contains_key(&index) {
258            return Ok(());
259        }
260
261        // Store item in journal
262        let record = Record::new(index, value);
263        let section = self.section(index);
264        let (offset, _) = self.journal.append(section, record).await?;
265
266        // Store index
267        self.indices.insert(index, offset);
268
269        // Add index to intervals
270        self.intervals.insert(index);
271
272        // Add section to pending
273        self.pending.insert(section);
274
275        // Update metrics
276        let _ = self.items_tracked.try_set(self.indices.len());
277        Ok(())
278    }
279
280    /// Sync all pending writes.
281    pub async fn sync(&mut self) -> Result<(), Error> {
282        let mut syncs = Vec::with_capacity(self.pending.len());
283        for section in self.pending.iter() {
284            syncs.push(self.journal.sync(*section));
285            self.syncs.inc();
286        }
287        try_join_all(syncs).await?;
288        self.pending.clear();
289        Ok(())
290    }
291
292    /// Stores an item in the [Cache] and syncs it, plus any other pending writes, to disk.
293    ///
294    /// If the index already exists, the cache is just synced.
295    pub async fn put_sync(&mut self, index: u64, value: V) -> Result<(), Error> {
296        self.put(index, value).await?;
297        self.sync().await
298    }
299
300    /// Close the [Cache].
301    ///
302    /// Any pending writes will be synced prior to closing.
303    pub async fn close(self) -> Result<(), Error> {
304        self.journal.close().await.map_err(Error::Journal)
305    }
306
307    /// Remove all persistent data created by this [Cache].
308    pub async fn destroy(self) -> Result<(), Error> {
309        self.journal.destroy().await.map_err(Error::Journal)
310    }
311}
312
313impl<E: Storage + Metrics, V: Codec> crate::store::Store for Cache<E, V> {
314    type Key = u64;
315    type Value = V;
316    type Error = Error;
317
318    async fn get(&self, key: &Self::Key) -> Result<Option<Self::Value>, Self::Error> {
319        self.get(*key).await
320    }
321}
322
323#[cfg(all(test, feature = "arbitrary"))]
324mod conformance {
325    use super::*;
326    use commonware_codec::conformance::CodecConformance;
327
328    commonware_conformance::conformance_tests! {
329        CodecConformance<Record<u64>>,
330    }
331}