Skip to main content

common/storage/
in_memory.rs

1use std::collections::BTreeMap;
2use std::ops::RangeBounds;
3use std::sync::{Arc, RwLock};
4
5use async_trait::async_trait;
6use bytes::Bytes;
7
8use super::{MergeOperator, Storage, StorageSnapshot, WriteOptions};
9use crate::storage::RecordOp;
10use crate::{BytesRange, Record, StorageError, StorageIterator, StorageRead, StorageResult};
11
12/// In-memory implementation of the Storage trait using a BTreeMap.
13///
14/// This implementation stores all data in memory and is useful for testing
15/// or scenarios where durability is not required.
16pub struct InMemoryStorage {
17    data: Arc<RwLock<BTreeMap<Bytes, Bytes>>>,
18    merge_operator: Option<Arc<dyn MergeOperator + Send + Sync>>,
19}
20
21impl InMemoryStorage {
22    /// Creates a new InMemoryStorage instance with an empty store.
23    pub fn new() -> Self {
24        Self {
25            data: Arc::new(RwLock::new(BTreeMap::new())),
26            merge_operator: None,
27        }
28    }
29
30    /// Creates a new InMemoryStorage instance with an optional merge operator.
31    ///
32    /// If a merge operator is provided, the `merge` method will use it to combine
33    /// existing values with new values. If no merge operator is provided, the
34    /// `merge` method will return an error.
35    pub fn with_merge_operator(merge_operator: Arc<dyn MergeOperator + Send + Sync>) -> Self {
36        Self {
37            data: Arc::new(RwLock::new(BTreeMap::new())),
38            merge_operator: Some(merge_operator),
39        }
40    }
41}
42
43impl Default for InMemoryStorage {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49#[async_trait]
50impl StorageRead for InMemoryStorage {
51    /// Retrieves a single record by key from the in-memory store.
52    ///
53    /// Returns `None` if the key does not exist.
54    #[tracing::instrument(level = "trace", skip_all)]
55    async fn get(&self, key: Bytes) -> StorageResult<Option<Record>> {
56        let data = self
57            .data
58            .read()
59            .map_err(|e| StorageError::Internal(format!("Failed to acquire read lock: {}", e)))?;
60
61        match data.get(&key) {
62            Some(value) => Ok(Some(Record::new(key, value.clone()))),
63            None => Ok(None),
64        }
65    }
66
67    #[tracing::instrument(level = "trace", skip_all)]
68    async fn scan_iter(
69        &self,
70        range: BytesRange,
71    ) -> StorageResult<Box<dyn StorageIterator + Send + 'static>> {
72        let data = self
73            .data
74            .read()
75            .map_err(|e| StorageError::Internal(format!("Failed to acquire read lock: {}", e)))?;
76
77        // Collect all matching records into a Vec for the iterator
78        let records: Vec<Record> = data
79            .range((range.start_bound().cloned(), range.end_bound().cloned()))
80            .map(|(k, v)| Record::new(k.clone(), v.clone()))
81            .collect();
82
83        Ok(Box::new(InMemoryIterator { records, index: 0 }))
84    }
85}
86
87struct InMemoryIterator {
88    records: Vec<Record>,
89    index: usize,
90}
91
92#[async_trait]
93impl StorageIterator for InMemoryIterator {
94    #[tracing::instrument(level = "trace", skip_all)]
95    async fn next(&mut self) -> StorageResult<Option<Record>> {
96        if self.index >= self.records.len() {
97            Ok(None)
98        } else {
99            let record = self.records[self.index].clone();
100            self.index += 1;
101            Ok(Some(record))
102        }
103    }
104}
105
106/// In-memory snapshot that holds a copy of the data at the time of snapshot creation.
107///
108/// Provides a consistent read-only view of the database at the time the snapshot was created.
109pub struct InMemoryStorageSnapshot {
110    data: Arc<BTreeMap<Bytes, Bytes>>,
111}
112
113#[async_trait]
114impl StorageRead for InMemoryStorageSnapshot {
115    #[tracing::instrument(level = "trace", skip_all)]
116    async fn get(&self, key: Bytes) -> StorageResult<Option<Record>> {
117        match self.data.get(&key) {
118            Some(value) => Ok(Some(Record::new(key, value.clone()))),
119            None => Ok(None),
120        }
121    }
122
123    #[tracing::instrument(level = "trace", skip_all)]
124    async fn scan_iter(
125        &self,
126        range: BytesRange,
127    ) -> StorageResult<Box<dyn StorageIterator + Send + 'static>> {
128        // Collect all matching records into a Vec for the iterator
129        let records: Vec<Record> = self
130            .data
131            .range((range.start_bound().cloned(), range.end_bound().cloned()))
132            .map(|(k, v)| Record::new(k.clone(), v.clone()))
133            .collect();
134
135        Ok(Box::new(InMemoryIterator { records, index: 0 }))
136    }
137}
138
139#[async_trait]
140impl StorageSnapshot for InMemoryStorageSnapshot {}
141
142#[async_trait]
143impl Storage for InMemoryStorage {
144    async fn apply(&self, records: Vec<RecordOp>) -> StorageResult<()> {
145        let mut data = self
146            .data
147            .write()
148            .map_err(|e| StorageError::Internal(format!("Failed to acquire write lock: {}", e)))?;
149
150        for record in records {
151            match record {
152                RecordOp::Put(record) => {
153                    data.insert(record.key, record.value);
154                }
155                RecordOp::Merge(record) => {
156                    let existing_value = data.get(&record.key).cloned();
157                    let merged_value = self.merge_operator.as_ref().unwrap().merge(
158                        &record.key,
159                        existing_value,
160                        record.value.clone(),
161                    );
162                    data.insert(record.key, merged_value);
163                }
164                RecordOp::Delete(key) => {
165                    data.remove(&key);
166                }
167            }
168        }
169
170        Ok(())
171    }
172
173    /// Writes a batch of records to the in-memory store with default options.
174    ///
175    /// Delegates to [`put_with_options`](Self::put_with_options) with default options.
176    async fn put(&self, records: Vec<Record>) -> StorageResult<()> {
177        self.put_with_options(records, WriteOptions::default())
178            .await
179    }
180
181    /// Writes a batch of records to the in-memory store.
182    ///
183    /// All records are written atomically within a single write lock acquisition.
184    /// For in-memory storage, write options are ignored since there is no
185    /// durable storage to await.
186    async fn put_with_options(
187        &self,
188        records: Vec<Record>,
189        _options: WriteOptions,
190    ) -> StorageResult<()> {
191        let mut data = self
192            .data
193            .write()
194            .map_err(|e| StorageError::Internal(format!("Failed to acquire write lock: {}", e)))?;
195
196        for record in records {
197            data.insert(record.key, record.value);
198        }
199
200        Ok(())
201    }
202
203    /// Merges values for the given keys using the configured merge operator.
204    ///
205    /// This method requires a merge operator to be configured during construction.
206    /// For each record, it will:
207    /// 1. Get the existing value (if any)
208    /// 2. Call the merge operator to combine existing and new values
209    /// 3. Put the merged result back
210    ///
211    /// If no merge operator is configured, this method will return a
212    /// `StorageError::Storage` error.
213    async fn merge(&self, records: Vec<Record>) -> StorageResult<()> {
214        let merge_op = self
215            .merge_operator
216            .as_ref()
217            .ok_or_else(|| {
218                StorageError::Storage(
219                    "Merge operator not configured: in-memory storage requires a merge operator to be set during construction".to_string(),
220                )
221            })?;
222
223        let mut data = self
224            .data
225            .write()
226            .map_err(|e| StorageError::Internal(format!("Failed to acquire write lock: {}", e)))?;
227
228        for record in records {
229            let existing_value = data.get(&record.key).cloned();
230            let merged_value = merge_op.merge(&record.key, existing_value, record.value.clone());
231            data.insert(record.key, merged_value);
232        }
233
234        Ok(())
235    }
236
237    /// Creates a point-in-time snapshot of the in-memory storage.
238    ///
239    /// The snapshot provides a consistent read-only view of the database at the time
240    /// the snapshot was created. Reads from the snapshot will not see any subsequent
241    /// writes to the underlying storage.
242    async fn snapshot(&self) -> StorageResult<Arc<dyn StorageSnapshot>> {
243        let data = self
244            .data
245            .read()
246            .map_err(|e| StorageError::Internal(format!("Failed to acquire read lock: {}", e)))?;
247
248        // Clone the entire BTreeMap for the snapshot
249        let snapshot_data = Arc::new(data.clone());
250
251        Ok(Arc::new(InMemoryStorageSnapshot {
252            data: snapshot_data,
253        }))
254    }
255
256    async fn flush(&self) -> StorageResult<()> {
257        // No-op for in-memory storage - all writes are immediately visible
258        Ok(())
259    }
260
261    async fn close(&self) -> StorageResult<()> {
262        // No-op for in-memory storage
263        Ok(())
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    use bytes::BytesMut;
271    use std::ops::Bound;
272
273    /// Test merge operator that appends new value to existing value with a separator.
274    struct AppendMergeOperator;
275
276    impl MergeOperator for AppendMergeOperator {
277        fn merge(&self, _key: &Bytes, existing_value: Option<Bytes>, new_value: Bytes) -> Bytes {
278            match existing_value {
279                Some(existing) => {
280                    let mut result = BytesMut::from(existing);
281                    result.extend_from_slice(b",");
282                    result.extend_from_slice(&new_value);
283                    result.freeze()
284                }
285                None => new_value,
286            }
287        }
288    }
289
290    #[tokio::test]
291    async fn should_return_none_when_key_not_found() {
292        // given
293        let storage = InMemoryStorage::new();
294
295        // when
296        let result = storage.get(Bytes::from("missing_key")).await;
297
298        // then
299        assert!(result.is_ok());
300        assert!(result.unwrap().is_none());
301    }
302
303    #[tokio::test]
304    async fn should_store_and_retrieve_record() {
305        // given
306        let storage = InMemoryStorage::new();
307        let key = Bytes::from("test_key");
308        let value = Bytes::from("test_value");
309
310        // when
311        storage
312            .put(vec![Record::new(key.clone(), value.clone())])
313            .await
314            .unwrap();
315        let result = storage.get(key).await.unwrap();
316
317        // then
318        assert!(result.is_some());
319        let record = result.unwrap();
320        assert_eq!(record.key, Bytes::from("test_key"));
321        assert_eq!(record.value, value);
322    }
323
324    #[tokio::test]
325    async fn should_overwrite_existing_key() {
326        // given
327        let storage = InMemoryStorage::new();
328        let key = Bytes::from("test_key");
329        let initial_value = Bytes::from("initial_value");
330        let updated_value = Bytes::from("updated_value");
331
332        // when
333        storage
334            .put(vec![Record::new(key.clone(), initial_value)])
335            .await
336            .unwrap();
337        storage
338            .put(vec![Record::new(key.clone(), updated_value.clone())])
339            .await
340            .unwrap();
341        let result = storage.get(key).await.unwrap();
342
343        // then
344        assert!(result.is_some());
345        assert_eq!(result.unwrap().value, updated_value);
346    }
347
348    #[tokio::test]
349    async fn should_store_multiple_records() {
350        // given
351        let storage = InMemoryStorage::new();
352        let records = vec![
353            Record::new(Bytes::from("key1"), Bytes::from("value1")),
354            Record::new(Bytes::from("key2"), Bytes::from("value2")),
355            Record::new(Bytes::from("key3"), Bytes::from("value3")),
356        ];
357
358        // when
359        storage.put(records.clone()).await.unwrap();
360
361        // then
362        for record in records {
363            let retrieved = storage.get(record.key.clone()).await.unwrap();
364            assert!(retrieved.is_some());
365            assert_eq!(retrieved.unwrap().value, record.value);
366        }
367    }
368
369    #[tokio::test]
370    async fn should_scan_all_records_when_unbounded() {
371        // given
372        let storage = InMemoryStorage::new();
373        let records = vec![
374            Record::new(Bytes::from("a"), Bytes::from("value_a")),
375            Record::new(Bytes::from("b"), Bytes::from("value_b")),
376            Record::new(Bytes::from("c"), Bytes::from("value_c")),
377        ];
378        storage.put(records.clone()).await.unwrap();
379
380        // when
381        let scanned = storage.scan(BytesRange::unbounded()).await.unwrap();
382
383        // then
384        assert_eq!(scanned.len(), 3);
385        assert_eq!(scanned[0].key, Bytes::from("a"));
386        assert_eq!(scanned[1].key, Bytes::from("b"));
387        assert_eq!(scanned[2].key, Bytes::from("c"));
388    }
389
390    #[tokio::test]
391    async fn should_scan_records_with_prefix() {
392        // given
393        let storage = InMemoryStorage::new();
394        let records = vec![
395            Record::new(Bytes::from("prefix_a"), Bytes::from("value1")),
396            Record::new(Bytes::from("prefix_b"), Bytes::from("value2")),
397            Record::new(Bytes::from("other_c"), Bytes::from("value3")),
398        ];
399        storage.put(records).await.unwrap();
400
401        // when
402        let scanned = storage
403            .scan(BytesRange::prefix(Bytes::from("prefix_")))
404            .await
405            .unwrap();
406
407        // then
408        assert_eq!(scanned.len(), 2);
409        assert_eq!(scanned[0].key, Bytes::from("prefix_a"));
410        assert_eq!(scanned[1].key, Bytes::from("prefix_b"));
411    }
412
413    #[tokio::test]
414    async fn should_scan_records_in_bounded_range() {
415        // given
416        let storage = InMemoryStorage::new();
417        let records = vec![
418            Record::new(Bytes::from("a"), Bytes::from("value_a")),
419            Record::new(Bytes::from("b"), Bytes::from("value_b")),
420            Record::new(Bytes::from("c"), Bytes::from("value_c")),
421            Record::new(Bytes::from("d"), Bytes::from("value_d")),
422        ];
423        storage.put(records).await.unwrap();
424
425        // when
426        let range = BytesRange::new(
427            Bound::Included(Bytes::from("b")),
428            Bound::Excluded(Bytes::from("d")),
429        );
430        let scanned = storage.scan(range).await.unwrap();
431
432        // then
433        assert_eq!(scanned.len(), 2);
434        assert_eq!(scanned[0].key, Bytes::from("b"));
435        assert_eq!(scanned[1].key, Bytes::from("c"));
436    }
437
438    #[tokio::test]
439    async fn should_return_empty_vec_when_scanning_empty_storage() {
440        // given
441        let storage = InMemoryStorage::new();
442
443        // when
444        let scanned = storage.scan(BytesRange::unbounded()).await.unwrap();
445
446        // then
447        assert!(scanned.is_empty());
448    }
449
450    #[tokio::test]
451    async fn should_iterate_over_records() {
452        // given
453        let storage = InMemoryStorage::new();
454        let records = vec![
455            Record::new(Bytes::from("key1"), Bytes::from("value1")),
456            Record::new(Bytes::from("key2"), Bytes::from("value2")),
457        ];
458        storage.put(records).await.unwrap();
459
460        // when
461        let mut iter = storage.scan_iter(BytesRange::unbounded()).await.unwrap();
462        let first = iter.next().await.unwrap();
463        let second = iter.next().await.unwrap();
464        let third = iter.next().await.unwrap();
465
466        // then
467        assert!(first.is_some());
468        assert_eq!(first.unwrap().key, Bytes::from("key1"));
469        assert!(second.is_some());
470        assert_eq!(second.unwrap().key, Bytes::from("key2"));
471        assert!(third.is_none());
472    }
473
474    #[tokio::test]
475    async fn should_create_snapshot_with_current_data() {
476        // given
477        let storage = InMemoryStorage::new();
478        storage
479            .put(vec![Record::new(
480                Bytes::from("key1"),
481                Bytes::from("value1"),
482            )])
483            .await
484            .unwrap();
485
486        // when
487        let snapshot = storage.snapshot().await.unwrap();
488
489        // then
490        let result = snapshot.get(Bytes::from("key1")).await.unwrap();
491        assert!(result.is_some());
492        assert_eq!(result.unwrap().value, Bytes::from("value1"));
493    }
494
495    #[tokio::test]
496    async fn should_not_see_writes_after_snapshot() {
497        // given
498        let storage = InMemoryStorage::new();
499        storage
500            .put(vec![Record::new(
501                Bytes::from("key1"),
502                Bytes::from("value1"),
503            )])
504            .await
505            .unwrap();
506
507        // when
508        let snapshot = storage.snapshot().await.unwrap();
509        storage
510            .put(vec![Record::new(
511                Bytes::from("key2"),
512                Bytes::from("value2"),
513            )])
514            .await
515            .unwrap();
516
517        // then
518        let snapshot_result = snapshot.get(Bytes::from("key2")).await.unwrap();
519        assert!(snapshot_result.is_none());
520
521        let storage_result = storage.get(Bytes::from("key2")).await.unwrap();
522        assert!(storage_result.is_some());
523    }
524
525    #[tokio::test]
526    async fn should_scan_snapshot_independently() {
527        // given
528        let storage = InMemoryStorage::new();
529        storage
530            .put(vec![Record::new(Bytes::from("a"), Bytes::from("value_a"))])
531            .await
532            .unwrap();
533
534        // when
535        let snapshot = storage.snapshot().await.unwrap();
536        storage
537            .put(vec![Record::new(Bytes::from("b"), Bytes::from("value_b"))])
538            .await
539            .unwrap();
540
541        // then
542        let snapshot_records = snapshot.scan(BytesRange::unbounded()).await.unwrap();
543        assert_eq!(snapshot_records.len(), 1);
544        assert_eq!(snapshot_records[0].key, Bytes::from("a"));
545
546        let storage_records = storage.scan(BytesRange::unbounded()).await.unwrap();
547        assert_eq!(storage_records.len(), 2);
548    }
549
550    #[tokio::test]
551    async fn should_handle_empty_record() {
552        // given
553        let storage = InMemoryStorage::new();
554        let key = Bytes::from("empty_key");
555
556        // when
557        storage.put(vec![Record::empty(key.clone())]).await.unwrap();
558        let result = storage.get(key).await.unwrap();
559
560        // then
561        assert!(result.is_some());
562        assert_eq!(result.unwrap().value, Bytes::new());
563    }
564
565    #[tokio::test]
566    async fn should_return_error_when_merge_operator_not_configured() {
567        // given
568        let storage = InMemoryStorage::new();
569        let record = Record::new(Bytes::from("key1"), Bytes::from("value1"));
570
571        // when
572        let result = storage.merge(vec![record]).await;
573
574        // then
575        assert!(result.is_err());
576        assert!(
577            result
578                .unwrap_err()
579                .to_string()
580                .contains("Merge operator not configured")
581        );
582    }
583
584    #[tokio::test]
585    async fn should_merge_when_key_does_not_exist() {
586        // given
587        let merge_op = Arc::new(AppendMergeOperator);
588        let storage = InMemoryStorage::with_merge_operator(merge_op);
589        let key = Bytes::from("new_key");
590        let value = Bytes::from("value1");
591
592        // when
593        storage
594            .merge(vec![Record::new(key.clone(), value.clone())])
595            .await
596            .unwrap();
597        let result = storage.get(key).await.unwrap();
598
599        // then
600        assert!(result.is_some());
601        assert_eq!(result.unwrap().value, value);
602    }
603
604    #[tokio::test]
605    async fn should_merge_when_key_exists() {
606        // given
607        let merge_op = Arc::new(AppendMergeOperator);
608        let storage = InMemoryStorage::with_merge_operator(merge_op);
609        let key = Bytes::from("key1");
610        let initial_value = Bytes::from("value1");
611        let new_value = Bytes::from("value2");
612
613        storage
614            .put(vec![Record::new(key.clone(), initial_value)])
615            .await
616            .unwrap();
617
618        // when
619        storage
620            .merge(vec![Record::new(key.clone(), new_value)])
621            .await
622            .unwrap();
623        let result = storage.get(key).await.unwrap();
624
625        // then
626        assert!(result.is_some());
627        assert_eq!(result.unwrap().value, Bytes::from("value1,value2"));
628    }
629
630    #[tokio::test]
631    async fn should_merge_multiple_keys() {
632        // given
633        let merge_op = Arc::new(AppendMergeOperator);
634        let storage = InMemoryStorage::with_merge_operator(merge_op);
635        let records = vec![
636            Record::new(Bytes::from("key1"), Bytes::from("value1")),
637            Record::new(Bytes::from("key2"), Bytes::from("value2")),
638        ];
639        storage.put(records).await.unwrap();
640
641        // when
642        storage
643            .merge(vec![
644                Record::new(Bytes::from("key1"), Bytes::from("value1a")),
645                Record::new(Bytes::from("key2"), Bytes::from("value2a")),
646            ])
647            .await
648            .unwrap();
649
650        // then
651        let result1 = storage.get(Bytes::from("key1")).await.unwrap();
652        assert_eq!(result1.unwrap().value, Bytes::from("value1,value1a"));
653
654        let result2 = storage.get(Bytes::from("key2")).await.unwrap();
655        assert_eq!(result2.unwrap().value, Bytes::from("value2,value2a"));
656    }
657
658    #[tokio::test]
659    async fn should_merge_empty_values() {
660        // given
661        let merge_op = Arc::new(AppendMergeOperator);
662        let storage = InMemoryStorage::with_merge_operator(merge_op);
663        let key = Bytes::from("key1");
664
665        // when
666        storage
667            .merge(vec![Record::empty(key.clone())])
668            .await
669            .unwrap();
670        let result = storage.get(key).await.unwrap();
671
672        // then
673        assert!(result.is_some());
674        assert_eq!(result.unwrap().value, Bytes::new());
675    }
676}