Skip to main content

reifydb_cdc/storage/
memory.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{
5	collections::{BTreeMap, Bound},
6	sync::Arc,
7};
8
9use reifydb_core::{
10	common::CommitVersion,
11	interface::cdc::{Cdc, CdcBatch},
12};
13use reifydb_runtime::sync::rwlock::RwLock;
14
15use super::{CdcStorage, CdcStorageResult, DropBeforeResult, DroppedCdcEntry};
16
17#[derive(Clone)]
18pub struct MemoryCdcStorage {
19	inner: Arc<RwLock<BTreeMap<CommitVersion, Cdc>>>,
20}
21
22impl MemoryCdcStorage {
23	pub fn new() -> Self {
24		Self {
25			inner: Arc::new(RwLock::new(BTreeMap::new())),
26		}
27	}
28
29	pub fn with_entries(entries: impl IntoIterator<Item = Cdc>) -> Self {
30		let map: BTreeMap<CommitVersion, Cdc> = entries.into_iter().map(|cdc| (cdc.version, cdc)).collect();
31		Self {
32			inner: Arc::new(RwLock::new(map)),
33		}
34	}
35
36	pub fn len(&self) -> usize {
37		self.inner.read().len()
38	}
39
40	pub fn is_empty(&self) -> bool {
41		self.inner.read().is_empty()
42	}
43
44	pub fn clear(&self) {
45		self.inner.write().clear();
46	}
47}
48
49impl Default for MemoryCdcStorage {
50	fn default() -> Self {
51		Self::new()
52	}
53}
54
55impl CdcStorage for MemoryCdcStorage {
56	fn write(&self, cdc: &Cdc) -> CdcStorageResult<()> {
57		self.inner.write().insert(cdc.version, cdc.clone());
58		Ok(())
59	}
60
61	fn read(&self, version: CommitVersion) -> CdcStorageResult<Option<Cdc>> {
62		Ok(self.inner.read().get(&version).cloned())
63	}
64
65	fn read_range(
66		&self,
67		start: Bound<CommitVersion>,
68		end: Bound<CommitVersion>,
69		batch_size: u64,
70	) -> CdcStorageResult<CdcBatch> {
71		let guard = self.inner.read();
72		let batch_size = batch_size as usize;
73
74		let range_iter = guard.range((start, end));
75		let mut items: Vec<Cdc> = Vec::with_capacity(batch_size.min(64));
76		let mut count = 0;
77
78		for (_, cdc) in range_iter {
79			if count >= batch_size {
80				// We've hit the batch limit, there are more items
81				return Ok(CdcBatch {
82					items,
83					has_more: true,
84				});
85			}
86			items.push(cdc.clone());
87			count += 1;
88		}
89
90		Ok(CdcBatch {
91			items,
92			has_more: false,
93		})
94	}
95
96	fn count(&self, version: CommitVersion) -> CdcStorageResult<usize> {
97		Ok(self.inner.read().get(&version).map(|cdc| cdc.system_changes.len()).unwrap_or(0))
98	}
99
100	fn min_version(&self) -> CdcStorageResult<Option<CommitVersion>> {
101		Ok(self.inner.read().keys().next().copied())
102	}
103
104	fn max_version(&self) -> CdcStorageResult<Option<CommitVersion>> {
105		Ok(self.inner.read().keys().next_back().copied())
106	}
107
108	fn drop_before(&self, version: CommitVersion) -> CdcStorageResult<DropBeforeResult> {
109		let mut guard = self.inner.write();
110		let keys_to_remove: Vec<_> = guard.range(..version).map(|(k, _)| *k).collect();
111		let count = keys_to_remove.len();
112
113		let mut entries = Vec::new();
114		for key in &keys_to_remove {
115			if let Some(cdc) = guard.get(key) {
116				for sys_change in &cdc.system_changes {
117					entries.push(DroppedCdcEntry {
118						key: sys_change.key().clone(),
119						value_bytes: sys_change.value_bytes() as u64,
120					});
121				}
122			}
123		}
124
125		for key in keys_to_remove {
126			guard.remove(&key);
127		}
128
129		Ok(DropBeforeResult {
130			count,
131			entries,
132		})
133	}
134}
135
136#[cfg(test)]
137pub mod tests {
138	use std::thread;
139
140	use reifydb_core::{
141		encoded::{key::EncodedKey, row::EncodedRow},
142		interface::cdc::SystemChange,
143	};
144	use reifydb_type::util::cowvec::CowVec;
145
146	use super::*;
147
148	fn make_cdc(version: u64) -> Cdc {
149		Cdc::new(
150			CommitVersion(version),
151			12345,
152			Vec::new(),
153			vec![SystemChange::Insert {
154				key: EncodedKey::new(vec![1, 2, 3]),
155				post: EncodedRow(CowVec::new(vec![])),
156			}],
157		)
158	}
159
160	#[test]
161	fn test_clone_shares_storage() {
162		let storage1 = MemoryCdcStorage::new();
163		let storage2 = storage1.clone();
164
165		storage1.write(&make_cdc(1)).unwrap();
166
167		// Both should see the same data
168		assert!(storage1.read(CommitVersion(1)).unwrap().is_some());
169		assert!(storage2.read(CommitVersion(1)).unwrap().is_some());
170	}
171
172	#[test]
173	fn test_concurrent_access() {
174		let storage = MemoryCdcStorage::new();
175		let mut handles = vec![];
176
177		// Spawn multiple writers
178		for i in 0..10 {
179			let s = storage.clone();
180			handles.push(thread::spawn(move || {
181				s.write(&make_cdc(i)).unwrap();
182			}));
183		}
184
185		for h in handles {
186			h.join().unwrap();
187		}
188
189		// All entries should be present
190		assert_eq!(storage.len(), 10);
191	}
192
193	#[test]
194	fn test_range_exclusive_bounds() {
195		let storage = MemoryCdcStorage::new();
196
197		for v in 1..=5 {
198			storage.write(&make_cdc(v)).unwrap();
199		}
200
201		// Exclusive start
202		let batch = storage
203			.read_range(Bound::Excluded(CommitVersion(2)), Bound::Included(CommitVersion(4)), 100)
204			.unwrap();
205		assert_eq!(batch.items.len(), 2); // 3, 4
206		assert_eq!(batch.items[0].version, CommitVersion(3));
207		assert_eq!(batch.items[1].version, CommitVersion(4));
208
209		// Exclusive end
210		let batch = storage
211			.read_range(Bound::Included(CommitVersion(2)), Bound::Excluded(CommitVersion(4)), 100)
212			.unwrap();
213		assert_eq!(batch.items.len(), 2); // 2, 3
214		assert_eq!(batch.items[0].version, CommitVersion(2));
215		assert_eq!(batch.items[1].version, CommitVersion(3));
216	}
217
218	#[test]
219	fn test_overwrite_entry() {
220		let storage = MemoryCdcStorage::new();
221
222		let cdc1 = Cdc::new(
223			CommitVersion(1),
224			100,
225			Vec::new(),
226			vec![SystemChange::Insert {
227				key: EncodedKey::new(vec![1]),
228				post: EncodedRow(CowVec::new(vec![])),
229			}],
230		);
231
232		let cdc2 = Cdc::new(
233			CommitVersion(1),
234			200, // Different timestamp
235			Vec::new(),
236			vec![
237				SystemChange::Insert {
238					key: EncodedKey::new(vec![2]),
239					post: EncodedRow(CowVec::new(vec![])),
240				},
241				SystemChange::Insert {
242					key: EncodedKey::new(vec![3]),
243					post: EncodedRow(CowVec::new(vec![])),
244				},
245			],
246		);
247
248		storage.write(&cdc1).unwrap();
249		assert_eq!(storage.count(CommitVersion(1)).unwrap(), 1);
250
251		storage.write(&cdc2).unwrap();
252		assert_eq!(storage.count(CommitVersion(1)).unwrap(), 2);
253
254		let read = storage.read(CommitVersion(1)).unwrap().unwrap();
255		assert_eq!(read.timestamp, 200);
256	}
257}