1use crate::error::{Error, Result};
5use crate::kv::{KVStore, KVTransaction};
6use crate::log::wal::{WalReader, WalRecord, WalWriter};
7use crate::storage::flush::write_empty_vector_segment;
8use crate::storage::sstable::{SstableReader, SstableWriter};
9use crate::txn::TxnManager;
10use crate::types::{Key, TxnId, TxnMode, TxnState, Value};
11use std::collections::{BTreeMap, HashMap};
12use std::ops::Bound::{Excluded, Included};
13use std::path::{Path, PathBuf};
14use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
15use std::sync::{Arc, RwLock, RwLockReadGuard};
16use tracing::warn;
17
18#[derive(Debug, Clone, Default)]
20pub struct MemoryStats {
21 pub total_bytes: usize,
23 pub kv_bytes: usize,
25 pub index_bytes: usize,
27}
28
29#[derive(Clone)]
31pub struct MemoryKV {
32 manager: Arc<MemoryTxnManager>,
33}
34
35impl MemoryKV {
36 pub fn new() -> Self {
38 Self {
39 manager: Arc::new(MemoryTxnManager::new(None, None, None)),
40 }
41 }
42
43 pub fn memory_stats(&self) -> MemoryStats {
45 self.manager.memory_stats()
46 }
47
48 pub fn new_with_limit(limit: Option<usize>) -> Self {
50 Self {
51 manager: Arc::new(MemoryTxnManager::new_with_limit(limit)),
52 }
53 }
54
55 pub fn open(path: &Path) -> Result<Self> {
57 let wal_writer = WalWriter::new(path)?;
58 let sstable_path = path.with_extension("sst");
59 let manager = Arc::new(MemoryTxnManager::new(
60 Some(wal_writer),
61 Some(path.to_path_buf()),
62 Some(sstable_path),
63 ));
64 manager.recover()?;
65 Ok(Self { manager })
66 }
67
68 pub fn flush(&self) -> Result<()> {
70 self.manager.flush()
71 }
72}
73
74impl Default for MemoryKV {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl KVStore for MemoryKV {
81 type Transaction<'a> = MemoryTransaction<'a>;
82 type Manager<'a> = &'a MemoryTxnManager;
83
84 fn txn_manager(&self) -> Self::Manager<'_> {
85 &self.manager
86 }
87
88 fn begin(&self, mode: TxnMode) -> Result<Self::Transaction<'_>> {
89 self.manager.begin_internal(mode)
90 }
91}
92
93type VersionedValue = (Value, u64);
95
96struct MemorySharedState {
98 data: RwLock<BTreeMap<Key, VersionedValue>>,
100 next_txn_id: AtomicU64,
102 commit_version: AtomicU64,
104 wal_writer: Option<RwLock<WalWriter>>,
106 wal_path: Option<PathBuf>,
108 sstable: RwLock<Option<SstableReader>>,
110 sstable_path: Option<PathBuf>,
112 memory_limit: RwLock<Option<usize>>,
114 current_memory: AtomicUsize,
116}
117
118impl MemorySharedState {
119 fn check_memory_limit(&self, additional: usize) -> Result<()> {
121 if let Some(limit) = *self.memory_limit.read().unwrap() {
122 let current = self.current_memory.load(Ordering::Relaxed);
123 let requested = current.saturating_add(additional);
124 if requested > limit {
125 return Err(Error::MemoryLimitExceeded { limit, requested });
126 }
127 }
128 Ok(())
129 }
130
131 fn memory_stats(&self) -> MemoryStats {
133 let kv_bytes = self.current_memory.load(Ordering::Relaxed);
134 MemoryStats {
135 total_bytes: kv_bytes,
136 kv_bytes,
137 index_bytes: 0,
138 }
139 }
140
141 fn recompute_current_memory(&self) {
143 let data = self.data.read().unwrap();
144 let mut total = 0usize;
145 for (k, (v, _)) in data.iter() {
146 total = total.saturating_add(k.len() + v.len());
147 }
148 self.current_memory.store(total, Ordering::Relaxed);
149 }
150}
151
152pub struct MemoryTxnManager {
154 state: Arc<MemorySharedState>,
155}
156
157impl MemoryTxnManager {
158 fn new_with_params(
159 wal_writer: Option<WalWriter>,
160 wal_path: Option<PathBuf>,
161 sstable_path: Option<PathBuf>,
162 memory_limit: Option<usize>,
163 ) -> Self {
164 Self {
165 state: Arc::new(MemorySharedState {
166 data: RwLock::new(BTreeMap::new()),
167 next_txn_id: AtomicU64::new(1),
168 commit_version: AtomicU64::new(0),
169 wal_writer: wal_writer.map(RwLock::new),
170 wal_path,
171 sstable: RwLock::new(None),
172 sstable_path,
173 memory_limit: RwLock::new(memory_limit),
174 current_memory: AtomicUsize::new(0),
175 }),
176 }
177 }
178
179 fn new(
180 wal_writer: Option<WalWriter>,
181 wal_path: Option<PathBuf>,
182 sstable_path: Option<PathBuf>,
183 ) -> Self {
184 Self::new_with_params(wal_writer, wal_path, sstable_path, None)
185 }
186
187 pub fn new_with_limit(limit: Option<usize>) -> Self {
189 Self::new_with_params(None, None, None, limit)
190 }
191
192 pub fn memory_stats(&self) -> MemoryStats {
194 self.state.memory_stats()
195 }
196
197 pub fn set_memory_limit(&self, limit: Option<usize>) {
199 let mut guard = self.state.memory_limit.write().unwrap();
200 *guard = limit;
201 }
202
203 pub fn snapshot(&self) -> Vec<(Key, Value)> {
205 let data = self.state.data.read().unwrap();
206 data.iter()
207 .map(|(k, (v, _))| (k.clone(), v.clone()))
208 .collect()
209 }
210
211 pub fn clear_all(&self) {
213 let mut data = self.state.data.write().unwrap();
214 data.clear();
215 drop(data);
216 self.state.current_memory.store(0, Ordering::Relaxed);
217 self.state.commit_version.store(0, Ordering::Relaxed);
218 }
219
220 pub fn compact_with_limit<F>(
223 &self,
224 input_bytes: usize,
225 output_bytes: usize,
226 run: F,
227 ) -> Result<bool>
228 where
229 F: FnOnce() -> Result<()>,
230 {
231 if let Some(limit) = *self.state.memory_limit.read().unwrap() {
232 let current = self.state.current_memory.load(Ordering::Relaxed);
233 let prospective = current
235 .saturating_sub(input_bytes)
236 .saturating_add(output_bytes);
237 if prospective > limit {
238 warn!(
239 limit,
240 requested = prospective,
241 "compaction skipped due to memory limit"
242 );
243 return Ok(false);
244 }
245 }
246
247 run()?;
248
249 let current = self.state.current_memory.load(Ordering::Relaxed);
251 let new_usage = current
252 .saturating_sub(input_bytes)
253 .saturating_add(output_bytes);
254 self.state
255 .current_memory
256 .store(new_usage, Ordering::Relaxed);
257 Ok(true)
258 }
259
260 pub fn compact_in_memory(&self) -> Result<bool> {
262 let snapshot_bytes = {
263 let data = self.state.data.read().unwrap();
264 let mut bytes = 0usize;
265 for (k, (v, _)) in data.iter() {
266 bytes = bytes.saturating_add(k.len() + v.len());
267 }
268 bytes
269 };
270
271 self.compact_with_limit(snapshot_bytes, snapshot_bytes, || {
272 let data = self.state.data.read().unwrap();
273 let mut rebuilt = BTreeMap::new();
274 for (k, (v, version)) in data.iter() {
275 rebuilt.insert(k.clone(), (v.clone(), *version));
276 }
277 drop(data);
278
279 let mut write_guard = self.state.data.write().unwrap();
280 *write_guard = rebuilt;
281 Ok(())
282 })
283 }
284
285 pub fn flush(&self) -> Result<()> {
287 let path = self
288 .state
289 .sstable_path
290 .as_ref()
291 .ok_or_else(|| Error::InvalidFormat("sstable path is not configured".into()))?;
292
293 let data = self.state.data.read().unwrap();
294 let mut writer = SstableWriter::create(path)?;
295 for (key, (value, _version)) in data.iter() {
296 writer.append(key, value)?;
297 }
298 drop(data);
299
300 let _footer = writer.finish()?;
301 let reader = SstableReader::open(path)?;
302 let vec_path = path.with_extension("vec");
304 write_empty_vector_segment(&vec_path)?;
305
306 let mut slot = self.state.sstable.write().unwrap();
307 *slot = Some(reader);
308 Ok(())
309 }
310
311 fn replay(&self) -> Result<()> {
313 let path = match &self.state.wal_path {
314 Some(p) => p,
315 None => return Ok(()),
316 };
317 if !path.exists() || std::fs::metadata(path)?.len() == 0 {
318 return Ok(());
319 }
320
321 let mut data = self.state.data.write().unwrap();
322 let mut max_txn_id = 0;
323 let mut max_version = self.state.commit_version.load(Ordering::Acquire);
324 let reader = WalReader::new(path)?;
325 let mut pending_txns: HashMap<TxnId, Vec<(Key, Option<Value>)>> = HashMap::new();
326
327 for record_result in reader {
328 match record_result? {
329 WalRecord::Begin(txn_id) => {
330 max_txn_id = max_txn_id.max(txn_id.0);
331 pending_txns.entry(txn_id).or_default();
332 }
333 WalRecord::Put(txn_id, key, value) => {
334 max_txn_id = max_txn_id.max(txn_id.0);
335 pending_txns
336 .entry(txn_id)
337 .or_default()
338 .push((key, Some(value)));
339 }
340 WalRecord::Delete(txn_id, key) => {
341 max_txn_id = max_txn_id.max(txn_id.0);
342 pending_txns.entry(txn_id).or_default().push((key, None));
343 }
344 WalRecord::Commit(txn_id) => {
345 if let Some(writes) = pending_txns.remove(&txn_id) {
346 max_version += 1;
347 for (key, value) in writes {
348 if let Some(v) = value {
349 data.insert(key, (v, max_version));
350 } else {
351 data.remove(&key);
352 }
353 }
354 }
355 }
356 }
357 }
358
359 self.state
360 .next_txn_id
361 .store(max_txn_id + 1, Ordering::SeqCst);
362 self.state
363 .commit_version
364 .store(max_version, Ordering::SeqCst);
365 Ok(())
366 }
367
368 fn load_sstable(&self) -> Result<()> {
369 let path = match &self.state.sstable_path {
370 Some(p) => p,
371 None => return Ok(()),
372 };
373 if !path.exists() {
374 return Ok(());
375 }
376
377 let mut reader = SstableReader::open(path)?;
378 let mut data = self.state.data.write().unwrap();
379 let mut version = self.state.commit_version.load(Ordering::Acquire);
380
381 let keys: Vec<Key> = reader
382 .index()
383 .iter()
384 .map(|entry| entry.key.clone())
385 .collect();
386
387 for key in keys {
388 if let Some(value) = reader.get(&key)? {
389 version += 1;
390 data.insert(key, (value, version));
391 }
392 }
393
394 self.state.commit_version.store(version, Ordering::SeqCst);
395 let mut slot = self.state.sstable.write().unwrap();
396 *slot = Some(reader);
397 Ok(())
398 }
399
400 fn recover(&self) -> Result<()> {
402 self.load_sstable()?;
403 self.replay()?;
404 self.state.recompute_current_memory();
405 Ok(())
406 }
407
408 fn sstable_get(&self, key: &Key) -> Result<Option<Value>> {
409 let mut guard = self.state.sstable.write().unwrap();
410 if let Some(reader) = guard.as_mut() {
411 return reader.get(key);
412 }
413 Ok(None)
414 }
415
416 fn begin_internal(&self, mode: TxnMode) -> Result<MemoryTransaction<'_>> {
417 let txn_id = self.state.next_txn_id.fetch_add(1, Ordering::SeqCst);
418 let start_version = self.state.commit_version.load(Ordering::Acquire);
419 Ok(MemoryTransaction::new(
420 self,
421 TxnId(txn_id),
422 mode,
423 start_version,
424 ))
425 }
426}
427
428impl<'a> TxnManager<'a, MemoryTransaction<'a>> for &'a MemoryTxnManager {
429 fn begin(&'a self, mode: TxnMode) -> Result<MemoryTransaction<'a>> {
430 self.begin_internal(mode)
431 }
432
433 fn commit(&'a self, mut txn: MemoryTransaction<'a>) -> Result<()> {
434 if txn.state != TxnState::Active {
435 return Err(Error::TxnClosed);
436 }
437 if txn.mode == TxnMode::ReadOnly || txn.writes.is_empty() {
438 txn.state = TxnState::Committed;
439 return Ok(());
440 }
441
442 let mut data = self.state.data.write().unwrap();
443
444 for key in txn.read_set.keys() {
445 let current_version = data.get(key).map(|(_, v)| *v).unwrap_or(0);
446 if current_version > txn.start_version {
447 return Err(Error::TxnConflict);
448 }
449 }
450
451 for key in txn.writes.keys() {
453 let current_version = data.get(key).map(|(_, v)| *v).unwrap_or(0);
454 if current_version > txn.start_version {
455 return Err(Error::TxnConflict);
456 }
457 }
458
459 let mut delta: isize = 0;
461 for (key, value) in &txn.writes {
462 let current_size = data.get(key).map(|(v, _)| key.len() + v.len()).unwrap_or(0);
463 let new_size = match value {
464 Some(v) => key.len() + v.len(),
465 None => 0,
466 };
467 delta += new_size as isize - current_size as isize;
468 }
469
470 let current_mem = self.state.current_memory.load(Ordering::Relaxed);
471 let prospective = if delta >= 0 {
472 current_mem.saturating_add(delta as usize)
473 } else {
474 current_mem.saturating_sub(delta.unsigned_abs())
475 };
476
477 if delta > 0 {
478 self.state.check_memory_limit(delta as usize)?;
479 }
480
481 let commit_version = self.state.commit_version.fetch_add(1, Ordering::AcqRel) + 1;
482
483 if let Some(wal_lock) = &self.state.wal_writer {
484 let mut wal = wal_lock.write().unwrap();
485 wal.append(&WalRecord::Begin(txn.id))?;
486 for (key, value) in &txn.writes {
487 let record = match value {
488 Some(v) => WalRecord::Put(txn.id, key.clone(), v.clone()),
489 None => WalRecord::Delete(txn.id, key.clone()),
490 };
491 wal.append(&record)?;
492 }
493 wal.append(&WalRecord::Commit(txn.id))?;
494 }
495
496 for (key, value) in std::mem::take(&mut txn.writes) {
497 if let Some(v) = value {
498 data.insert(key, (v, commit_version));
499 } else {
500 data.remove(&key);
501 }
502 }
503
504 self.state
505 .current_memory
506 .store(prospective, Ordering::Relaxed);
507
508 txn.state = TxnState::Committed;
509 Ok(())
510 }
511
512 fn rollback(&'a self, mut txn: MemoryTransaction<'a>) -> Result<()> {
513 if txn.state != TxnState::Active {
514 return Err(Error::TxnClosed);
515 }
516 txn.state = TxnState::RolledBack;
517 Ok(())
518 }
519}
520
521pub struct MemoryTransaction<'a> {
523 manager: &'a MemoryTxnManager,
524 id: TxnId,
525 mode: TxnMode,
526 state: TxnState,
527 start_version: u64,
528 writes: BTreeMap<Key, Option<Value>>,
529 read_set: HashMap<Key, u64>,
530}
531
532impl<'a> MemoryTransaction<'a> {
533 fn new(manager: &'a MemoryTxnManager, id: TxnId, mode: TxnMode, start_version: u64) -> Self {
534 Self {
535 manager,
536 id,
537 mode,
538 state: TxnState::Active,
539 start_version,
540 writes: BTreeMap::new(),
541 read_set: HashMap::new(),
542 }
543 }
544
545 fn ensure_active(&self) -> Result<()> {
546 if self.state != TxnState::Active {
547 return Err(Error::TxnClosed);
548 }
549 Ok(())
550 }
551
552 fn scan_range_internal(&mut self, start: &[u8], end: &[u8]) -> MergedScanIter<'_> {
553 let start_vec = start.to_vec();
554 let end_vec = end.to_vec();
555 let data_guard = self.manager.state.data.read().unwrap();
556 let data_ptr: *const BTreeMap<Key, VersionedValue> = &*data_guard;
557 let data_iter = unsafe {
558 (&*data_ptr).range((Included(start_vec.clone()), Excluded(end_vec.clone())))
560 };
561 let write_iter = self
562 .writes
563 .range((Included(start_vec.clone()), Excluded(end_vec.clone())));
564
565 MergedScanIter::new(
566 data_guard,
567 data_iter,
568 write_iter,
569 None,
570 Some(end_vec),
571 self.start_version,
572 &mut self.read_set,
573 )
574 }
575
576 fn scan_prefix_internal(&mut self, prefix: &[u8]) -> MergedScanIter<'_> {
577 let prefix_vec = prefix.to_vec();
578 let data_guard = self.manager.state.data.read().unwrap();
579 let data_ptr: *const BTreeMap<Key, VersionedValue> = &*data_guard;
580 let data_iter = unsafe {
581 (&*data_ptr).range(prefix_vec.clone()..)
583 };
584 let write_iter = self.writes.range(prefix_vec.clone()..);
585 MergedScanIter::new(
586 data_guard,
587 data_iter,
588 write_iter,
589 Some(prefix_vec),
590 None,
591 self.start_version,
592 &mut self.read_set,
593 )
594 }
595}
596
597impl<'a> KVTransaction<'a> for MemoryTransaction<'a> {
598 fn id(&self) -> TxnId {
599 self.id
600 }
601
602 fn mode(&self) -> TxnMode {
603 self.mode
604 }
605
606 fn get(&mut self, key: &Key) -> Result<Option<Value>> {
607 if self.state != TxnState::Active {
608 return Err(Error::TxnClosed);
609 }
610
611 if let Some(value) = self.writes.get(key) {
612 return Ok(value.clone());
613 }
614
615 let result = {
616 let data = self.manager.state.data.read().unwrap();
617 data.get(key).cloned()
618 };
619
620 if let Some((v, version)) = result {
621 self.read_set.insert(key.clone(), version);
622 return Ok(Some(v));
623 }
624
625 if let Some(value) = self.manager.sstable_get(key)? {
627 let version = self.manager.state.commit_version.load(Ordering::Acquire);
628 self.read_set.insert(key.clone(), version);
629 return Ok(Some(value));
630 }
631
632 Ok(None)
633 }
634
635 fn put(&mut self, key: Key, value: Value) -> Result<()> {
636 if self.state != TxnState::Active {
637 return Err(Error::TxnClosed);
638 }
639 if self.mode == TxnMode::ReadOnly {
640 return Err(Error::TxnConflict);
641 }
642 self.writes.insert(key, Some(value));
643 Ok(())
644 }
645
646 fn delete(&mut self, key: Key) -> Result<()> {
647 if self.state != TxnState::Active {
648 return Err(Error::TxnClosed);
649 }
650 if self.mode == TxnMode::ReadOnly {
651 return Err(Error::TxnConflict);
652 }
653 self.writes.insert(key, None);
654 Ok(())
655 }
656
657 fn scan_prefix(
658 &mut self,
659 prefix: &[u8],
660 ) -> Result<Box<dyn Iterator<Item = (Key, Value)> + '_>> {
661 self.ensure_active()?;
662 let iter = self
663 .scan_prefix_internal(prefix)
664 .filter_map(|(k, v)| v.map(|val| (k, val)));
665 Ok(Box::new(iter))
666 }
667
668 fn scan_range(
669 &mut self,
670 start: &[u8],
671 end: &[u8],
672 ) -> Result<Box<dyn Iterator<Item = (Key, Value)> + '_>> {
673 self.ensure_active()?;
674 let iter = self
675 .scan_range_internal(start, end)
676 .filter_map(|(k, v)| v.map(|val| (k, val)));
677 Ok(Box::new(iter))
678 }
679
680 fn commit_self(mut self) -> Result<()> {
681 if self.state != TxnState::Active {
682 return Err(Error::TxnClosed);
683 }
684 if self.mode == TxnMode::ReadOnly || self.writes.is_empty() {
685 self.state = TxnState::Committed;
686 return Ok(());
687 }
688
689 let mut data = self.manager.state.data.write().unwrap();
690
691 for key in self.read_set.keys() {
693 let current_version = data.get(key).map(|(_, v)| *v).unwrap_or(0);
694 if current_version > self.start_version {
695 return Err(Error::TxnConflict);
696 }
697 }
698
699 for key in self.writes.keys() {
701 let current_version = data.get(key).map(|(_, v)| *v).unwrap_or(0);
702 if current_version > self.start_version {
703 return Err(Error::TxnConflict);
704 }
705 }
706
707 let mut delta: isize = 0;
709 for (key, value) in &self.writes {
710 let current_size = data.get(key).map(|(v, _)| key.len() + v.len()).unwrap_or(0);
711 let new_size = match value {
712 Some(v) => key.len() + v.len(),
713 None => 0,
714 };
715 delta += new_size as isize - current_size as isize;
716 }
717
718 let current_mem = self.manager.state.current_memory.load(Ordering::Relaxed);
719 let prospective = if delta >= 0 {
720 current_mem.saturating_add(delta as usize)
721 } else {
722 current_mem.saturating_sub(delta.unsigned_abs())
723 };
724
725 if delta > 0 {
726 self.manager.state.check_memory_limit(delta as usize)?;
727 }
728
729 let commit_version = self
730 .manager
731 .state
732 .commit_version
733 .fetch_add(1, Ordering::AcqRel)
734 + 1;
735
736 if let Some(wal_lock) = &self.manager.state.wal_writer {
738 let mut wal = wal_lock.write().unwrap();
739 wal.append(&WalRecord::Begin(self.id))?;
740 for (key, value) in &self.writes {
741 let record = match value {
742 Some(v) => WalRecord::Put(self.id, key.clone(), v.clone()),
743 None => WalRecord::Delete(self.id, key.clone()),
744 };
745 wal.append(&record)?;
746 }
747 wal.append(&WalRecord::Commit(self.id))?;
748 }
749
750 for (key, value) in std::mem::take(&mut self.writes) {
752 if let Some(v) = value {
753 data.insert(key, (v, commit_version));
754 } else {
755 data.remove(&key);
756 }
757 }
758
759 self.manager
760 .state
761 .current_memory
762 .store(prospective, Ordering::Relaxed);
763
764 self.state = TxnState::Committed;
765 Ok(())
766 }
767
768 fn rollback_self(mut self) -> Result<()> {
769 if self.state != TxnState::Active {
770 return Err(Error::TxnClosed);
771 }
772 self.state = TxnState::RolledBack;
773 Ok(())
774 }
775}
776
777struct MergedScanIter<'a> {
779 _data_guard: RwLockReadGuard<'a, BTreeMap<Key, VersionedValue>>,
780 data_iter: std::collections::btree_map::Range<'a, Key, VersionedValue>,
781 write_iter: std::collections::btree_map::Range<'a, Key, Option<Value>>,
782 data_peek: Option<(Key, (Value, u64))>,
783 write_peek: Option<(Key, Option<Value>)>,
784 prefix: Option<Vec<u8>>,
785 end: Option<Key>,
786 start_version: u64,
787 read_set: &'a mut HashMap<Key, u64>,
788}
789
790impl<'a> MergedScanIter<'a> {
791 #[allow(clippy::too_many_arguments)]
792 fn new(
793 data_guard: std::sync::RwLockReadGuard<'a, BTreeMap<Key, VersionedValue>>,
794 data_iter: std::collections::btree_map::Range<'a, Key, VersionedValue>,
795 write_iter: std::collections::btree_map::Range<'a, Key, Option<Value>>,
796 prefix: Option<Vec<u8>>,
797 end: Option<Key>,
798 start_version: u64,
799 read_set: &'a mut HashMap<Key, u64>,
800 ) -> Self {
801 let mut iter = Self {
802 _data_guard: data_guard,
803 data_iter,
804 write_iter,
805 data_peek: None,
806 write_peek: None,
807 prefix,
808 end,
809 start_version,
810 read_set,
811 };
812 iter.advance_data();
813 iter.advance_write();
814 iter
815 }
816
817 fn advance_data(&mut self) {
818 self.data_peek = None;
819 while let Some((k, (v, ver))) = self.data_iter.next().map(|(k, v)| (k.clone(), v.clone())) {
820 if let Some(end) = &self.end {
821 if k >= *end {
822 return;
823 }
824 }
825 if let Some(prefix) = &self.prefix {
826 if !k.starts_with(prefix) {
827 return;
828 }
829 }
830 if ver > self.start_version {
831 continue;
832 }
833 self.data_peek = Some((k, (v, ver)));
834 return;
835 }
836 }
837
838 fn advance_write(&mut self) {
839 self.write_peek = None;
840 if let Some((k, v)) = self.write_iter.next().map(|(k, v)| (k.clone(), v.clone())) {
841 if let Some(end) = &self.end {
842 if k >= *end {
843 return;
844 }
845 }
846 if let Some(prefix) = &self.prefix {
847 if !k.starts_with(prefix) {
848 return;
849 }
850 }
851 self.write_peek = Some((k, v));
852 }
853 }
854}
855
856impl<'a> Iterator for MergedScanIter<'a> {
857 type Item = (Key, Option<Value>);
858
859 fn next(&mut self) -> Option<Self::Item> {
860 let data_key = self.data_peek.as_ref().map(|(k, _)| k.clone());
861 let write_key = self.write_peek.as_ref().map(|(k, _)| k.clone());
862
863 match (data_key, write_key) {
864 (Some(dk), Some(wk)) => {
865 if dk == wk {
866 let (_, (_, ver)) = self.data_peek.take().unwrap();
867 let (_, write_val) = self.write_peek.take().unwrap();
868 self.read_set.insert(dk.clone(), ver);
869 self.advance_data();
870 self.advance_write();
871 Some((dk, write_val))
872 } else if dk < wk {
873 let (k, (v, ver)) = self.data_peek.take().unwrap();
874 self.read_set.insert(k.clone(), ver);
875 self.advance_data();
876 Some((k, Some(v)))
877 } else {
878 let (k, write_val) = self.write_peek.take().unwrap();
879 self.advance_write();
880 Some((k, write_val))
881 }
882 }
883 (Some(_), None) => {
884 let (k, (v, ver)) = self.data_peek.take().unwrap();
885 self.read_set.insert(k.clone(), ver);
886 self.advance_data();
887 Some((k, Some(v)))
888 }
889 (None, Some(_)) => {
890 let (k, write_val) = self.write_peek.take().unwrap();
891 self.advance_write();
892 Some((k, write_val))
893 }
894 (None, None) => None,
895 }
896 }
897}
898
899impl<'a> Drop for MemoryTransaction<'a> {
900 fn drop(&mut self) {
901 if self.state == TxnState::Active {
902 self.state = TxnState::RolledBack;
903 }
904 }
905}
906
907#[cfg(test)]
908mod tests {
909 use super::*;
910 use crate::{KVTransaction, TxnManager};
911 use tempfile::tempdir;
912 use tracing::Level;
913
914 fn key(s: &str) -> Key {
915 s.as_bytes().to_vec()
916 }
917
918 fn value(s: &str) -> Value {
919 s.as_bytes().to_vec()
920 }
921
922 #[test]
923 fn test_put_and_get_transient() {
924 let store = MemoryKV::new();
925 let manager = store.txn_manager();
926 let mut txn = manager.begin(TxnMode::ReadWrite).unwrap();
927 txn.put(key("hello"), value("world")).unwrap();
928 let val = txn.get(&key("hello")).unwrap();
929 assert_eq!(val, Some(value("world")));
930 manager.commit(txn).unwrap();
931
932 let mut txn2 = manager.begin(TxnMode::ReadOnly).unwrap();
933 let val2 = txn2.get(&key("hello")).unwrap();
934 assert_eq!(val2, Some(value("world")));
935 }
936
937 #[test]
938 fn test_occ_conflict() {
939 let store = MemoryKV::new();
940 let manager = store.txn_manager();
941
942 let mut t1 = manager.begin(TxnMode::ReadWrite).unwrap();
943 t1.get(&key("k1")).unwrap();
944
945 let mut t2 = manager.begin(TxnMode::ReadWrite).unwrap();
946 t2.put(key("k1"), value("v2")).unwrap();
947 assert!(manager.commit(t2).is_ok());
948
949 t1.put(key("k1"), value("v1")).unwrap();
950 let result = manager.commit(t1);
951 assert!(matches!(result, Err(Error::TxnConflict)));
952 }
953
954 #[test]
955 fn test_blind_write_conflict() {
956 let store = MemoryKV::new();
957 let manager = store.txn_manager();
958
959 let mut t1 = manager.begin(TxnMode::ReadWrite).unwrap();
960 t1.put(key("k1"), value("v1")).unwrap();
961
962 let mut t2 = manager.begin(TxnMode::ReadWrite).unwrap();
963 t2.put(key("k1"), value("v2")).unwrap();
964 assert!(manager.commit(t2).is_ok());
965
966 let result = manager.commit(t1);
967 assert!(matches!(result, Err(Error::TxnConflict)));
968 }
969
970 #[test]
971 fn test_read_only_write_fails() {
972 let store = MemoryKV::new();
973 let manager = store.txn_manager();
974 let mut txn = manager.begin(TxnMode::ReadOnly).unwrap();
975 assert!(matches!(
976 txn.put(key("k1"), value("v1")),
977 Err(Error::TxnConflict)
978 ));
979 assert!(matches!(txn.delete(key("k1")), Err(Error::TxnConflict)));
980 }
981
982 #[test]
983 fn test_txn_closed_error() {
984 let store = MemoryKV::new();
985 let manager = store.txn_manager();
986 let txn = manager.begin(TxnMode::ReadWrite).unwrap();
987 manager.commit(txn).unwrap();
988
989 let mut closed_txn = manager.begin(TxnMode::ReadWrite).unwrap();
992 closed_txn.state = TxnState::Committed;
993 assert!(matches!(closed_txn.get(&key("k1")), Err(Error::TxnClosed)));
994 assert!(matches!(
995 closed_txn.put(key("k1"), value("v1")),
996 Err(Error::TxnClosed)
997 ));
998 }
999
1000 #[test]
1001 fn test_get_not_found() {
1002 let store = MemoryKV::new();
1003 let manager = store.txn_manager();
1004 let mut txn = manager.begin(TxnMode::ReadOnly).unwrap();
1005 let res = txn.get(&key("non-existent"));
1006 assert!(res.is_ok());
1007 assert!(res.unwrap().is_none());
1008 }
1009
1010 #[test]
1011 fn flush_and_reopen_reads_from_sstable() {
1012 let dir = tempdir().unwrap();
1013 let wal_path = dir.path().join("wal.log");
1014 {
1015 let store = MemoryKV::open(&wal_path).unwrap();
1016 let manager = store.txn_manager();
1017 let mut txn = manager.begin(TxnMode::ReadWrite).unwrap();
1018 txn.put(key("k1"), value("v1")).unwrap();
1019 manager.commit(txn).unwrap();
1020 store.flush().unwrap();
1021 }
1022
1023 let reopened = MemoryKV::open(&wal_path).unwrap();
1024 let manager = reopened.txn_manager();
1025 let mut txn = manager.begin(TxnMode::ReadOnly).unwrap();
1026 assert_eq!(txn.get(&key("k1")).unwrap(), Some(value("v1")));
1027 }
1028
1029 #[test]
1030 fn wal_overlays_sstable_on_reopen() {
1031 let dir = tempdir().unwrap();
1032 let wal_path = dir.path().join("wal.log");
1033 {
1034 let store = MemoryKV::open(&wal_path).unwrap();
1035 let manager = store.txn_manager();
1036 let mut txn = manager.begin(TxnMode::ReadWrite).unwrap();
1037 txn.put(key("k1"), value("v1")).unwrap();
1038 manager.commit(txn).unwrap();
1039 store.flush().unwrap();
1040
1041 let mut txn2 = manager.begin(TxnMode::ReadWrite).unwrap();
1042 txn2.put(key("k1"), value("v2")).unwrap();
1043 manager.commit(txn2).unwrap();
1044 }
1045
1046 let reopened = MemoryKV::open(&wal_path).unwrap();
1047 let manager = reopened.txn_manager();
1048 let mut txn = manager.begin(TxnMode::ReadOnly).unwrap();
1049 assert_eq!(txn.get(&key("k1")).unwrap(), Some(value("v2")));
1050 }
1051
1052 #[test]
1053 fn scan_prefix_merges_snapshot_and_writes() {
1054 let store = MemoryKV::new();
1055 let manager = store.txn_manager();
1056
1057 let mut seed = manager.begin(TxnMode::ReadWrite).unwrap();
1058 seed.put(key("p:1"), value("old1")).unwrap();
1059 seed.put(key("p:2"), value("old2")).unwrap();
1060 seed.put(key("q:1"), value("other")).unwrap();
1061 manager.commit(seed).unwrap();
1062
1063 let mut txn = manager.begin(TxnMode::ReadWrite).unwrap();
1064 txn.put(key("p:1"), value("new1")).unwrap();
1065 txn.delete(key("p:2")).unwrap();
1066 txn.put(key("p:3"), value("new3")).unwrap();
1067
1068 let results: Vec<_> = txn.scan_prefix(b"p:").unwrap().collect();
1069 assert_eq!(
1070 results,
1071 vec![(key("p:1"), value("new1")), (key("p:3"), value("new3"))]
1072 );
1073 }
1074
1075 #[test]
1076 fn scan_range_skips_newer_versions() {
1077 let store = MemoryKV::new();
1078 let manager = store.txn_manager();
1079
1080 let mut seed = manager.begin(TxnMode::ReadWrite).unwrap();
1081 seed.put(key("b"), value("v1")).unwrap();
1082 manager.commit(seed).unwrap();
1083
1084 let mut txn1 = manager.begin(TxnMode::ReadWrite).unwrap();
1085
1086 let mut txn2 = manager.begin(TxnMode::ReadWrite).unwrap();
1087 txn2.put(key("ba"), value("v2")).unwrap();
1088 manager.commit(txn2).unwrap();
1089
1090 let results: Vec<_> = txn1.scan_range(b"b", b"c").unwrap().collect();
1091 assert_eq!(results, vec![(key("b"), value("v1"))]);
1092 }
1093
1094 #[test]
1095 fn scan_range_records_reads_for_conflict_detection() {
1096 let store = MemoryKV::new();
1097 let manager = store.txn_manager();
1098
1099 let mut seed = manager.begin(TxnMode::ReadWrite).unwrap();
1100 seed.put(key("k1"), value("v1")).unwrap();
1101 manager.commit(seed).unwrap();
1102
1103 let mut t1 = manager.begin(TxnMode::ReadWrite).unwrap();
1104 let results: Vec<_> = t1.scan_range(b"k0", b"kz").unwrap().collect();
1105 assert_eq!(results, vec![(key("k1"), value("v1"))]);
1106 t1.put(key("k_new"), value("v_new")).unwrap();
1107
1108 let mut t2 = manager.begin(TxnMode::ReadWrite).unwrap();
1109 t2.put(key("k1"), value("v2")).unwrap();
1110 manager.commit(t2).unwrap();
1111
1112 let result = manager.commit(t1);
1113 assert!(matches!(result, Err(Error::TxnConflict)));
1114 }
1115
1116 #[test]
1117 fn memory_stats_tracks_put_and_delete() {
1118 let store = MemoryKV::new();
1119 let manager = store.txn_manager();
1120
1121 let stats = manager.memory_stats();
1122 assert_eq!(stats.total_bytes, 0);
1123 assert_eq!(stats.kv_bytes, 0);
1124 assert_eq!(stats.index_bytes, 0);
1125
1126 let mut txn = manager.begin(TxnMode::ReadWrite).unwrap();
1128 txn.put(key("a"), value("1234")).unwrap(); manager.commit(txn).unwrap();
1130
1131 let stats = manager.memory_stats();
1132 assert_eq!(stats.total_bytes, 5);
1133 assert_eq!(stats.kv_bytes, 5);
1134 assert_eq!(stats.index_bytes, 0);
1135
1136 let mut txn = manager.begin(TxnMode::ReadWrite).unwrap();
1138 txn.delete(key("a")).unwrap();
1139 manager.commit(txn).unwrap();
1140
1141 let stats = manager.memory_stats();
1142 assert_eq!(stats.total_bytes, 0);
1143 assert_eq!(stats.kv_bytes, 0);
1144 }
1145
1146 #[test]
1147 fn memory_limit_error_does_not_break_reads() {
1148 let store = MemoryKV::new_with_limit(Some(10));
1149 let manager = store.txn_manager();
1150
1151 let mut txn = manager.begin_internal(TxnMode::ReadWrite).unwrap();
1153 txn.put(key("k1"), value("vvvv")).unwrap();
1154 manager.commit(txn).unwrap();
1155
1156 let mut txn2 = manager.begin_internal(TxnMode::ReadWrite).unwrap();
1158 txn2.put(key("k2"), value("vvvvvv")).unwrap();
1159 let result = manager.commit(txn2);
1160 assert!(matches!(result, Err(Error::MemoryLimitExceeded { .. })));
1161
1162 let mut read_txn = manager.begin_internal(TxnMode::ReadOnly).unwrap();
1164 let got = read_txn.get(&key("k1")).unwrap();
1165 assert_eq!(got, Some(value("vvvv")));
1166
1167 let stats = manager.memory_stats();
1169 assert_eq!(stats.total_bytes, 6);
1170 }
1171
1172 struct VecWriter(std::sync::Arc<std::sync::Mutex<Vec<u8>>>);
1173
1174 impl std::io::Write for VecWriter {
1175 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
1176 let mut guard = self.0.lock().unwrap();
1177 guard.extend_from_slice(buf);
1178 Ok(buf.len())
1179 }
1180
1181 fn flush(&mut self) -> std::io::Result<()> {
1182 Ok(())
1183 }
1184 }
1185
1186 #[test]
1187 fn compaction_skips_when_over_limit_and_logs_warning() {
1188 let store = MemoryKV::new_with_limit(Some(12));
1189 let manager = store.txn_manager();
1190
1191 let mut txn = manager.begin_internal(TxnMode::ReadWrite).unwrap();
1193 txn.put(key("k1"), value("123456")).unwrap();
1194 manager.commit(txn).unwrap();
1195
1196 let buffer = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
1198 let make_writer = {
1199 let buf = buffer.clone();
1200 move || VecWriter(buf.clone())
1201 };
1202 let subscriber = tracing_subscriber::fmt()
1203 .with_max_level(Level::WARN)
1204 .with_writer(make_writer)
1205 .without_time()
1206 .finish();
1207 let _guard = tracing::subscriber::set_default(subscriber);
1208
1209 let ran = manager.compact_with_limit(2, 10, || Ok(())).unwrap();
1211 assert!(!ran);
1212
1213 assert_eq!(manager.memory_stats().total_bytes, 8);
1215
1216 let log = String::from_utf8(buffer.lock().unwrap().clone()).unwrap();
1218 assert!(
1219 log.contains("compaction skipped due to memory limit"),
1220 "expected warning log, got: {}",
1221 log
1222 );
1223 }
1224}