1use std::collections::BTreeMap;
2use std::ops::RangeBounds;
3use std::sync::{Arc, RwLock};
4
5use async_trait::async_trait;
6use bytes::Bytes;
7
8use super::{MergeOperator, Storage, StorageSnapshot, WriteOptions};
9use crate::storage::RecordOp;
10use crate::{BytesRange, Record, StorageError, StorageIterator, StorageRead, StorageResult};
11
12pub struct InMemoryStorage {
17 data: Arc<RwLock<BTreeMap<Bytes, Bytes>>>,
18 merge_operator: Option<Arc<dyn MergeOperator + Send + Sync>>,
19}
20
21impl InMemoryStorage {
22 pub fn new() -> Self {
24 Self {
25 data: Arc::new(RwLock::new(BTreeMap::new())),
26 merge_operator: None,
27 }
28 }
29
30 pub fn with_merge_operator(merge_operator: Arc<dyn MergeOperator + Send + Sync>) -> Self {
36 Self {
37 data: Arc::new(RwLock::new(BTreeMap::new())),
38 merge_operator: Some(merge_operator),
39 }
40 }
41}
42
43impl Default for InMemoryStorage {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49#[async_trait]
50impl StorageRead for InMemoryStorage {
51 #[tracing::instrument(level = "trace", skip_all)]
55 async fn get(&self, key: Bytes) -> StorageResult<Option<Record>> {
56 let data = self
57 .data
58 .read()
59 .map_err(|e| StorageError::Internal(format!("Failed to acquire read lock: {}", e)))?;
60
61 match data.get(&key) {
62 Some(value) => Ok(Some(Record::new(key, value.clone()))),
63 None => Ok(None),
64 }
65 }
66
67 #[tracing::instrument(level = "trace", skip_all)]
68 async fn scan_iter(
69 &self,
70 range: BytesRange,
71 ) -> StorageResult<Box<dyn StorageIterator + Send + 'static>> {
72 let data = self
73 .data
74 .read()
75 .map_err(|e| StorageError::Internal(format!("Failed to acquire read lock: {}", e)))?;
76
77 let records: Vec<Record> = data
79 .range((range.start_bound().cloned(), range.end_bound().cloned()))
80 .map(|(k, v)| Record::new(k.clone(), v.clone()))
81 .collect();
82
83 Ok(Box::new(InMemoryIterator { records, index: 0 }))
84 }
85}
86
87struct InMemoryIterator {
88 records: Vec<Record>,
89 index: usize,
90}
91
92#[async_trait]
93impl StorageIterator for InMemoryIterator {
94 #[tracing::instrument(level = "trace", skip_all)]
95 async fn next(&mut self) -> StorageResult<Option<Record>> {
96 if self.index >= self.records.len() {
97 Ok(None)
98 } else {
99 let record = self.records[self.index].clone();
100 self.index += 1;
101 Ok(Some(record))
102 }
103 }
104}
105
106pub struct InMemoryStorageSnapshot {
110 data: Arc<BTreeMap<Bytes, Bytes>>,
111}
112
113#[async_trait]
114impl StorageRead for InMemoryStorageSnapshot {
115 #[tracing::instrument(level = "trace", skip_all)]
116 async fn get(&self, key: Bytes) -> StorageResult<Option<Record>> {
117 match self.data.get(&key) {
118 Some(value) => Ok(Some(Record::new(key, value.clone()))),
119 None => Ok(None),
120 }
121 }
122
123 #[tracing::instrument(level = "trace", skip_all)]
124 async fn scan_iter(
125 &self,
126 range: BytesRange,
127 ) -> StorageResult<Box<dyn StorageIterator + Send + 'static>> {
128 let records: Vec<Record> = self
130 .data
131 .range((range.start_bound().cloned(), range.end_bound().cloned()))
132 .map(|(k, v)| Record::new(k.clone(), v.clone()))
133 .collect();
134
135 Ok(Box::new(InMemoryIterator { records, index: 0 }))
136 }
137}
138
139#[async_trait]
140impl StorageSnapshot for InMemoryStorageSnapshot {}
141
142#[async_trait]
143impl Storage for InMemoryStorage {
144 async fn apply(&self, records: Vec<RecordOp>) -> StorageResult<()> {
145 let mut data = self
146 .data
147 .write()
148 .map_err(|e| StorageError::Internal(format!("Failed to acquire write lock: {}", e)))?;
149
150 for record in records {
151 match record {
152 RecordOp::Put(record) => {
153 data.insert(record.key, record.value);
154 }
155 RecordOp::Merge(record) => {
156 let existing_value = data.get(&record.key).cloned();
157 let merged_value = self.merge_operator.as_ref().unwrap().merge(
158 &record.key,
159 existing_value,
160 record.value.clone(),
161 );
162 data.insert(record.key, merged_value);
163 }
164 RecordOp::Delete(key) => {
165 data.remove(&key);
166 }
167 }
168 }
169
170 Ok(())
171 }
172
173 async fn put(&self, records: Vec<Record>) -> StorageResult<()> {
177 self.put_with_options(records, WriteOptions::default())
178 .await
179 }
180
181 async fn put_with_options(
187 &self,
188 records: Vec<Record>,
189 _options: WriteOptions,
190 ) -> StorageResult<()> {
191 let mut data = self
192 .data
193 .write()
194 .map_err(|e| StorageError::Internal(format!("Failed to acquire write lock: {}", e)))?;
195
196 for record in records {
197 data.insert(record.key, record.value);
198 }
199
200 Ok(())
201 }
202
203 async fn merge(&self, records: Vec<Record>) -> StorageResult<()> {
214 let merge_op = self
215 .merge_operator
216 .as_ref()
217 .ok_or_else(|| {
218 StorageError::Storage(
219 "Merge operator not configured: in-memory storage requires a merge operator to be set during construction".to_string(),
220 )
221 })?;
222
223 let mut data = self
224 .data
225 .write()
226 .map_err(|e| StorageError::Internal(format!("Failed to acquire write lock: {}", e)))?;
227
228 for record in records {
229 let existing_value = data.get(&record.key).cloned();
230 let merged_value = merge_op.merge(&record.key, existing_value, record.value.clone());
231 data.insert(record.key, merged_value);
232 }
233
234 Ok(())
235 }
236
237 async fn snapshot(&self) -> StorageResult<Arc<dyn StorageSnapshot>> {
243 let data = self
244 .data
245 .read()
246 .map_err(|e| StorageError::Internal(format!("Failed to acquire read lock: {}", e)))?;
247
248 let snapshot_data = Arc::new(data.clone());
250
251 Ok(Arc::new(InMemoryStorageSnapshot {
252 data: snapshot_data,
253 }))
254 }
255
256 async fn flush(&self) -> StorageResult<()> {
257 Ok(())
259 }
260
261 async fn close(&self) -> StorageResult<()> {
262 Ok(())
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use bytes::BytesMut;
271 use std::ops::Bound;
272
273 struct AppendMergeOperator;
275
276 impl MergeOperator for AppendMergeOperator {
277 fn merge(&self, _key: &Bytes, existing_value: Option<Bytes>, new_value: Bytes) -> Bytes {
278 match existing_value {
279 Some(existing) => {
280 let mut result = BytesMut::from(existing);
281 result.extend_from_slice(b",");
282 result.extend_from_slice(&new_value);
283 result.freeze()
284 }
285 None => new_value,
286 }
287 }
288 }
289
290 #[tokio::test]
291 async fn should_return_none_when_key_not_found() {
292 let storage = InMemoryStorage::new();
294
295 let result = storage.get(Bytes::from("missing_key")).await;
297
298 assert!(result.is_ok());
300 assert!(result.unwrap().is_none());
301 }
302
303 #[tokio::test]
304 async fn should_store_and_retrieve_record() {
305 let storage = InMemoryStorage::new();
307 let key = Bytes::from("test_key");
308 let value = Bytes::from("test_value");
309
310 storage
312 .put(vec![Record::new(key.clone(), value.clone())])
313 .await
314 .unwrap();
315 let result = storage.get(key).await.unwrap();
316
317 assert!(result.is_some());
319 let record = result.unwrap();
320 assert_eq!(record.key, Bytes::from("test_key"));
321 assert_eq!(record.value, value);
322 }
323
324 #[tokio::test]
325 async fn should_overwrite_existing_key() {
326 let storage = InMemoryStorage::new();
328 let key = Bytes::from("test_key");
329 let initial_value = Bytes::from("initial_value");
330 let updated_value = Bytes::from("updated_value");
331
332 storage
334 .put(vec![Record::new(key.clone(), initial_value)])
335 .await
336 .unwrap();
337 storage
338 .put(vec![Record::new(key.clone(), updated_value.clone())])
339 .await
340 .unwrap();
341 let result = storage.get(key).await.unwrap();
342
343 assert!(result.is_some());
345 assert_eq!(result.unwrap().value, updated_value);
346 }
347
348 #[tokio::test]
349 async fn should_store_multiple_records() {
350 let storage = InMemoryStorage::new();
352 let records = vec![
353 Record::new(Bytes::from("key1"), Bytes::from("value1")),
354 Record::new(Bytes::from("key2"), Bytes::from("value2")),
355 Record::new(Bytes::from("key3"), Bytes::from("value3")),
356 ];
357
358 storage.put(records.clone()).await.unwrap();
360
361 for record in records {
363 let retrieved = storage.get(record.key.clone()).await.unwrap();
364 assert!(retrieved.is_some());
365 assert_eq!(retrieved.unwrap().value, record.value);
366 }
367 }
368
369 #[tokio::test]
370 async fn should_scan_all_records_when_unbounded() {
371 let storage = InMemoryStorage::new();
373 let records = vec![
374 Record::new(Bytes::from("a"), Bytes::from("value_a")),
375 Record::new(Bytes::from("b"), Bytes::from("value_b")),
376 Record::new(Bytes::from("c"), Bytes::from("value_c")),
377 ];
378 storage.put(records.clone()).await.unwrap();
379
380 let scanned = storage.scan(BytesRange::unbounded()).await.unwrap();
382
383 assert_eq!(scanned.len(), 3);
385 assert_eq!(scanned[0].key, Bytes::from("a"));
386 assert_eq!(scanned[1].key, Bytes::from("b"));
387 assert_eq!(scanned[2].key, Bytes::from("c"));
388 }
389
390 #[tokio::test]
391 async fn should_scan_records_with_prefix() {
392 let storage = InMemoryStorage::new();
394 let records = vec![
395 Record::new(Bytes::from("prefix_a"), Bytes::from("value1")),
396 Record::new(Bytes::from("prefix_b"), Bytes::from("value2")),
397 Record::new(Bytes::from("other_c"), Bytes::from("value3")),
398 ];
399 storage.put(records).await.unwrap();
400
401 let scanned = storage
403 .scan(BytesRange::prefix(Bytes::from("prefix_")))
404 .await
405 .unwrap();
406
407 assert_eq!(scanned.len(), 2);
409 assert_eq!(scanned[0].key, Bytes::from("prefix_a"));
410 assert_eq!(scanned[1].key, Bytes::from("prefix_b"));
411 }
412
413 #[tokio::test]
414 async fn should_scan_records_in_bounded_range() {
415 let storage = InMemoryStorage::new();
417 let records = vec![
418 Record::new(Bytes::from("a"), Bytes::from("value_a")),
419 Record::new(Bytes::from("b"), Bytes::from("value_b")),
420 Record::new(Bytes::from("c"), Bytes::from("value_c")),
421 Record::new(Bytes::from("d"), Bytes::from("value_d")),
422 ];
423 storage.put(records).await.unwrap();
424
425 let range = BytesRange::new(
427 Bound::Included(Bytes::from("b")),
428 Bound::Excluded(Bytes::from("d")),
429 );
430 let scanned = storage.scan(range).await.unwrap();
431
432 assert_eq!(scanned.len(), 2);
434 assert_eq!(scanned[0].key, Bytes::from("b"));
435 assert_eq!(scanned[1].key, Bytes::from("c"));
436 }
437
438 #[tokio::test]
439 async fn should_return_empty_vec_when_scanning_empty_storage() {
440 let storage = InMemoryStorage::new();
442
443 let scanned = storage.scan(BytesRange::unbounded()).await.unwrap();
445
446 assert!(scanned.is_empty());
448 }
449
450 #[tokio::test]
451 async fn should_iterate_over_records() {
452 let storage = InMemoryStorage::new();
454 let records = vec![
455 Record::new(Bytes::from("key1"), Bytes::from("value1")),
456 Record::new(Bytes::from("key2"), Bytes::from("value2")),
457 ];
458 storage.put(records).await.unwrap();
459
460 let mut iter = storage.scan_iter(BytesRange::unbounded()).await.unwrap();
462 let first = iter.next().await.unwrap();
463 let second = iter.next().await.unwrap();
464 let third = iter.next().await.unwrap();
465
466 assert!(first.is_some());
468 assert_eq!(first.unwrap().key, Bytes::from("key1"));
469 assert!(second.is_some());
470 assert_eq!(second.unwrap().key, Bytes::from("key2"));
471 assert!(third.is_none());
472 }
473
474 #[tokio::test]
475 async fn should_create_snapshot_with_current_data() {
476 let storage = InMemoryStorage::new();
478 storage
479 .put(vec![Record::new(
480 Bytes::from("key1"),
481 Bytes::from("value1"),
482 )])
483 .await
484 .unwrap();
485
486 let snapshot = storage.snapshot().await.unwrap();
488
489 let result = snapshot.get(Bytes::from("key1")).await.unwrap();
491 assert!(result.is_some());
492 assert_eq!(result.unwrap().value, Bytes::from("value1"));
493 }
494
495 #[tokio::test]
496 async fn should_not_see_writes_after_snapshot() {
497 let storage = InMemoryStorage::new();
499 storage
500 .put(vec![Record::new(
501 Bytes::from("key1"),
502 Bytes::from("value1"),
503 )])
504 .await
505 .unwrap();
506
507 let snapshot = storage.snapshot().await.unwrap();
509 storage
510 .put(vec![Record::new(
511 Bytes::from("key2"),
512 Bytes::from("value2"),
513 )])
514 .await
515 .unwrap();
516
517 let snapshot_result = snapshot.get(Bytes::from("key2")).await.unwrap();
519 assert!(snapshot_result.is_none());
520
521 let storage_result = storage.get(Bytes::from("key2")).await.unwrap();
522 assert!(storage_result.is_some());
523 }
524
525 #[tokio::test]
526 async fn should_scan_snapshot_independently() {
527 let storage = InMemoryStorage::new();
529 storage
530 .put(vec![Record::new(Bytes::from("a"), Bytes::from("value_a"))])
531 .await
532 .unwrap();
533
534 let snapshot = storage.snapshot().await.unwrap();
536 storage
537 .put(vec![Record::new(Bytes::from("b"), Bytes::from("value_b"))])
538 .await
539 .unwrap();
540
541 let snapshot_records = snapshot.scan(BytesRange::unbounded()).await.unwrap();
543 assert_eq!(snapshot_records.len(), 1);
544 assert_eq!(snapshot_records[0].key, Bytes::from("a"));
545
546 let storage_records = storage.scan(BytesRange::unbounded()).await.unwrap();
547 assert_eq!(storage_records.len(), 2);
548 }
549
550 #[tokio::test]
551 async fn should_handle_empty_record() {
552 let storage = InMemoryStorage::new();
554 let key = Bytes::from("empty_key");
555
556 storage.put(vec![Record::empty(key.clone())]).await.unwrap();
558 let result = storage.get(key).await.unwrap();
559
560 assert!(result.is_some());
562 assert_eq!(result.unwrap().value, Bytes::new());
563 }
564
565 #[tokio::test]
566 async fn should_return_error_when_merge_operator_not_configured() {
567 let storage = InMemoryStorage::new();
569 let record = Record::new(Bytes::from("key1"), Bytes::from("value1"));
570
571 let result = storage.merge(vec![record]).await;
573
574 assert!(result.is_err());
576 assert!(
577 result
578 .unwrap_err()
579 .to_string()
580 .contains("Merge operator not configured")
581 );
582 }
583
584 #[tokio::test]
585 async fn should_merge_when_key_does_not_exist() {
586 let merge_op = Arc::new(AppendMergeOperator);
588 let storage = InMemoryStorage::with_merge_operator(merge_op);
589 let key = Bytes::from("new_key");
590 let value = Bytes::from("value1");
591
592 storage
594 .merge(vec![Record::new(key.clone(), value.clone())])
595 .await
596 .unwrap();
597 let result = storage.get(key).await.unwrap();
598
599 assert!(result.is_some());
601 assert_eq!(result.unwrap().value, value);
602 }
603
604 #[tokio::test]
605 async fn should_merge_when_key_exists() {
606 let merge_op = Arc::new(AppendMergeOperator);
608 let storage = InMemoryStorage::with_merge_operator(merge_op);
609 let key = Bytes::from("key1");
610 let initial_value = Bytes::from("value1");
611 let new_value = Bytes::from("value2");
612
613 storage
614 .put(vec![Record::new(key.clone(), initial_value)])
615 .await
616 .unwrap();
617
618 storage
620 .merge(vec![Record::new(key.clone(), new_value)])
621 .await
622 .unwrap();
623 let result = storage.get(key).await.unwrap();
624
625 assert!(result.is_some());
627 assert_eq!(result.unwrap().value, Bytes::from("value1,value2"));
628 }
629
630 #[tokio::test]
631 async fn should_merge_multiple_keys() {
632 let merge_op = Arc::new(AppendMergeOperator);
634 let storage = InMemoryStorage::with_merge_operator(merge_op);
635 let records = vec![
636 Record::new(Bytes::from("key1"), Bytes::from("value1")),
637 Record::new(Bytes::from("key2"), Bytes::from("value2")),
638 ];
639 storage.put(records).await.unwrap();
640
641 storage
643 .merge(vec![
644 Record::new(Bytes::from("key1"), Bytes::from("value1a")),
645 Record::new(Bytes::from("key2"), Bytes::from("value2a")),
646 ])
647 .await
648 .unwrap();
649
650 let result1 = storage.get(Bytes::from("key1")).await.unwrap();
652 assert_eq!(result1.unwrap().value, Bytes::from("value1,value1a"));
653
654 let result2 = storage.get(Bytes::from("key2")).await.unwrap();
655 assert_eq!(result2.unwrap().value, Bytes::from("value2,value2a"));
656 }
657
658 #[tokio::test]
659 async fn should_merge_empty_values() {
660 let merge_op = Arc::new(AppendMergeOperator);
662 let storage = InMemoryStorage::with_merge_operator(merge_op);
663 let key = Bytes::from("key1");
664
665 storage
667 .merge(vec![Record::empty(key.clone())])
668 .await
669 .unwrap();
670 let result = storage.get(key).await.unwrap();
671
672 assert!(result.is_some());
674 assert_eq!(result.unwrap().value, Bytes::new());
675 }
676}