monoutils_store/implementations/stores/
memstore.rs

1use std::{
2    collections::{HashMap, HashSet},
3    pin::Pin,
4    sync::Arc,
5};
6
7use bytes::Bytes;
8use futures::StreamExt;
9use libipld::Cid;
10use serde::{de::DeserializeOwned, Serialize};
11use tokio::{io::AsyncRead, sync::RwLock};
12
13use crate::{
14    utils, Chunker, Codec, FixedSizeChunker, FlatLayout, IpldReferences, IpldStore,
15    IpldStoreSeekable, Layout, LayoutSeekable, SeekableReader, StoreError, StoreResult,
16};
17
18//--------------------------------------------------------------------------------------------------
19// Types
20//--------------------------------------------------------------------------------------------------
21
22/// An in-memory storage for IPLD node and raw blocks with reference counting.
23///
24/// This store maintains a reference count for each stored block. Reference counting is used to
25/// determine when a block can be safely removed from the store.
26#[derive(Debug, Clone)]
27// TODO: Use BalancedDagLayout as default
28pub struct MemoryStore<C = FixedSizeChunker, L = FlatLayout>
29where
30    C: Chunker,
31    L: Layout,
32{
33    /// Represents the blocks stored in the store.
34    ///
35    /// When data is added to the store, it may not necessarily fit into the acceptable block size
36    /// limit, so it is chunked into smaller blocks.
37    ///
38    /// The `usize` is used for counting the references to blocks within the store.
39    blocks: Arc<RwLock<HashMap<Cid, (usize, Bytes)>>>,
40
41    /// The chunking algorithm used to split data into chunks.
42    chunker: C,
43
44    /// The layout strategy used to store chunked data.
45    layout: L,
46}
47
48//--------------------------------------------------------------------------------------------------
49// Methods
50//--------------------------------------------------------------------------------------------------
51
52impl<C, L> MemoryStore<C, L>
53where
54    C: Chunker,
55    L: Layout,
56{
57    /// Creates a new `MemoryStore` with the given `chunker` and `layout`.
58    pub fn new(chunker: C, layout: L) -> Self {
59        MemoryStore {
60            blocks: Arc::new(RwLock::new(HashMap::new())),
61            chunker,
62            layout,
63        }
64    }
65
66    /// Prints all the blocks in the store.
67    pub fn debug(&self)
68    where
69        C: Clone + Send,
70        L: Clone + Send,
71    {
72        let store = self.clone();
73        tokio::spawn(async move {
74            let blocks = store.blocks.read().await;
75            for (cid, (size, bytes)) in blocks.iter() {
76                println!("\ncid: {} ({:?})\nkey: {}", cid, size, hex::encode(bytes));
77            }
78        });
79    }
80
81    /// Increments the reference count of the blocks with the given `Cid`s.
82    async fn inc_refs(&self, cids: impl Iterator<Item = &Cid>) {
83        for cid in cids {
84            if let Some((size, _)) = self.blocks.write().await.get_mut(cid) {
85                *size += 1;
86            }
87        }
88    }
89
90    /// Stores raw bytes in the store without any size checks.
91    async fn store_raw(&self, bytes: Bytes, codec: Codec) -> Cid {
92        let cid = utils::make_cid(codec, &bytes);
93        self.blocks.write().await.insert(cid, (1, bytes));
94        cid
95    }
96}
97
98//--------------------------------------------------------------------------------------------------
99// Trait Implementations
100//--------------------------------------------------------------------------------------------------
101
102impl<C, L> IpldStore for MemoryStore<C, L>
103where
104    C: Chunker + Clone + Send + Sync,
105    L: Layout + Clone + Send + Sync,
106{
107    async fn put_node<T>(&self, data: &T) -> StoreResult<Cid>
108    where
109        T: Serialize + IpldReferences + Sync,
110    {
111        // Serialize the data to bytes.
112        let bytes = Bytes::from(serde_ipld_dagcbor::to_vec(&data).map_err(StoreError::custom)?);
113
114        // Check if the data exceeds the node maximum block size.
115        if let Some(max_size) = self.get_node_block_max_size() {
116            if bytes.len() as u64 > max_size {
117                return Err(StoreError::NodeBlockTooLarge(bytes.len() as u64, max_size));
118            }
119        }
120
121        // Increment the reference count of the block.
122        self.inc_refs(data.get_references()).await;
123
124        Ok(self.store_raw(bytes, Codec::DagCbor).await)
125    }
126
127    async fn put_bytes<'a>(
128        &'a self,
129        reader: impl AsyncRead + Send + Sync + 'a,
130    ) -> StoreResult<Cid> {
131        let chunk_stream = self.chunker.chunk(reader).await?;
132        let mut cid_stream = self.layout.organize(chunk_stream, self.clone()).await?;
133
134        // Take the last `Cid` from the stream.
135        let mut cid = cid_stream.next().await.unwrap()?;
136        while let Some(result) = cid_stream.next().await {
137            cid = result?;
138        }
139
140        Ok(cid)
141    }
142
143    async fn put_raw_block(&self, bytes: impl Into<Bytes>) -> StoreResult<Cid> {
144        let bytes = bytes.into();
145        if let Some(max_size) = self.get_raw_block_max_size() {
146            if bytes.len() as u64 > max_size {
147                return Err(StoreError::RawBlockTooLarge(bytes.len() as u64, max_size));
148            }
149        }
150
151        Ok(self.store_raw(bytes, Codec::Raw).await)
152    }
153
154    async fn get_node<T>(&self, cid: &Cid) -> StoreResult<T>
155    where
156        T: DeserializeOwned,
157    {
158        let blocks = self.blocks.read().await;
159        match blocks.get(cid) {
160            Some((_, bytes)) => match cid.codec().try_into()? {
161                Codec::DagCbor => {
162                    let data =
163                        serde_ipld_dagcbor::from_slice::<T>(bytes).map_err(StoreError::custom)?;
164                    Ok(data)
165                }
166                codec => Err(StoreError::UnexpectedBlockCodec(Codec::DagCbor, codec)),
167            },
168            None => Err(StoreError::BlockNotFound(*cid)),
169        }
170    }
171
172    async fn get_bytes<'a>(
173        &'a self,
174        cid: &'a Cid,
175    ) -> StoreResult<Pin<Box<dyn AsyncRead + Send + Sync + 'a>>> {
176        self.layout.retrieve(cid, self.clone()).await
177    }
178
179    async fn get_raw_block(&self, cid: &Cid) -> StoreResult<Bytes> {
180        let blocks = self.blocks.read().await;
181        match blocks.get(cid) {
182            Some((_, bytes)) => match cid.codec().try_into()? {
183                Codec::Raw => Ok(bytes.clone()),
184                codec => Err(StoreError::UnexpectedBlockCodec(Codec::Raw, codec)),
185            },
186            None => Err(StoreError::BlockNotFound(*cid)),
187        }
188    }
189
190    #[inline]
191    async fn has(&self, cid: &Cid) -> bool {
192        let blocks = self.blocks.read().await;
193        blocks.contains_key(cid)
194    }
195
196    fn get_supported_codecs(&self) -> HashSet<Codec> {
197        let mut codecs = HashSet::new();
198        codecs.insert(Codec::DagCbor);
199        codecs.insert(Codec::Raw);
200        codecs
201    }
202
203    #[inline]
204    fn get_node_block_max_size(&self) -> Option<u64> {
205        self.chunker.chunk_max_size()
206    }
207
208    #[inline]
209    fn get_raw_block_max_size(&self) -> Option<u64> {
210        self.chunker.chunk_max_size()
211    }
212
213    async fn is_empty(&self) -> StoreResult<bool> {
214        Ok(self.blocks.read().await.is_empty())
215    }
216
217    async fn get_size(&self) -> StoreResult<u64> {
218        Ok(self.blocks.read().await.len() as u64)
219    }
220}
221
222impl<C, L> IpldStoreSeekable for MemoryStore<C, L>
223where
224    C: Chunker + Clone + Send + Sync,
225    L: LayoutSeekable + Clone + Send + Sync,
226{
227    async fn get_seekable_bytes<'a>(
228        &'a self,
229        cid: &'a Cid,
230    ) -> StoreResult<Pin<Box<dyn SeekableReader + Send + 'a>>> {
231        self.layout.retrieve_seekable(cid, self.clone()).await
232    }
233}
234
235impl Default for MemoryStore {
236    fn default() -> Self {
237        MemoryStore {
238            blocks: Arc::new(RwLock::new(HashMap::new())),
239            chunker: FixedSizeChunker::default(),
240            layout: FlatLayout::default(),
241        }
242    }
243}
244
245//--------------------------------------------------------------------------------------------------
246// Tests
247//--------------------------------------------------------------------------------------------------
248
249#[cfg(test)]
250mod tests {
251    use tokio::io::AsyncReadExt;
252
253    use super::*;
254
255    #[tokio::test]
256    async fn test_memory_store_put_and_get() -> anyhow::Result<()> {
257        let store = MemoryStore::default();
258
259        //================== Raw ==================
260
261        let data = vec![1, 2, 3, 4, 5];
262        let cid = store.put_bytes(&data[..]).await?;
263        let mut res = store.get_bytes(&cid).await?;
264
265        let mut buf = Vec::new();
266        res.read_to_end(&mut buf).await?;
267
268        assert_eq!(data, buf);
269
270        //================= IPLD =================
271
272        let data = fixtures::Directory {
273            name: "root".to_string(),
274            entries: vec![
275                utils::make_cid(Codec::Raw, &[1, 2, 3]),
276                utils::make_cid(Codec::Raw, &[4, 5, 6]),
277            ],
278        };
279
280        let cid = store.put_node(&data).await?;
281        let res = store.get_node::<fixtures::Directory>(&cid).await?;
282
283        assert_eq!(res, data);
284
285        Ok(())
286    }
287}
288
289#[cfg(test)]
290mod fixtures {
291    use serde::Deserialize;
292
293    use super::*;
294
295    //--------------------------------------------------------------------------------------------------
296    // Types
297    //--------------------------------------------------------------------------------------------------
298
299    #[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
300    pub(super) struct Directory {
301        pub(super) name: String,
302        pub(super) entries: Vec<Cid>,
303    }
304
305    //--------------------------------------------------------------------------------------------------
306    // Trait Implementations
307    //--------------------------------------------------------------------------------------------------
308
309    impl IpldReferences for Directory {
310        fn get_references<'a>(&'a self) -> Box<dyn Iterator<Item = &'a Cid> + Send + 'a> {
311            Box::new(self.entries.iter())
312        }
313    }
314}